haiphamcse commited on
Commit
3e31ef1
·
verified ·
1 Parent(s): df4962b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SAEDashboard/.dockerignore +184 -0
  2. SAEDashboard/.flake8 +8 -0
  3. SAEDashboard/.github/workflows/ci.yaml +117 -0
  4. SAEDashboard/.gitignore +204 -0
  5. SAEDashboard/.vscode/settings.json +18 -0
  6. SAEDashboard/CHANGELOG.md +1263 -0
  7. SAEDashboard/Dockerfile +45 -0
  8. SAEDashboard/LICENSE +21 -0
  9. SAEDashboard/Makefile +27 -0
  10. SAEDashboard/README.md +221 -0
  11. SAEDashboard/docker/docker-entrypoint.sh +11 -0
  12. SAEDashboard/docker/docker-hub.yaml +57 -0
  13. SAEDashboard/neuronpedia_vector_pipeline_demo.ipynb +282 -0
  14. SAEDashboard/notebooks/experiment_gemma_2_9b_dashboard_generation_np.py +52 -0
  15. SAEDashboard/notebooks/sae_dashboard_demo_gemma_2_9b.ipynb +618 -0
  16. SAEDashboard/pyproject.toml +70 -0
  17. SAEDashboard/sae_dashboard/__init__.py +10 -0
  18. SAEDashboard/sae_dashboard/__pycache__/__init__.cpython-313.pyc +0 -0
  19. SAEDashboard/sae_dashboard/__pycache__/components.cpython-313.pyc +0 -0
  20. SAEDashboard/sae_dashboard/__pycache__/components_config.cpython-313.pyc +0 -0
  21. SAEDashboard/sae_dashboard/__pycache__/data_parsing_fns.cpython-313.pyc +0 -0
  22. SAEDashboard/sae_dashboard/__pycache__/data_writing_fns.cpython-313.pyc +0 -0
  23. SAEDashboard/sae_dashboard/__pycache__/dfa_calculator.cpython-313.pyc +0 -0
  24. SAEDashboard/sae_dashboard/__pycache__/feature_data.cpython-313.pyc +0 -0
  25. SAEDashboard/sae_dashboard/__pycache__/feature_data_generator.cpython-313.pyc +0 -0
  26. SAEDashboard/sae_dashboard/__pycache__/html_fns.cpython-313.pyc +0 -0
  27. SAEDashboard/sae_dashboard/__pycache__/layout.cpython-313.pyc +0 -0
  28. SAEDashboard/sae_dashboard/__pycache__/sae_vis_data.cpython-313.pyc +0 -0
  29. SAEDashboard/sae_dashboard/__pycache__/sae_vis_runner.cpython-313.pyc +0 -0
  30. SAEDashboard/sae_dashboard/__pycache__/sequence_data_generator.cpython-313.pyc +0 -0
  31. SAEDashboard/sae_dashboard/__pycache__/transformer_lens_wrapper.cpython-313.pyc +0 -0
  32. SAEDashboard/sae_dashboard/__pycache__/utils_fns.cpython-313.pyc +0 -0
  33. SAEDashboard/sae_dashboard/__pycache__/vector_vis_data.cpython-313.pyc +0 -0
  34. SAEDashboard/sae_dashboard/clt_layer_wrapper.py +697 -0
  35. SAEDashboard/sae_dashboard/components.py +774 -0
  36. SAEDashboard/sae_dashboard/components_config.py +206 -0
  37. SAEDashboard/sae_dashboard/css/dropdown.css +40 -0
  38. SAEDashboard/sae_dashboard/css/general.css +53 -0
  39. SAEDashboard/sae_dashboard/css/sequences.css +61 -0
  40. SAEDashboard/sae_dashboard/css/tables.css +81 -0
  41. SAEDashboard/sae_dashboard/data_parsing_fns.py +412 -0
  42. SAEDashboard/sae_dashboard/data_writing_fns.py +210 -0
  43. SAEDashboard/sae_dashboard/dfa_calculator.py +159 -0
  44. SAEDashboard/sae_dashboard/feature_data.py +211 -0
  45. SAEDashboard/sae_dashboard/feature_data_generator.py +313 -0
  46. SAEDashboard/sae_dashboard/html/acts_histogram_template.html +2 -0
  47. SAEDashboard/sae_dashboard/html/feature_tables_template.html +2 -0
  48. SAEDashboard/sae_dashboard/html/logits_histogram_template.html +2 -0
  49. SAEDashboard/sae_dashboard/html/logits_table_template.html +2 -0
  50. SAEDashboard/sae_dashboard/html/sequences_group_template.html +2 -0
SAEDashboard/.dockerignore ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # ruff
147
+ .ruff_cache
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+
165
+ *.pkl
166
+ *.pt
167
+ sae_vis/archive_fns.py
168
+ *__pycache__
169
+ mats_sae_training
170
+ callum_instructions.md
171
+ april-fools
172
+ *large.html
173
+ requirements.txt
174
+ tests/fixtures/cache_benchmark/
175
+ tests/fixtures/cache_unit/
176
+
177
+ neuronpedia_outputs/
178
+ cached_activations/
179
+ wandb/
180
+
181
+
182
+ **.safetensors
183
+ **flamegraph.html
184
+ artifacts/
SAEDashboard/.flake8 ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ extend-ignore = E203, E266, E501, W503, E721, F722, E731, E402, F821
3
+ max-line-length = 79
4
+ max-complexity = 25
5
+ extend-select = E9, F63, F7, F82
6
+ show-source = true
7
+ statistics = true
8
+ exclude = ./wandb/*, ./research/wandb/*, .venv/*
SAEDashboard/.github/workflows/ci.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "ci"
2
+
3
+ on:
4
+ pull_request:
5
+ branches: ["**"]
6
+ push:
7
+ branches: ["**"]
8
+
9
+ jobs:
10
+ build:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: ["3.10", "3.11", "3.12"]
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+
18
+ - name: Set up Python ${{ matrix.python-version }}
19
+ uses: actions/setup-python@v5
20
+ with:
21
+ python-version: ${{ matrix.python-version }}
22
+
23
+ - name: Cache Huggingface assets
24
+ uses: actions/cache@v4
25
+ with:
26
+ key: huggingface-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
27
+ path: ~/.cache/huggingface
28
+ restore-keys: |
29
+ huggingface-${{ runner.os }}-${{ matrix.python-version }}-
30
+
31
+ - name: Load cached Poetry installation
32
+ id: cached-poetry
33
+ uses: actions/cache@v4
34
+ with:
35
+ path: ~/.local
36
+ key: poetry-${{ runner.os }}-${{ matrix.python-version }}-1 # Incremented to reset cache
37
+
38
+ - name: Install Poetry
39
+ if: steps.cached-poetry.outputs.cache-hit != 'true'
40
+ uses: snok/install-poetry@v1
41
+ with:
42
+ version: 1.5.1 # Specify a version explicitly
43
+ virtualenvs-create: true
44
+ virtualenvs-in-project: true
45
+ installer-parallel: true
46
+
47
+ - name: Check Poetry Version
48
+ run: poetry --version
49
+
50
+ - name: Load cached venv
51
+ id: cached-poetry-dependencies
52
+ uses: actions/cache@v4
53
+ with:
54
+ path: .venv
55
+ key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}-1 # Incremented to reset cache
56
+ restore-keys: |
57
+ venv-${{ runner.os }}-${{ matrix.python-version }}-
58
+
59
+ - name: Install dependencies
60
+ if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
61
+ run: poetry install --no-interaction
62
+
63
+ - name: List installed packages
64
+ run: poetry run pip list
65
+
66
+ - name: Check flake8 installation
67
+ run: poetry run which flake8
68
+
69
+ - name: check linting
70
+ run: poetry run flake8 .
71
+
72
+ - name: check formatting
73
+ run: poetry run black --check .
74
+
75
+ - name: check types
76
+ run: poetry run pyright .
77
+
78
+ - name: test
79
+ run: poetry run pytest --cov=sae_dashboard --cov-report=term-missing tests/unit
80
+
81
+ - name: build
82
+ run: poetry build
83
+
84
+ release:
85
+ needs: build
86
+ permissions:
87
+ contents: write
88
+ id-token: write
89
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):')
90
+ runs-on: ubuntu-latest
91
+ concurrency: release
92
+ environment:
93
+ name: pypi
94
+ steps:
95
+ - uses: actions/checkout@v4
96
+ with:
97
+ fetch-depth: 0
98
+
99
+ - uses: actions/setup-python@v5
100
+ with:
101
+ python-version: "3.11"
102
+
103
+ - name: Semantic Release
104
+ id: release
105
+ uses: python-semantic-release/python-semantic-release@v9.8.8
106
+ with:
107
+ github_token: ${{ secrets.GITHUB_TOKEN }}
108
+
109
+ - name: Publish package distributions to PyPI
110
+ uses: pypa/gh-action-pypi-publish@release/v1
111
+ if: steps.release.outputs.released == 'true'
112
+
113
+ - name: Publish package distributions to GitHub Releases
114
+ uses: python-semantic-release/upload-to-gh-release@main
115
+ if: steps.release.outputs.released == 'true'
116
+ with:
117
+ github_token: ${{ secrets.GITHUB_TOKEN }}
SAEDashboard/.gitignore ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # ruff
147
+ .ruff_cache
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+
165
+ *.pkl
166
+ *.pt
167
+ sae_vis/archive_fns.py
168
+ *__pycache__
169
+ mats_sae_training
170
+ callum_instructions.md
171
+ april-fools
172
+ *large.html
173
+ requirements.txt
174
+ tests/fixtures/cache_benchmark/
175
+ tests/fixtures/cache_unit/
176
+
177
+ neuronpedia_outputs/
178
+ cached_activations/
179
+ wandb/
180
+ demo_activations_cache/
181
+ test_activations_cache/
182
+ demo_feature_dashboards.html
183
+
184
+ **.safetensors
185
+ **flamegraph.html
186
+ artifacts/
187
+ prof/
188
+
189
+ .vscode/settings.json
190
+ dfa_tests.ipynb
191
+
192
+ .DS_Store
193
+
194
+ # Test and temporary directories
195
+ crosslayer-coding/
196
+ SAELens/
197
+ clt_test*/
198
+ test_output/
199
+ test_outputs/
200
+ clt-technical-description.md
201
+
202
+ ignore_data/
203
+
204
+ outputs/
SAEDashboard/.vscode/settings.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "python.testing.pytestArgs": [
3
+ "tests"
4
+ ],
5
+ "python.testing.unittestEnabled": false,
6
+ "python.testing.pytestEnabled": true,
7
+
8
+ "[python]": {
9
+ "editor.defaultFormatter": "ms-python.black-formatter",
10
+ "editor.formatOnSave": true,
11
+ "editor.codeActionsOnSave": {
12
+ "source.organizeImports": "explicit"
13
+ }
14
+ },
15
+ "isort.args": ["--profile", "black"],
16
+ "editor.defaultFormatter": "mikoz.black-py",
17
+ "liveServer.settings.port": 5501
18
+ }
SAEDashboard/CHANGELOG.md ADDED
@@ -0,0 +1,1263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CHANGELOG
2
+
3
+ ## v0.7.3 (2025-10-11)
4
+
5
+ ### Fix
6
+
7
+ * fix: broken dependencies ([`25ce6e8`](https://github.com/jbloomAus/SAEDashboard/commit/25ce6e8ae2debe232b0eff4ba910fa10fc816480))
8
+
9
+ ### Unknown
10
+
11
+ * Merge pull request #70 from jbloomAus/fix_deps
12
+
13
+ fix: broken dependencies ([`352b9b2`](https://github.com/jbloomAus/SAEDashboard/commit/352b9b2148b62c8fabb7adcf6bf0cbacfa345a74))
14
+
15
+ * update .gitignore ([`6133ca6`](https://github.com/jbloomAus/SAEDashboard/commit/6133ca67b39bf6e033e8b4792f5eba9850668821))
16
+
17
+ ## v0.7.2 (2025-09-01)
18
+
19
+ ### Fix
20
+
21
+ * fix: use clt pypi library ([`a3197b8`](https://github.com/jbloomAus/SAEDashboard/commit/a3197b8870d43107bba0356af11c64ddc054392d))
22
+
23
+ ## v0.7.1 (2025-08-31)
24
+
25
+ ### Fix
26
+
27
+ * fix: force build ([`f432ba2`](https://github.com/jbloomAus/SAEDashboard/commit/f432ba2d14edd4c11ac71e947fbb1e97790e753e))
28
+
29
+ ## v0.7.0 (2025-08-31)
30
+
31
+ ### Feature
32
+
33
+ * feat: Merge pull request #69 from jbloomAus/qwen-transcoder
34
+
35
+ Transcoder Support + SAELens v6 ([`8f8651e`](https://github.com/jbloomAus/SAEDashboard/commit/8f8651edaf8c20bc8eaff09e05de238a4ce780fb))
36
+
37
+ ### Fix
38
+
39
+ * fix: relax saelens to not break saelens demo project ([`e38d140`](https://github.com/jbloomAus/SAEDashboard/commit/e38d1408f30f47f9e91646fa8646300750b23fd3))
40
+
41
+ ### Unknown
42
+
43
+ * Merge branch 'main' into qwen-transcoder ([`ae9dede`](https://github.com/jbloomAus/SAEDashboard/commit/ae9dedead671996ebf4f7de84eb4252b2af85fc2))
44
+
45
+ * Merge pull request #67 from jbloomAus/relax_saelens
46
+
47
+ fix: relax saelens to not break saelens demo project ([`fa1691a`](https://github.com/jbloomAus/SAEDashboard/commit/fa1691ab224e684618b2e800b8fda8af741eb81b))
48
+
49
+ ## v0.6.11 (2025-08-05)
50
+
51
+ ### Fix
52
+
53
+ * fix: fixes tool.semantic_release subtable (#66) ([`eb36157`](https://github.com/jbloomAus/SAEDashboard/commit/eb361571550a4653f7fbcc5a9cc2c98c329aaf41))
54
+
55
+ * fix: fixes tool.semantic_release subtable (#66) ([`725d76d`](https://github.com/jbloomAus/SAEDashboard/commit/725d76d9ac00e3b295c6b11f4657f4432c925e9e))
56
+
57
+ ### Unknown
58
+
59
+ * upgrades python-semantic-release (#65) ([`fcdae8b`](https://github.com/jbloomAus/SAEDashboard/commit/fcdae8b0e18bbf2c5184977bdf14cc14280fb6bc))
60
+
61
+ * fix CI ([`d7dffdd`](https://github.com/jbloomAus/SAEDashboard/commit/d7dffdd79b0165cbdfb360f13622385f884a2158))
62
+
63
+ * upgrades python-semantic-release (#65) ([`721d683`](https://github.com/jbloomAus/SAEDashboard/commit/721d683b437f378f22c3e713cb4e3f16bdc82e1a))
64
+
65
+ * converter ([`30ff988`](https://github.com/jbloomAus/SAEDashboard/commit/30ff9881f4ea1adc3305675088f5fb808367d5eb))
66
+
67
+ * bos override ([`bd04133`](https://github.com/jbloomAus/SAEDashboard/commit/bd04133033ce5b422f62fc12099eb6527d6f8070))
68
+
69
+ * add prefix tokens to cli ([`4560dd7`](https://github.com/jbloomAus/SAEDashboard/commit/4560dd7301a713b3b45d836df3e43bf25ba52c64))
70
+
71
+ * add prefix tokens to cli ([`50636fb`](https://github.com/jbloomAus/SAEDashboard/commit/50636fb667384fbf1df8d7809cfe3c5ebc44beab))
72
+
73
+ * top acts group 20 ([`1e3d3d4`](https://github.com/jbloomAus/SAEDashboard/commit/1e3d3d4228b5f526f9a2e1d90cb51ff74d8e60b8))
74
+
75
+ * temp updates for qwen transcoder ([`5727ac9`](https://github.com/jbloomAus/SAEDashboard/commit/5727ac944feed9eb78f60f6553feb40c7b6622d8))
76
+
77
+ * some config fixes ([`51d903a`](https://github.com/jbloomAus/SAEDashboard/commit/51d903a98ebc2402f6d4863f1efc1062420ea3eb))
78
+
79
+ * olved double normalization ([`5293073`](https://github.com/jbloomAus/SAEDashboard/commit/5293073adef90efc0804c61d9dbe3ab55430f62b))
80
+
81
+ * updated readme ([`6ea06f2`](https://github.com/jbloomAus/SAEDashboard/commit/6ea06f2371e80191d8a29cca0c5e134943db02d0))
82
+
83
+ * formatting changes ([`ad7422e`](https://github.com/jbloomAus/SAEDashboard/commit/ad7422e11a5525d8584991e278a3342ebf4ff892))
84
+
85
+ * Update CLT test script parameters ([`9961859`](https://github.com/jbloomAus/SAEDashboard/commit/996185947217ccd13d6763462d764cbd2977a28e))
86
+
87
+ * Merge pull request #64 from jbloomAus/clt-support
88
+
89
+ CLT Support ([`57971a9`](https://github.com/jbloomAus/SAEDashboard/commit/57971a9e85b62cc2a9e7bf04ee8d8c26fac9cecc))
90
+
91
+ * Add Cross-Layer Transcoder (CLT) support to SAEDashboard
92
+
93
+ - Add CLTLayerWrapper to provide SAE-compatible interface for CLTs
94
+ - Integrate CLT loading into NeuronpediaRunner with --use-clt flag
95
+ - Add CLT-specific configuration parameters (clt_layer_idx, clt_weights_filename)
96
+ - Support JumpReLU activation with learned thresholds
97
+ - Add normalization statistics loading from norm_stats.json
98
+ - Handle CLT-specific hook naming conventions (tl_input_template)
99
+ - Add comprehensive unit tests for CLT functionality
100
+ - Fix existing unit tests to use StandardSAE/StandardSAEConfig
101
+
102
+ 🤖 Generated with [Claude Code](https://claude.ai/code)
103
+
104
+ Co-Authored-By: Claude <noreply@anthropic.com> ([`fab4c6c`](https://github.com/jbloomAus/SAEDashboard/commit/fab4c6cc7a026bbf9beed229734a34323f22c158))
105
+
106
+ * Add CLT (Cross-Layer Transcoder) support
107
+
108
+ - Add CLTLayerWrapper to wrap CLT models for SAE-compatible interface
109
+ - Add CLT loading logic in neuronpedia_runner with local file support
110
+ - Add conditional logic to skip fold_W_dec_norm() for CLT wrappers
111
+ - Add conditional logic to skip hook_z_reshaping_mode for CLT wrappers
112
+ - Add support for additional hook types (hook_mlp_out, hook_attn_out, etc.)
113
+ - Add CLI arguments for CLT configuration (--use-clt, --clt-layer-idx, etc.)
114
+ - Ensure set_use_hook_mlp_in is called for CLT models ([`0c58760`](https://github.com/jbloomAus/SAEDashboard/commit/0c587608d2da05aa5ed3656afc0ca6b001fbf79b))
115
+
116
+ * script for CLT dashboard generation ([`31e7154`](https://github.com/jbloomAus/SAEDashboard/commit/31e7154363af98af2d6ec269ee9a78f9723048bb))
117
+
118
+ * formatting ([`210fdc4`](https://github.com/jbloomAus/SAEDashboard/commit/210fdc4bd3d8e2dc957c2925719a7a7f7bca1de3))
119
+
120
+ * simplified init function ([`883deb4`](https://github.com/jbloomAus/SAEDashboard/commit/883deb46d75df68af96c2366f59a41dd4a6db964))
121
+
122
+ * formatted tests ([`676b0f1`](https://github.com/jbloomAus/SAEDashboard/commit/676b0f17db4ad919c951414e29c1eb224efa8cef))
123
+
124
+ * added tests and formatting ([`70cbad3`](https://github.com/jbloomAus/SAEDashboard/commit/70cbad3074d2c0a86929921d78d20954e1caf6d8))
125
+
126
+ * Fix compatibility with new SAELens API structure
127
+
128
+ - Fix hook_layer extraction from hook_name when not in config
129
+ - Remove deprecated unpacking of SAE.from_pretrained() return values
130
+ - Handle prepend_bos in both config and metadata locations
131
+ - Add support for extracting layer number from hook_name pattern
132
+
133
+ Co-Authored-By: Claude <noreply@anthropic.com> ([`9018991`](https://github.com/jbloomAus/SAEDashboard/commit/9018991a40b62b7f4ece4f91aa8f07b9f2119d9f))
134
+
135
+ * Fix indentation and hook_name access for transcoders
136
+
137
+ - Fix indentation errors in neuronpedia_runner.py
138
+ - Fix hook_name access - it's always in metadata for both SAEs and transcoders
139
+ - Add test scripts for transcoder functionality
140
+ - Successfully tested transcoder dashboard generation
141
+
142
+ Co-Authored-By: Claude <noreply@anthropic.com> ([`31368e4`](https://github.com/jbloomAus/SAEDashboard/commit/31368e4387757063df87e3b04bd512e13a8cc7d5))
143
+
144
+ * Update .gitignore to exclude test directories and submodules ([`5bf67c5`](https://github.com/jbloomAus/SAEDashboard/commit/5bf67c56e242f496b849957675dd610863485aeb))
145
+
146
+ * Add transcoder support to SAEDashboard
147
+
148
+ - Update imports from sae_lens to use new API structure
149
+ - Add support for loading Transcoder and SkipTranscoder
150
+ - Handle differences between SAE and Transcoder configs
151
+ - Add support for normalized hooks in transformer_lens_wrapper
152
+ - Fix architecture handling in FeatureMaskingContext
153
+ - Update ActivationsStore.from_sae() to include dataset parameter
154
+
155
+ Co-Authored-By: Claude <noreply@anthropic.com> ([`02f78a0`](https://github.com/jbloomAus/SAEDashboard/commit/02f78a0b2a6d08729e60caf06649fca4dfc38ec7))
156
+
157
+ ## v0.6.10 (2025-07-16)
158
+
159
+ ### Fix
160
+
161
+ * fix: relax SAELens requirement ([`a83147e`](https://github.com/jbloomAus/SAEDashboard/commit/a83147efbf30ef4c4380f306a03468a0c8d41be0))
162
+
163
+ * fix: Merge pull request #45 from Hzfinfdu/main
164
+
165
+ fix: reading model_from_pretrained_kwargs from SAELens config with th… ([`0a509fe`](https://github.com/jbloomAus/SAEDashboard/commit/0a509fede04737b8087140fb4fe5f7addc259806))
166
+
167
+ * fix: reading model_from_pretrained_kwargs from SAELens config with the correct key ([`9938812`](https://github.com/jbloomAus/SAEDashboard/commit/9938812ad209764ceb021eedffb08c0fc5a31c89))
168
+
169
+ ### Unknown
170
+
171
+ * Merge pull request #60 from jbloomAus/fix-unit-tests
172
+
173
+ fixes unit tests ([`1a3975d`](https://github.com/jbloomAus/SAEDashboard/commit/1a3975df60198dc169dfcf3354a8e8da5383029f))
174
+
175
+ * fixes unit tests ([`2a35d5d`](https://github.com/jbloomAus/SAEDashboard/commit/2a35d5dd4ad08e3d3532158030d86dbef93ad309))
176
+
177
+ * dedupes get_tokens() (#55)
178
+
179
+ * dedupes get_tokens()
180
+
181
+ * adds newline ([`faeb6f1`](https://github.com/jbloomAus/SAEDashboard/commit/faeb6f119d35a275a304d39c6e8cc9c7c40d31ce))
182
+
183
+ * fixes make commands (#57) ([`cb74411`](https://github.com/jbloomAus/SAEDashboard/commit/cb74411039d0c9d0b0883c85434f27601cf940a5))
184
+
185
+ * deletes print statements in tests (#56) ([`026ba30`](https://github.com/jbloomAus/SAEDashboard/commit/026ba305f4e31f5b47d4f9ada04c7cb0c3aae7f0))
186
+
187
+ * deletes unused direct_effect_feature_ablation_experiment() (#52) ([`391ff94`](https://github.com/jbloomAus/SAEDashboard/commit/391ff949a997a99b605bd55a706d6fed2892249c))
188
+
189
+ * removes unused files (#54) ([`5381cc7`](https://github.com/jbloomAus/SAEDashboard/commit/5381cc7118c7655c6c14cdbbd12e1f6c00278fc2))
190
+
191
+ * Merge pull request #47 from Marlon154/main
192
+
193
+ Fixing deprecated fn call for SAE Lens ([`61c9bd4`](https://github.com/jbloomAus/SAEDashboard/commit/61c9bd4ad8ccd5d96cb5c89eb961db0e7fbc2ab0))
194
+
195
+ * Merge branch 'main' into main ([`50b202a`](https://github.com/jbloomAus/SAEDashboard/commit/50b202a0b2fef413dd46b4ce2838bae27c0ac252))
196
+
197
+ * Merge pull request #35 from chanind/relax-saelens-dep
198
+
199
+ fix: relax SAELens and einops requirements ([`6c71bbf`](https://github.com/jbloomAus/SAEDashboard/commit/6c71bbfd7b6f1562093f1192616a7a55188631d3))
200
+
201
+ * fixing type checking ([`42a9845`](https://github.com/jbloomAus/SAEDashboard/commit/42a9845bba856c942f6d70182429cdb49e0ea917))
202
+
203
+ * Merge branch 'main' into relax-saelens-dep ([`3e6c870`](https://github.com/jbloomAus/SAEDashboard/commit/3e6c8703afd5ce80c29ec1ed0fc729def3f7f8fa))
204
+
205
+ * also relax einops ([`62614ac`](https://github.com/jbloomAus/SAEDashboard/commit/62614ac27ca50527556cc7c891e589e63a14e9bc))
206
+
207
+ * fix type checks ([`5a2cca0`](https://github.com/jbloomAus/SAEDashboard/commit/5a2cca0334a0907e7685cbef798cda71cd249ba4))
208
+
209
+ * Fixing deprecated fn call ([`f1da0e6`](https://github.com/jbloomAus/SAEDashboard/commit/f1da0e6ea7d663e5ff54612d7979d1b1ed9a6b77))
210
+
211
+ ## v0.6.9 (2025-02-25)
212
+
213
+ ### Fix
214
+
215
+ * fix: Merge pull request #44 from jbloomAus/update_saelens
216
+
217
+ fix: don't use sparsity ([`f30a19b`](https://github.com/jbloomAus/SAEDashboard/commit/f30a19b9cb42f15302848c31ddf1d14462209a42))
218
+
219
+ * fix: don't use sparsity ([`d5ba79b`](https://github.com/jbloomAus/SAEDashboard/commit/d5ba79bbf3d51cbc67e276297d18d85add9d33e7))
220
+
221
+ * fix: update SAELens version and remove unsupported load_sparsity ([`63192ba`](https://github.com/jbloomAus/SAEDashboard/commit/63192ba7d9de7afae9cb67f65db2e79a39b898c6))
222
+
223
+ ### Unknown
224
+
225
+ * Merge pull request #43 from jbloomAus/update_saelens
226
+
227
+ fix: update SAELens version and remove unsupported load_sparsity ([`c083723`](https://github.com/jbloomAus/SAEDashboard/commit/c083723237090165725e587f8bdb8f01338394b4))
228
+
229
+ ## v0.6.8 (2025-02-15)
230
+
231
+ ### Fix
232
+
233
+ * fix: prepended chat template text should not be in activations ([`f3c20ee`](https://github.com/jbloomAus/SAEDashboard/commit/f3c20eec31976db48c1f1d37aabd077e068f66ac))
234
+
235
+ ### Unknown
236
+
237
+ * Merge pull request #42 from jbloomAus/prepend_text_fix
238
+
239
+ fix: prepended chat template text should not be in activations ([`eea0b83`](https://github.com/jbloomAus/SAEDashboard/commit/eea0b830e97e791571986cf1ccae1605606ddb4f))
240
+
241
+ ## v0.6.7 (2025-02-13)
242
+
243
+ ### Fix
244
+
245
+ * fix: force build ([`9b96ac5`](https://github.com/jbloomAus/SAEDashboard/commit/9b96ac57e0b23c1a4cc73fbd9fd855ab0961cce7))
246
+
247
+ ### Unknown
248
+
249
+ * Merge pull request #41 from jbloomAus/prepend_chat_template
250
+
251
+ feat: Prepend chat template and activation threshold ([`c7347fa`](https://github.com/jbloomAus/SAEDashboard/commit/c7347faa7d1c800dae398ee8dbded53afced9aa4))
252
+
253
+ * add example ([`9ab42d6`](https://github.com/jbloomAus/SAEDashboard/commit/9ab42d66a2a3e445b300bf83f677f143da3fecd9))
254
+
255
+ * proper 'activation threshold' ([`a6d7c1c`](https://github.com/jbloomAus/SAEDashboard/commit/a6d7c1c8ca4a1ec46676298c4661f623bada9049))
256
+
257
+ * prepend chat template text ([`c8829a1`](https://github.com/jbloomAus/SAEDashboard/commit/c8829a14d6351b5cfd52af927942af5c6897db60))
258
+
259
+ ## v0.6.6 (2025-02-11)
260
+
261
+ ### Fix
262
+
263
+ * fix: run_settings.json should properly log model_id and layer ([`2e661d9`](https://github.com/jbloomAus/SAEDashboard/commit/2e661d95f30bc28e7d818bbd67de931a334d837f))
264
+
265
+ ### Unknown
266
+
267
+ * Merge pull request #40 from jbloomAus/run_settings_fix
268
+
269
+ fix: run_settings.json should properly log model_id and layer ([`f3bde39`](https://github.com/jbloomAus/SAEDashboard/commit/f3bde395843720674d4c60e21bc2453d958ff402))
270
+
271
+ ## v0.6.5 (2025-02-11)
272
+
273
+ ### Fix
274
+
275
+ * fix: Force Build ([`2e4979c`](https://github.com/jbloomAus/SAEDashboard/commit/2e4979c07ad8bcd2760ee0981ee415d17fef2e5a))
276
+
277
+ ### Unknown
278
+
279
+ * Merge pull request #39 from jbloomAus/allow_vector_output
280
+
281
+ feat: allow outputting raw vector in neuronpedia outputs ([`1444786`](https://github.com/jbloomAus/SAEDashboard/commit/14447862418b18d112053a7af8810b049400089a))
282
+
283
+ * remove debug log ([`6efeb6c`](https://github.com/jbloomAus/SAEDashboard/commit/6efeb6c1976dea8374800e7625d63a14a3b6438d))
284
+
285
+ * allow outputting vector ([`4c6cb35`](https://github.com/jbloomAus/SAEDashboard/commit/4c6cb35752317db9f22476d26dc1bab7e4d6e511))
286
+
287
+ * Merge pull request #37 from jbloomAus/feature/vector-dashboards
288
+
289
+ Feature/vector dashboards ([`64c44a9`](https://github.com/jbloomAus/SAEDashboard/commit/64c44a9c11b2dce26b030d5e7bbf782ef90a2985))
290
+
291
+ * typing ([`09aeeab`](https://github.com/jbloomAus/SAEDashboard/commit/09aeeabb4c0f45f4bdabb884f683208eb7073142))
292
+
293
+ * Fixed missing parameter ([`a91d9f5`](https://github.com/jbloomAus/SAEDashboard/commit/a91d9f5dc8c65249c032dc4088aead4364bc42e9))
294
+
295
+ * Fixed parameterization and formatting ([`1956fbc`](https://github.com/jbloomAus/SAEDashboard/commit/1956fbc0ab6e4200d771ad4b504946ff81707969))
296
+
297
+ * Renamed demo notebook, some cleanup ([`6a486a5`](https://github.com/jbloomAus/SAEDashboard/commit/6a486a5834377823f380f59cafcbc3debbdcc3ed))
298
+
299
+ * Working pipeline flow ([`fdb2292`](https://github.com/jbloomAus/SAEDashboard/commit/fdb2292ad84b083246fdf3be2820e0b31168dce2))
300
+
301
+ * First draft of vector vis pipeline ([`4351ef9`](https://github.com/jbloomAus/SAEDashboard/commit/4351ef938a82d2b5e5a37391c236791cc23b41e5))
302
+
303
+ * Merge pull request #38 from jbloomAus/feature/hf-model-override
304
+
305
+ enable passing custom HF model to replace model weights ([`5d98417`](https://github.com/jbloomAus/SAEDashboard/commit/5d98417877c2cfe52bb09ddada0b4b53849b344a))
306
+
307
+ * enable passing custom HF model to replace model weights ([`b2d6ae5`](https://github.com/jbloomAus/SAEDashboard/commit/b2d6ae5446fb79f4662bcdac6030cb6072b09b60))
308
+
309
+ * Don't copy to output folder by default ([`4dbde12`](https://github.com/jbloomAus/SAEDashboard/commit/4dbde1214d49eaaf9b591f083f34e57c8c0c1dbd))
310
+
311
+ * Don't save html file for NP outputs ([`a160bff`](https://github.com/jbloomAus/SAEDashboard/commit/a160bff204b7464d2de00e3f80c255123d11171b))
312
+
313
+ ## v0.6.4 (2024-10-24)
314
+
315
+ ### Fix
316
+
317
+ * fix: Merge pull request #33 from jbloomAus/fix/topk-selection-purview
318
+
319
+ Fix/topk selection purview ([`afccd5a`](https://github.com/jbloomAus/SAEDashboard/commit/afccd5aaa00d00672eb1270b258b69f0e51c046a))
320
+
321
+ ### Unknown
322
+
323
+ * updated formatting/typing ([`fb141ae`](https://github.com/jbloomAus/SAEDashboard/commit/fb141ae991261408d296286bf6777b2ec5f1f319))
324
+
325
+ * TopK will now select from all latents regardless of feature batch size ([`c1f0e14`](https://github.com/jbloomAus/SAEDashboard/commit/c1f0e14dda7aa3364bfd78ca2b8c04c95b2d14b3))
326
+
327
+ * Update README.md ([`8235a9e`](https://github.com/jbloomAus/SAEDashboard/commit/8235a9e3adaea50b6b9f26f575e25a254d67a135))
328
+
329
+ * Merge pull request #32 from jbloomAus/docs/readme-update
330
+
331
+ docs: updated readme ([`b5e5480`](https://github.com/jbloomAus/SAEDashboard/commit/b5e54808ee05fc75e68d74ec319bf49826b45508))
332
+
333
+ * Update README.md ([`a1546fd`](https://github.com/jbloomAus/SAEDashboard/commit/a1546fdef32745cdc862a5a2dd0478e57e45320d))
334
+
335
+ * Removed outdated vis type ([`b0676af`](https://github.com/jbloomAus/SAEDashboard/commit/b0676afcca0845b73a54d983eaa9d72b0e9dff05))
336
+
337
+ * Update README.md ([`9b8446a`](https://github.com/jbloomAus/SAEDashboard/commit/9b8446aa47f287ba80bf0ac4a39f7c77f0492990))
338
+
339
+ * Updated format ([`90e4a09`](https://github.com/jbloomAus/SAEDashboard/commit/90e4a09eedd7f428b64e58d5ca2fd1cfa658b0da))
340
+
341
+ * Updated readme ([`f6819a6`](https://github.com/jbloomAus/SAEDashboard/commit/f6819a6da594673cad65c9ccd3a4f67746de796d))
342
+
343
+ ## v0.6.3 (2024-10-23)
344
+
345
+ ### Fix
346
+
347
+ * fix: update cached_activations directory to include number of prompts ([`0308cb1`](https://github.com/jbloomAus/SAEDashboard/commit/0308cb146bf2eb9cee26f03d3098511d03022485))
348
+
349
+ ## v0.6.2 (2024-10-23)
350
+
351
+ ### Fix
352
+
353
+ * fix: lint ([`3fc0e2c`](https://github.com/jbloomAus/SAEDashboard/commit/3fc0e2ccb39ed1d3e31d66ae0aba2b2b367d46aa))
354
+
355
+ ### Unknown
356
+
357
+ * Merge branch 'main' of https://github.com/jbloomAus/SAEDashboard ([`8f74a96`](https://github.com/jbloomAus/SAEDashboard/commit/8f74a969f48a7e0fd8de17cc983acf3886db95ef))
358
+
359
+ ## v0.6.1 (2024-10-22)
360
+
361
+ ### Unknown
362
+
363
+ * Fix: divide by zero, cached_activations folder name ([`1792298`](https://github.com/jbloomAus/SAEDashboard/commit/179229805ae6489d86e235240c65d26db64b5cd7))
364
+
365
+ * Merge branch 'main' of https://github.com/jbloomAus/SAEDashboard ([`508a74d`](https://github.com/jbloomAus/SAEDashboard/commit/508a74df8ff279716501e4179c501b5089a8d706))
366
+
367
+ ## v0.6.0 (2024-10-21)
368
+
369
+ ### Feature
370
+
371
+ * feat: np sae id suffix ([`448b14e`](https://github.com/jbloomAus/SAEDashboard/commit/448b14e0b3aea8ff854a5365f164b6ce5f419f0d))
372
+
373
+ ### Fix
374
+
375
+ * fix: update saelens to v4 ([`ef1a330`](https://github.com/jbloomAus/SAEDashboard/commit/ef1a3302d0483eddb247defab5c88816850f7f63))
376
+
377
+ ### Unknown
378
+
379
+ * Merge pull request #31 from jbloomAus/fix/reduce-mem
380
+
381
+ fix: added mem cleanup ([`60bd716`](https://github.com/jbloomAus/SAEDashboard/commit/60bd716c7b52bb0eaea0937e097eb77ed78bd33d))
382
+
383
+ * Fixed formatting ([`f1fab0c`](https://github.com/jbloomAus/SAEDashboard/commit/f1fab0c1fd5be281e2162ab3f54ffc7f4c09a1ce))
384
+
385
+ * Added cleanup ([`305c46d`](https://github.com/jbloomAus/SAEDashboard/commit/305c46d7a30330bbae6893b83cb6d498c2c975f1))
386
+
387
+ * Merge pull request #30 from jbloomAus/feat-mask-via-position
388
+
389
+ feat: prepending/appending tokens for prompt template + feat mask via Position ([`4c60e4c`](https://github.com/jbloomAus/SAEDashboard/commit/4c60e4c834dfb5759ce55dc90d1f88768abfea0d))
390
+
391
+ * add a few tests ([`96247d5`](https://github.com/jbloomAus/SAEDashboard/commit/96247d5afaf141b8b1279c17fd135240b0d8e869))
392
+
393
+ * handle prefixes / suffixes and ignored positions ([`bff7fd9`](https://github.com/jbloomAus/SAEDashboard/commit/bff7fd98b09318a1b01d2bc4a06467f8afa156f9))
394
+
395
+ * simplify masking ([`385b6e1`](https://github.com/jbloomAus/SAEDashboard/commit/385b6e116ecac53ad4df8585f7513c3416707d8b))
396
+
397
+ * add option for ignoring tokens at particular positions ([`ed3426d`](https://github.com/jbloomAus/SAEDashboard/commit/ed3426de5cb1495c138f770eefa5f941408aa390))
398
+
399
+ * Merge pull request #29 from jbloomAus/refactor/optimize-dfa-speed
400
+
401
+ Sped up DFA calculation 60x ([`f992e3c`](https://github.com/jbloomAus/SAEDashboard/commit/f992e3cf116189625b3a92529cf68d6226a1221c))
402
+
403
+ * Sped up DFA calculation ([`be11cd5`](https://github.com/jbloomAus/SAEDashboard/commit/be11cd5652f0f8a8ae425555666b747b9b99314e))
404
+
405
+ * Added test to check for decoder weight dist (head dist) ([`f147696`](https://github.com/jbloomAus/SAEDashboard/commit/f1476967af5fee95313264ccaee668605d23b9ad))
406
+
407
+ * Merge pull request #28 from jbloomAus/feature/np-topk-size-arg
408
+
409
+ Feature/np topk size arg ([`c5c1365`](https://github.com/jbloomAus/SAEDashboard/commit/c5c136576609991177d3a8924b5bf75a42b66399))
410
+
411
+ * Simply updated default value for top K ([`5c855fe`](https://github.com/jbloomAus/SAEDashboard/commit/5c855fec0e58a114a537590d1400eaa42dd3610c))
412
+
413
+ * Testing variable topk sizes ([`79fe14b`](https://github.com/jbloomAus/SAEDashboard/commit/79fe14b840991bd1f8ada8462aeb65d72821c4aa))
414
+
415
+ * Merge pull request #25 from jbloomAus/fix/dfa-for-gqa
416
+
417
+ Fix/dfa for gqa ([`85c345f`](https://github.com/jbloomAus/SAEDashboard/commit/85c345f3ad8069a59be8d495242395c50381ab01))
418
+
419
+ * Fixed formatting ([`48a67c7`](https://github.com/jbloomAus/SAEDashboard/commit/48a67c79247d05745d355e6a4bf380e9df20474e))
420
+
421
+ * Removed redundant code from rebase ([`a71fb9d`](https://github.com/jbloomAus/SAEDashboard/commit/a71fb9dde6e880b0f4297277d27696c9d524d052))
422
+
423
+ * fixed rebase ([`57ee280`](https://github.com/jbloomAus/SAEDashboard/commit/57ee28021efd3678bcd9d12d55e048c14a2f2d47))
424
+
425
+ * Added tests for DFA for GQA ([`3b99e36`](https://github.com/jbloomAus/SAEDashboard/commit/3b99e36c74d2c61617cfed107bee3b0eb3b63294))
426
+
427
+ * Removed duplicate code ([`7093773`](https://github.com/jbloomAus/SAEDashboard/commit/7093773d079cd235aea99273a1365363a5bf8b6d))
428
+
429
+ * More rebasing stuff ([`59c6cd8`](https://github.com/jbloomAus/SAEDashboard/commit/59c6cd85ead287b2774aa591463d131840c7f270))
430
+
431
+ * Fixed formatting ([`ed7d3b1`](https://github.com/jbloomAus/SAEDashboard/commit/ed7d3b16a99e3e3a272e73356cc0509b2c59a292))
432
+
433
+ * Removed debugging statements ([`6489d1c`](https://github.com/jbloomAus/SAEDashboard/commit/6489d1c5b52ed86cb280c237c08e10238e0d0564))
434
+
435
+ * more debug prints x3 ([`5ba2b8a`](https://github.com/jbloomAus/SAEDashboard/commit/5ba2b8a69f1881b901131976c7d52f142068dbd2))
436
+
437
+ * more debug prints x2 ([`e124ff9`](https://github.com/jbloomAus/SAEDashboard/commit/e124ff906ec7b37083af4e4721b9e33902146e47))
438
+
439
+ * more debug prints ([`e2b0c35`](https://github.com/jbloomAus/SAEDashboard/commit/e2b0c35467e5d405abd3cca664dfd1960dbba0eb))
440
+
441
+ * temp print statements ([`95df55b`](https://github.com/jbloomAus/SAEDashboard/commit/95df55b29f9250f67c5b986216e587c37f72aa9e))
442
+
443
+ * Lowered default threshold ([`dc1f31a`](https://github.com/jbloomAus/SAEDashboard/commit/dc1f31a55400231e46feb58a8c100f66472baa1b))
444
+
445
+ * updated ignore ([`eb0d56a`](https://github.com/jbloomAus/SAEDashboard/commit/eb0d56a9f813b9cf82742093fae00bb0ccfdac45))
446
+
447
+ * Reduced memory load of GQA DFA ([`05867f1`](https://github.com/jbloomAus/SAEDashboard/commit/05867f1d0c8b5f2a5b76f3ea45ab9c87eaae9c09))
448
+
449
+ * DFA will now work for models with grouped query attention ([`91a5dd1`](https://github.com/jbloomAus/SAEDashboard/commit/91a5dd17a2e567efa7d8a89d228eb7de47ae6766))
450
+
451
+ * Added head attr weights functionality for when DFA is use ([`03a615f`](https://github.com/jbloomAus/SAEDashboard/commit/03a615f7c70a6f6e634845dab4051874698fac5b))
452
+
453
+ * Edited default chunk size ([`7d68f9e`](https://github.com/jbloomAus/SAEDashboard/commit/7d68f9e7131b8c5558e886022625dac267f20aab))
454
+
455
+ * Fixed formatting ([`4d5f38b`](https://github.com/jbloomAus/SAEDashboard/commit/4d5f38beca15f2ce05c89f83eb3e955c291f9687))
456
+
457
+ * Removed debugging statements and added device changes ([`76e17c9`](https://github.com/jbloomAus/SAEDashboard/commit/76e17c91a41b5df6047baa5bcfa33d253b029d29))
458
+
459
+ * more debug prints x3 ([`06535d3`](https://github.com/jbloomAus/SAEDashboard/commit/06535d3df168d92ac79d2f5a14b345c757dfd9de))
460
+
461
+ * more debug prints x2 ([`26e8297`](https://github.com/jbloomAus/SAEDashboard/commit/26e8297888de066f0097e3b73245eb149bfb327f))
462
+
463
+ * more debug prints ([`9ded356`](https://github.com/jbloomAus/SAEDashboard/commit/9ded356ea8c3c5dd841bf5a45ea65ae8c67935f5))
464
+
465
+ * temp print statements ([`024ad57`](https://github.com/jbloomAus/SAEDashboard/commit/024ad578b65b8f3592b42b66dc6a56aeae2a3116))
466
+
467
+ * Lowered default threshold ([`a3b5977`](https://github.com/jbloomAus/SAEDashboard/commit/a3b5977c0f1bb7a865f7349304a5dd8092f7c2e8))
468
+
469
+ * updated ignore ([`d5d325a`](https://github.com/jbloomAus/SAEDashboard/commit/d5d325a63b3b26b890c2bab512f2a8473bdc926a))
470
+
471
+ * Reduced memory load of GQA DFA ([`93eb1a9`](https://github.com/jbloomAus/SAEDashboard/commit/93eb1a9a92320d9f4645b500e22a566135918e3d))
472
+
473
+ * DFA will now work for models with grouped query attention ([`6594155`](https://github.com/jbloomAus/SAEDashboard/commit/65941559bac03a3e4fb128d5327033e01f19c18d))
474
+
475
+ * Added head attr weights functionality for when DFA is use ([`9312d90`](https://github.com/jbloomAus/SAEDashboard/commit/9312d901bf17e14400199c86e0284be6c750162a))
476
+
477
+ * Added tests for DFA for GQA ([`fcfac37`](https://github.com/jbloomAus/SAEDashboard/commit/fcfac37e148461e585f38fddf868ad2a32d908a8))
478
+
479
+ * Removed duplicate code ([`cc00944`](https://github.com/jbloomAus/SAEDashboard/commit/cc00944855720d5b8139d4267b44c1a230ef5319))
480
+
481
+ * Fixed formatting ([`50b08b4`](https://github.com/jbloomAus/SAEDashboard/commit/50b08b4eb50734afe0f085274ccaee71ec4017a4))
482
+
483
+ * Removed debugging statements ([`f7b949b`](https://github.com/jbloomAus/SAEDashboard/commit/f7b949b4af6bc8ca7557bfa5fa2441fbaa0284a0))
484
+
485
+ * more debug prints x3 ([`53536b0`](https://github.com/jbloomAus/SAEDashboard/commit/53536b03d624783b6b2f95b07b9318139ef0c49e))
486
+
487
+ * more debug prints x2 ([`6f2c504`](https://github.com/jbloomAus/SAEDashboard/commit/6f2c504a355f9071e766fc7fa3b6aad9890572a8))
488
+
489
+ * more debug prints ([`e1bef90`](https://github.com/jbloomAus/SAEDashboard/commit/e1bef90d16e8c73c9532b19a08c842757828c7ed))
490
+
491
+ * temp print statements ([`fd75714`](https://github.com/jbloomAus/SAEDashboard/commit/fd75714ee4631463c1f754d68f83b9ef75eb2285))
492
+
493
+ * updated ignore ([`c01062f`](https://github.com/jbloomAus/SAEDashboard/commit/c01062faecfaa132d87c56a7ba7add573c6b0f4e))
494
+
495
+ * Reduced memory load of GQA DFA ([`1ae40e9`](https://github.com/jbloomAus/SAEDashboard/commit/1ae40e9d487af7e8a7b148629588ef87fdd0a6e5))
496
+
497
+ * DFA will now work for models with grouped query attention ([`c66c90f`](https://github.com/jbloomAus/SAEDashboard/commit/c66c90f5d51961cafd5f13c26a94193ee38f828a))
498
+
499
+ * Edited default chunk size ([`3c78bdc`](https://github.com/jbloomAus/SAEDashboard/commit/3c78bdcfda12e5873de082a7f1e631a801bd9407))
500
+
501
+ * Fixed formatting ([`10a36e3`](https://github.com/jbloomAus/SAEDashboard/commit/10a36e3e8da3c7593058d3638ac3b7a32953b1b0))
502
+
503
+ * Removed debugging statements and added device changes ([`0f51dd9`](https://github.com/jbloomAus/SAEDashboard/commit/0f51dd953cd214244c71e8b9156b90483ceaa2be))
504
+
505
+ * more debug prints x3 ([`112ef42`](https://github.com/jbloomAus/SAEDashboard/commit/112ef4292b81a64f6168e7527ec583faa9ba20a4))
506
+
507
+ * more debug prints x2 ([`ef154d6`](https://github.com/jbloomAus/SAEDashboard/commit/ef154d6044bb67d17a2aa225ddf4099ccfc16b55))
508
+
509
+ * more debug prints ([`1b18d14`](https://github.com/jbloomAus/SAEDashboard/commit/1b18d141dd33e3a99c2abd5a6d195ab5142890d8))
510
+
511
+ * temp print statements ([`2194d2c`](https://github.com/jbloomAus/SAEDashboard/commit/2194d2cea16856c96ace47ad5ac560f088e769b0))
512
+
513
+ * Lowered default threshold ([`a49d1e5`](https://github.com/jbloomAus/SAEDashboard/commit/a49d1e5b94c8ef680448f20ded849c7752fb5131))
514
+
515
+ * updated ignore ([`2067655`](https://github.com/jbloomAus/SAEDashboard/commit/20676554541d29fddd87215a47e8e94891e342ac))
516
+
517
+ * Reduced memory load of GQA DFA ([`8ec1956`](https://github.com/jbloomAus/SAEDashboard/commit/8ec19566e8898413d349fe3f2e43fbff232ffa62))
518
+
519
+ * DFA will now work for models with grouped query attention ([`8f3cf55`](https://github.com/jbloomAus/SAEDashboard/commit/8f3cf5532e57abc6e694fb11c5f9c7c2915215c0))
520
+
521
+ * Added head attr weights functionality for when DFA is use ([`234ea32`](https://github.com/jbloomAus/SAEDashboard/commit/234ea3211ce7dbf84d101c4e8bfe844c3903b16a))
522
+
523
+ * Merge pull request #27 from jbloomAus/fix/resolve-duplication
524
+
525
+ Removed sources of duplicate sequences ([`525bffe`](https://github.com/jbloomAus/SAEDashboard/commit/525bffee516a630c4b4f033d3971fad8c6dd5a74))
526
+
527
+ * Updated location of wandb finish() ([`921da77`](https://github.com/jbloomAus/SAEDashboard/commit/921da77132a560505fa61decf287ca3833f96ec7))
528
+
529
+ * Added two sets of tests for duplication checks ([`3e95ffd`](https://github.com/jbloomAus/SAEDashboard/commit/3e95ffd1dafd01deb1f7817845ccb6229fb4ae09))
530
+
531
+ * Restored original random indices function as it seemed ok ([`388719b`](https://github.com/jbloomAus/SAEDashboard/commit/388719bec99b4306e81e0cdb772b9924db210774))
532
+
533
+ * Removed sources of duplicate sequences ([`853306c`](https://github.com/jbloomAus/SAEDashboard/commit/853306c4e08d9ec95674fdc5c87f807019055d0d))
534
+
535
+ ## v0.5.1 (2024-08-27)
536
+
537
+ ### Fix
538
+
539
+ * fix: multi-gpu-tlens
540
+
541
+ fix: handle multiple tlens devices ([`ed1e967`](https://github.com/jbloomAus/SAEDashboard/commit/ed1e967d44b887f4b99d2257934ca920d5c6a508))
542
+
543
+ * fix: handle multiple tlens devices ([`ba5368f`](https://github.com/jbloomAus/SAEDashboard/commit/ba5368f9999f08332c153816ba5836f8a1eb9ba1))
544
+
545
+ ## v0.5.0 (2024-08-25)
546
+
547
+ ### Feature
548
+
549
+ * feat: accelerate caching. Torch load / save faster when files are small.
550
+
551
+ Refactor/accelerate caching ([`6027d0a`](https://github.com/jbloomAus/SAEDashboard/commit/6027d0a3fc0d70908bad036a9658caa406d9f809))
552
+
553
+ ### Unknown
554
+
555
+ * Updated formatting ([`c1ea288`](https://github.com/jbloomAus/SAEDashboard/commit/c1ea2882a17e0d1b7b28743a34fca9d0754bd8a7))
556
+
557
+ * Sped up caching with native torch functions ([`230840a`](https://github.com/jbloomAus/SAEDashboard/commit/230840aea50b8b7055a6aa61961d7ac50855b763))
558
+
559
+ * Increased cache loading speed ([`83fe5f4`](https://github.com/jbloomAus/SAEDashboard/commit/83fe5f4bdf1252d533f203bc3f53ea9f71880ab8))
560
+
561
+ ## v0.4.0 (2024-08-22)
562
+
563
+ ### Feature
564
+
565
+ * feat: Refactor json writer and trigger DFA release
566
+
567
+ JSON writer has been refactored for reusability and readability ([`664f487`](https://github.com/jbloomAus/SAEDashboard/commit/664f4874b585c5510d2d3dd639c5e893023f6332))
568
+
569
+ ### Unknown
570
+
571
+ * Refactored JSON creation from the neuronpedia runner ([`d6bb24b`](https://github.com/jbloomAus/SAEDashboard/commit/d6bb24b6d773874d8e99be4d84402d559741907b))
572
+
573
+ * Merge pull request #20 from jbloomAus/feature/dfa
574
+
575
+ SAEVisRunner DFA Implementation ([`926ea87`](https://github.com/jbloomAus/SAEDashboard/commit/926ea87dd344548489201f68cc92b33662430813))
576
+
577
+ * Update ci.yaml ([`4b2807d`](https://github.com/jbloomAus/SAEDashboard/commit/4b2807dd865904120d236b355c0ccb1680c2919e))
578
+
579
+ * Fixed formatting ([`a62cc8f`](https://github.com/jbloomAus/SAEDashboard/commit/a62cc8f1bdd4c6e49b76d2d594e5a6b4b8183a8c))
580
+
581
+ * Fixed target index ([`ca2668d`](https://github.com/jbloomAus/SAEDashboard/commit/ca2668da03ea4d06cdc9f198988b80e0db844316))
582
+
583
+ * Corrected DFA indexing ([`d5028ae`](https://github.com/jbloomAus/SAEDashboard/commit/d5028aec875db4c03196726400c3b90b5d9d4d01))
584
+
585
+ * Adding temporary testing notebook ([`98e4b2f`](https://github.com/jbloomAus/SAEDashboard/commit/98e4b2f93d300ad4e94985d8d2594739a277e0c8))
586
+
587
+ * Added DFA output to neuronpedia runner ([`68eeff3`](https://github.com/jbloomAus/SAEDashboard/commit/68eeff3172b0c8637a6566c07951c28fd14a1c03))
588
+
589
+ * Fixed test typehints ([`d358e6f`](https://github.com/jbloomAus/SAEDashboard/commit/d358e6f5cc37304935eed949a0b0b985ba12b94f))
590
+
591
+ * Fixed formatting ([`5cb19e2`](https://github.com/jbloomAus/SAEDashboard/commit/5cb19e241051503730b6982813a6730556990c92))
592
+
593
+ * Corrected typehints ([`6173fbd`](https://github.com/jbloomAus/SAEDashboard/commit/6173fbd3824b7cba58e1cf0c7ee239762ee533ce))
594
+
595
+ * Removed another unused import ([`8be1572`](https://github.com/jbloomAus/SAEDashboard/commit/8be1572370b1adf341e2a650953bf17cd179808d))
596
+
597
+ * Removed unused imports ([`9071210`](https://github.com/jbloomAus/SAEDashboard/commit/90712105f74b287d77a06c045e8c32fd05f2e668))
598
+
599
+ * Added support for DFA calculations up to SAE Vis runner ([`4a08ffd`](https://github.com/jbloomAus/SAEDashboard/commit/4a08ffd13a8f29ff16808a20cd663c9d2d369e6a))
600
+
601
+ * Added activation collection flow for DFA ([`0ebb1f3`](https://github.com/jbloomAus/SAEDashboard/commit/0ebb1f3ca61603662f4f2cc8b1341470bf75b5d1))
602
+
603
+ * Merge pull request #19 from jbloomAus/fix/remove_precision_reduction
604
+
605
+ Removed precision reduction option ([`a5f8df1`](https://github.com/jbloomAus/SAEDashboard/commit/a5f8df15ef8619c4d08655e777d379a05b453346))
606
+
607
+ * Removed float16 option entirely from quantile calc ([`1b6a4a9`](https://github.com/jbloomAus/SAEDashboard/commit/1b6a4a93403ca2e9a869aa73600f37960090f03d))
608
+
609
+ * Removed precision reduction option ([`cd03ffb`](https://github.com/jbloomAus/SAEDashboard/commit/cd03ffb182e93a42480c01408b47ebae94d4c349))
610
+
611
+ ## v0.3.0 (2024-08-15)
612
+
613
+ ### Feature
614
+
615
+ * feat: seperate files per dashboard html ([`cd8d050`](https://github.com/jbloomAus/SAEDashboard/commit/cd8d050218ae3c6eeb7a9779072e60b78bfe0b58))
616
+
617
+ ### Unknown
618
+
619
+ * Merge pull request #17 from jbloomAus/refactor/remove_enc_b
620
+
621
+ Removed all encoder B code ([`67c9c3f`](https://github.com/jbloomAus/SAEDashboard/commit/67c9c3fdc8bd220938f65c1f97214034cc7528b4))
622
+
623
+ * Removed all encoder B code ([`5174e2e`](https://github.com/jbloomAus/SAEDashboard/commit/5174e2e161030dc756c148f1740e50c52baf6a91))
624
+
625
+ * Merge pull request #18 from jbloomAus/feat-seperate-files-per-html-dashboard
626
+
627
+ feat: seperate files per dashboard html ([`8ff69ba`](https://github.com/jbloomAus/SAEDashboard/commit/8ff69ba207692d4acb8d5fc19d038090067690df))
628
+
629
+ * Merge pull request #16 from jbloomAus/performance_refactor
630
+
631
+ Create() will now reduce precision by default ([`fb07b90`](https://github.com/jbloomAus/SAEDashboard/commit/fb07b90eaac395a58f02ba927460dcc2c9e61d1a))
632
+
633
+ * Removed line ([`d795490`](https://github.com/jbloomAus/SAEDashboard/commit/d795490c1c9d8193c8cf84d0352b9d93c41947fe))
634
+
635
+ * Removed unnecessary print ([`4544f86`](https://github.com/jbloomAus/SAEDashboard/commit/4544f86472480f0df00344fa84111a7c2a52fcef))
636
+
637
+ * Precision will now be reduced by default for quantile calc ([`539d222`](https://github.com/jbloomAus/SAEDashboard/commit/539d222ded9e3a0944f5240f3a4cd84497d11a74))
638
+
639
+ * Merge pull request #15 from jbloomAus/quantile_efficiency
640
+
641
+ Quantile OOM prevention ([`4a40c37`](https://github.com/jbloomAus/SAEDashboard/commit/4a40c3704aab9363163fef3e2830d42f2fecdc6b))
642
+
643
+ * Made quantile batch optional and removed sampling code ([`2df51d3`](https://github.com/jbloomAus/SAEDashboard/commit/2df51d353f818a196916a15f2bc56f70480dd853))
644
+
645
+ * Added device check for test ([`afbb960`](https://github.com/jbloomAus/SAEDashboard/commit/afbb960d3c9376ad512607146826b7d1c1e68d48))
646
+
647
+ * Added parameter for quantile calculation batching ([`49d0a7a`](https://github.com/jbloomAus/SAEDashboard/commit/49d0a7ab37896a085f80409900e3d0b261b8c9e0))
648
+
649
+ * Added type annotation ([`c71c4aa`](https://github.com/jbloomAus/SAEDashboard/commit/c71c4aa1c8bc25d85b9a955b482823cbde445a51))
650
+
651
+ * Removed unused imports ([`ec01bfe`](https://github.com/jbloomAus/SAEDashboard/commit/ec01bfefc2f0f4d880cd5744ff6a2ea71991349b))
652
+
653
+ * Added float16 version of quantile calculation ([`2f01eb8`](https://github.com/jbloomAus/SAEDashboard/commit/2f01eb8d9f84a20918f19e81c23df86ddc9d7f0c))
654
+
655
+ * Merge pull request #13 from jbloomAus/hook_z_support
656
+
657
+ fix: restore hook_z support following regression. ([`ea87559`](https://github.com/jbloomAus/SAEDashboard/commit/ea87559359f9821e352dcab582e23b42fef1cebf))
658
+
659
+ * format ([`21e3617`](https://github.com/jbloomAus/SAEDashboard/commit/21e3617196ef57944c141563e9263101baf9c7f1))
660
+
661
+ * make sure hook_z works ([`efaeec0`](https://github.com/jbloomAus/SAEDashboard/commit/efaeec0fdf8c2c43bb13bfd652b812a38ebc0200))
662
+
663
+ * Merge pull request #12 from jbloomAus/use_sae_lens_loading
664
+
665
+ Use sae lens loading ([`89bba3e`](https://github.com/jbloomAus/SAEDashboard/commit/89bba3e7a10877782608c50f4b8dd9054f204381))
666
+
667
+ * add settings.json ([`d8f3034`](https://github.com/jbloomAus/SAEDashboard/commit/d8f3034c0ed7241c35e9761d60a9ee4072403fd0))
668
+
669
+ * add dtype ([`0d8008a`](https://github.com/jbloomAus/SAEDashboard/commit/0d8008afe93a2a2a5bfc954571c680a529ab883f))
670
+
671
+ * cli util ([`9da440e`](https://github.com/jbloomAus/SAEDashboard/commit/9da440eb3d50d48a7fdc4d3ee3d26de13a458593))
672
+
673
+ * wandb logging improvement ([`a077369`](https://github.com/jbloomAus/SAEDashboard/commit/a077369ca43009f4e50c0b1e7176cae398703856))
674
+
675
+ * add override for np set name ([`8906d10`](https://github.com/jbloomAus/SAEDashboard/commit/8906d103ab8d10bd01b791331dfc5485ac047a4f))
676
+
677
+ * auto add folder path to output dir ([`35e06ab`](https://github.com/jbloomAus/SAEDashboard/commit/35e06ab89bce257fc15ffaa4918b9598577d6df0))
678
+
679
+ * update tests ([`50163b0`](https://github.com/jbloomAus/SAEDashboard/commit/50163b04ca29b492b9fb71244aa26798655b663f))
680
+
681
+ * first step towards sae_lens remote loading ([`415a2d1`](https://github.com/jbloomAus/SAEDashboard/commit/415a2d1e484e9ea2351bf98de221f6a83a805107))
682
+
683
+ ## v0.2.3 (2024-08-06)
684
+
685
+ ### Fix
686
+
687
+ * fix: neuronpedia uses api_key for uploading features, and update sae_id->sae_set ([`0336a35`](https://github.com/jbloomAus/SAEDashboard/commit/0336a3587f825f0be15af79cc9a0033dda3d4a3f))
688
+
689
+ ### Unknown
690
+
691
+ * Merge pull request #11 from jbloomAus/ignore_bos_option
692
+
693
+ Ignore bos option ([`ae34b70`](https://github.com/jbloomAus/SAEDashboard/commit/ae34b70b61993b4cce49a758bf85514410c67bd8))
694
+
695
+ * change threshold ([`4a0be67`](https://github.com/jbloomAus/SAEDashboard/commit/4a0be67622826f879191ced225c8c075d34bfe56))
696
+
697
+ * type fix ([`525b6a1`](https://github.com/jbloomAus/SAEDashboard/commit/525b6a10331b9fa0a464ae0c7f01af90ae97d0bb))
698
+
699
+ * default ignore bos eos pad ([`d2396a7`](https://github.com/jbloomAus/SAEDashboard/commit/d2396a714dd9ea3d59e516aa0fe30a9c9225e22f))
700
+
701
+ * ignore bos tokens ([`96cf6e9`](https://github.com/jbloomAus/SAEDashboard/commit/96cf6e9427cadf13fa13b55b7d1bc83ae81d9ec0))
702
+
703
+ * jump relu support in feature masking context ([`a1ba87a`](https://github.com/jbloomAus/SAEDashboard/commit/a1ba87a5c5e03687d7d7b5c5677bd9773fa49517))
704
+
705
+ * depend on latest sae lens ([`4988207`](https://github.com/jbloomAus/SAEDashboard/commit/4988207abaca24256f52235e474fe5fbb5028c1a))
706
+
707
+ * Merge pull request #10 from jbloomAus/auth_and_sae_set
708
+
709
+ fix: neuronpedia uses api_key for uploading features, and update sae_id -> sae_set ([`4684aca`](https://github.com/jbloomAus/SAEDashboard/commit/4684aca54b69dbc913c1122f1a322ed4d808dce0))
710
+
711
+ * Combine upload-features and upload-dead-stubs ([`faac839`](https://github.com/jbloomAus/SAEDashboard/commit/faac8398fee8582b12c2d1a29df6d4de7e542bed))
712
+
713
+ * Activation store device should be cuda when available ([`93050b1`](https://github.com/jbloomAus/SAEDashboard/commit/93050b1f5c2b87c8e889fe3449d440016c996762))
714
+
715
+ * Activation store device should be cuda when available ([`4469066`](https://github.com/jbloomAus/SAEDashboard/commit/4469066af06bb4944832f2e596e36afa09adf160))
716
+
717
+ * Better support for huggingface dataset path ([`3dc4b78`](https://github.com/jbloomAus/SAEDashboard/commit/3dc4b783a1ced7b938ab45c4d10effedd148a829))
718
+
719
+ * Docker tweak ([`a1a70cb`](https://github.com/jbloomAus/SAEDashboard/commit/a1a70cb28c726887de9439024b7b1d01082d3932))
720
+
721
+ ## v0.2.2 (2024-07-12)
722
+
723
+ ### Fix
724
+
725
+ * fix: don't sample too many tokens + other fixes
726
+
727
+ fix: don't sample too many tokens ([`b2554b0`](https://github.com/jbloomAus/SAEDashboard/commit/b2554b017e75d14b38b343fc6e0c1bcc32be2359))
728
+
729
+ * fix: don't sample too many tokens ([`0cbb2ed`](https://github.com/jbloomAus/SAEDashboard/commit/0cbb2edb480b83823dc1a98dd7e5978ecdda0d81))
730
+
731
+ ### Unknown
732
+
733
+ * - Don't force manual overrides for dtype - default to SAE's dtype
734
+ - Add n_prompts_in_forward_pass to neuronpedia.py
735
+ - Add n_prompts_total, n_tokens_in_prompt, and dataset to neuronpedia artifact
736
+ - Remove NPDashboardSettings for now (just save the NPRunnerConfig later)
737
+ - Fix lint error
738
+ - Consolidate minibatch_size_features/tokens to n_feats_at_a_time and n_prompts_in_fwd_pass
739
+ - Update/Fix NP acceptance test ([`b6282c8`](https://github.com/jbloomAus/SAEDashboard/commit/b6282c83e1898e356e271af0926e2271fb23f707))
740
+
741
+ * Merge pull request #7 from jbloomAus/performance-improvement
742
+
743
+ feat: performance improvement ([`f98b3dc`](https://github.com/jbloomAus/SAEDashboard/commit/f98b3dcf84c42687dfc92fa38377edd1c3f6fa30))
744
+
745
+ * delete unused snapshots ([`4210b48`](https://github.com/jbloomAus/SAEDashboard/commit/4210b48608792adc9b841ea92a64050311e66cd6))
746
+
747
+ * format ([`de57a2d`](https://github.com/jbloomAus/SAEDashboard/commit/de57a2d84564fc0eb7d5e42799c00f73c7007cf8))
748
+
749
+ * linter ([`4725ffa`](https://github.com/jbloomAus/SAEDashboard/commit/4725ffa2cbe743aa0bb615213f11105b6911f10d))
750
+
751
+ * hope flaky tests start passing ([`8ac9e8e`](https://github.com/jbloomAus/SAEDashboard/commit/8ac9e8e93127d4ab811019fc62bbe050a9a00e2c))
752
+
753
+ * np.memmap caching and more explicit hyperparams ([`9a24186`](https://github.com/jbloomAus/SAEDashboard/commit/9a24186cc1c118725c6db7dc3c77feb815cf938f))
754
+
755
+ * Move docker" ([`27b1a27`](https://github.com/jbloomAus/SAEDashboard/commit/27b1a27118bcccf54576eb1891b936bd92848f3f))
756
+
757
+ * Add docker to workflow ([`a354fa4`](https://github.com/jbloomAus/SAEDashboard/commit/a354fa47cfb005dd2304b4237f9182e2408daeed))
758
+
759
+ * Dockerignore file ([`ed9fcf3`](https://github.com/jbloomAus/SAEDashboard/commit/ed9fcf3a634cd57f6517170784d56d86431e1710))
760
+
761
+ * new versions ([`f64e54d`](https://github.com/jbloomAus/SAEDashboard/commit/f64e54df5c1b643fc3acaff7f4d40d5597edf61a))
762
+
763
+ * Add tools to docker image ([`2a70f64`](https://github.com/jbloomAus/SAEDashboard/commit/2a70f64cfd4177d807a8345e64699054dd103e8d))
764
+
765
+ * Fix docker ([`3805f20`](https://github.com/jbloomAus/SAEDashboard/commit/3805f20bff622582d16fd6603bef4b77e6bada9e))
766
+
767
+ * Fix docker image ([`7f9ff2f`](https://github.com/jbloomAus/SAEDashboard/commit/7f9ff2f9b10ce08264b2153e8191eca32f9ee48a))
768
+
769
+ * Fix NP simple test, remove check for correlated neurons/features ([`355fad5`](https://github.com/jbloomAus/SAEDashboard/commit/355fad58ab2ab036a33375c02d9006db634702b9))
770
+
771
+ * Dockerfile, small batching fix ([`4df4c51`](https://github.com/jbloomAus/SAEDashboard/commit/4df4c5138341a1c233c3d0fe1a3d399846e92407))
772
+
773
+ * set sae_device, activation_store device ([`6d65b22`](https://github.com/jbloomAus/SAEDashboard/commit/6d65b22ef541326cc9558119b40baeb95cc2e47e))
774
+
775
+ * Fix NP dtype error ([`8bb4d9d`](https://github.com/jbloomAus/SAEDashboard/commit/8bb4d9de0c75ffed5daaba4d5ec563fbbee38f86))
776
+
777
+ * format ([`f667d92`](https://github.com/jbloomAus/SAEDashboard/commit/f667d92d9359e5c7976e21e821ac0dde8a081da6))
778
+
779
+ * depend on latest sae_lens ([`4a2a6a0`](https://github.com/jbloomAus/SAEDashboard/commit/4a2a6a0fd70d7b4a3f1f870a510a800b31f57264))
780
+
781
+ * use a much better method for getting subsets of feature activations ([`7101f13`](https://github.com/jbloomAus/SAEDashboard/commit/7101f13e13b4de5659623433ec359ecf2142daef))
782
+
783
+ * add to gitignore ([`20180e0`](https://github.com/jbloomAus/SAEDashboard/commit/20180e06a279ef93d6127b467511911db352bce5))
784
+
785
+ * add isort ([`3ab0fda`](https://github.com/jbloomAus/SAEDashboard/commit/3ab0fdaf75f735ec2eedc904529909111d0db0de))
786
+
787
+ ## v0.2.1 (2024-07-08)
788
+
789
+ ### Fix
790
+
791
+ * fix: trigger release ([`87bf0b5`](https://github.com/jbloomAus/SAEDashboard/commit/87bf0b5f21f0d1f5397e514090601ec21c718e35))
792
+
793
+ ### Unknown
794
+
795
+ * Merge pull request #6 from jbloomAus/fix-bfloat16
796
+
797
+ fix bfloat 16 error ([`2f3c597`](https://github.com/jbloomAus/SAEDashboard/commit/2f3c597c1795357679e92caec3dd7e522c669fdb))
798
+
799
+ * fix bfloat 16 error ([`63c3c62`](https://github.com/jbloomAus/SAEDashboard/commit/63c3c62f0a03e5656ed78cc0e8f853bea3f0938e))
800
+
801
+ * Merge pull request #5 from jbloomAus/np-updates
802
+
803
+ Updates + fixes for Neuronpedia ([`9e6b5c4`](https://github.com/jbloomAus/SAEDashboard/commit/9e6b5c427024b8a468b0d06e4e096c2561c35d5d))
804
+
805
+ * Fix SAELens compatibility ([`139e1a2`](https://github.com/jbloomAus/SAEDashboard/commit/139e1a2f219d790c6f8faa9be34d9fbc9403dda3))
806
+
807
+ * Rename file ([`16709ad`](https://github.com/jbloomAus/SAEDashboard/commit/16709add9ee5063b3682be34eef0aea2ddf4eceb))
808
+
809
+ * Fix type ([`6b20386`](https://github.com/jbloomAus/SAEDashboard/commit/6b2038682ca41423dda3a3597bbe88120b120262))
810
+
811
+ * Make Neuronpedia outputs an object, and add a real acceptance test ([`a5db256`](https://github.com/jbloomAus/SAEDashboard/commit/a5db2560e5f90a49257124635b3fdbee117ed860))
812
+
813
+ * Np Runner: Multi-gpu defaults ([`07f7128`](https://github.com/jbloomAus/SAEDashboard/commit/07f71282681ffa801dd15f9265be349cd5745b42))
814
+
815
+ * Ensure minibatch is on correct device ([`e206546`](https://github.com/jbloomAus/SAEDashboard/commit/e2065462c445df0e0985fb6588d4c01cb39bbef5))
816
+
817
+ * NP Runner: Automatically use multi-gpu, devices ([`bf280e6`](https://github.com/jbloomAus/SAEDashboard/commit/bf280e685dc4dd2018cd41aa94a29bc853fcee18))
818
+
819
+ * Allow dtype override ([`a40077d`](https://github.com/jbloomAus/SAEDashboard/commit/a40077dac1fa2ae880fcdabe3227878ef2cfaebe))
820
+
821
+ * NP-Runner: Remove unnecessary layer of batching. ([`e2ac92b`](https://github.com/jbloomAus/SAEDashboard/commit/e2ac92b036d0192e132c8a8700a5a2f448d1983b))
822
+
823
+ * NP Runner: Allow skipping sparsity check ([`ef74d2a`](https://github.com/jbloomAus/SAEDashboard/commit/ef74d2aeea2463afe150a5e8824da5a5206cd3d0))
824
+
825
+ * Merge pull request #2 from jbloomAus/multiple-devices
826
+
827
+ feat: Multiple devices ([`535e6c9`](https://github.com/jbloomAus/SAEDashboard/commit/535e6c9689d855f82a6ddfd9f169720fe367bde3))
828
+
829
+ * format ([`7f892ad`](https://github.com/jbloomAus/SAEDashboard/commit/7f892ad0efb42025df0bcf26bdddd6fac4c2d8b1))
830
+
831
+ * NP runner takes device args seperately ([`8fc31dd`](https://github.com/jbloomAus/SAEDashboard/commit/8fc31dd6ccd59f4f35742a4e15c380673c8cb2a3))
832
+
833
+ * multi-gpu-support ([`5e24e4e`](https://github.com/jbloomAus/SAEDashboard/commit/5e24e4e6598dd7943f8d677042dcf84bc6f7a0a6))
834
+
835
+ ## v0.2.0 (2024-06-10)
836
+
837
+ ### Feature
838
+
839
+ * feat: experimental release 2 ([`e264f97`](https://github.com/jbloomAus/SAEDashboard/commit/e264f97d90299f6ade294db8ed03aed9cd7491ee))
840
+
841
+ ## v0.1.0 (2024-06-10)
842
+
843
+ ### Feature
844
+
845
+ * feat: experimental release ([`d79310a`](https://github.com/jbloomAus/SAEDashboard/commit/d79310a7b6599f7b813e214c9268d736e0cb87f0))
846
+
847
+ ### Unknown
848
+
849
+ * fix pyproject.toml ([`a27c87d`](https://github.com/jbloomAus/SAEDashboard/commit/a27c87da987f043b470abce3404e305ec3f0d620))
850
+
851
+ * test deployment ([`288a2d9`](https://github.com/jbloomAus/SAEDashboard/commit/288a2d9bf797a1a2f9947b1ceac5e47edc1684ba))
852
+
853
+ * refactor np runner and add acceptance test ([`212593c`](https://github.com/jbloomAus/SAEDashboard/commit/212593c33b3aec33078a121738c0a826f705722f))
854
+
855
+ * Fix: Default context tokens length for neuronpedia runner ([`aefe95c`](https://github.com/jbloomAus/SAEDashboard/commit/aefe95cb1be4139ac45f042abdc78e0feccfb490))
856
+
857
+ * Allow custom context tokens length for Neuronpedia runner ([`d204cc8`](https://github.com/jbloomAus/SAEDashboard/commit/d204cc8fbb2ef376a1a5e00cd4f1cc5db2afb279))
858
+
859
+ * Fix: Streaming default true ([`1b91dff`](https://github.com/jbloomAus/SAEDashboard/commit/1b91dff045fdbd8c118c5f209750eca60c260f5f))
860
+
861
+ * Fix n_devices error for non-cuda ([`70b2dbd`](https://github.com/jbloomAus/SAEDashboard/commit/70b2dbdb2da51f5d78b1c2ce3210865fc259c97b))
862
+
863
+ * fix import path for ci ([`3bd4687`](https://github.com/jbloomAus/SAEDashboard/commit/3bd468727e2ab0b7d77224b7c0dad88e0727b773))
864
+
865
+ * make pyright happy, start config ([`b39ae85`](https://github.com/jbloomAus/SAEDashboard/commit/b39ae85d938a0db7c70b7dff9683f68f255dfb67))
866
+
867
+ * add black ([`236855b`](https://github.com/jbloomAus/SAEDashboard/commit/236855be1ef1464ea85b2afc6aaee963326f9257))
868
+
869
+ * fix ci ([`12818d7`](https://github.com/jbloomAus/SAEDashboard/commit/12818d7e6cd3e483258598b668805c1a9a048049))
870
+
871
+ * add pytest cov ([`aae0571`](https://github.com/jbloomAus/SAEDashboard/commit/aae057159639cd247a82fdeda9eddb98612ceec6))
872
+
873
+ * bring checks in line with sae_lens ([`7cd9679`](https://github.com/jbloomAus/SAEDashboard/commit/7cd9679cc18c64a7c8a0a07a1f12e6fc87543537))
874
+
875
+ * activation scaling factor ([`333d377`](https://github.com/jbloomAus/SAEDashboard/commit/333d3770d0d1d3c40dfeb3335dcfc46e9b7da717))
876
+
877
+ * Move Neuronpedia runner to SAEDashboard ([`4e691ea`](https://github.com/jbloomAus/SAEDashboard/commit/4e691eaad919e12b9cae6ff707eaa3cf322ea030))
878
+
879
+ * fold w_dec norm by default ([`b6c9bc7`](https://github.com/jbloomAus/SAEDashboard/commit/b6c9bc70dc419d1e32bfb5580997369215e15429))
880
+
881
+ * rename sae_vis to sae_dashboard ([`f0f5341`](https://github.com/jbloomAus/SAEDashboard/commit/f0f5341ffdf31a11884777d6ba8100cd302b9dab))
882
+
883
+ * rename feature data generator ([`e02ed0a`](https://github.com/jbloomAus/SAEDashboard/commit/e02ed0a18e92c497aea3e137cf43e9f354f8f30f))
884
+
885
+ * update demo ([`8aa9e52`](https://github.com/jbloomAus/SAEDashboard/commit/8aa9e5272f54d04b741e63aa335bfa1212a2d0f7))
886
+
887
+ * add demo ([`dd3036f`](https://github.com/jbloomAus/SAEDashboard/commit/dd3036f90e6a4ed459ec21647744d491911900ac))
888
+
889
+ * delete old demo files ([`3d86202`](https://github.com/jbloomAus/SAEDashboard/commit/3d8620204cf6acb21b5e7f9983c300341345cd88))
890
+
891
+ * remove unnecessary print statement ([`9d3d937`](https://github.com/jbloomAus/SAEDashboard/commit/9d3d937e74f5575dde68d5a21fb73ce6f826d0d4))
892
+
893
+ * set sae lens version ([`87a7691`](https://github.com/jbloomAus/SAEDashboard/commit/87a76911ff0f0d46ab421d9b5107aef27216e88b))
894
+
895
+ * update older readme ([`c5c98e5`](https://github.com/jbloomAus/SAEDashboard/commit/c5c98e53531874efab5bc16235d9c72816fa61d5))
896
+
897
+ * test ([`923da42`](https://github.com/jbloomAus/SAEDashboard/commit/923da427b56178acd99b988d6d6b51368b5d2359))
898
+
899
+ * remove sae lens dep ([`2c26d5f`](https://github.com/jbloomAus/SAEDashboard/commit/2c26d5f4c40c41f750971601968577f316e15598))
900
+
901
+ * Merge branch 'refactor_b' ([`3154d63`](https://github.com/jbloomAus/SAEDashboard/commit/3154d636e1a9f8a30b54c17e62a842bed3f8b2a1))
902
+
903
+ * pass linting ([`0c079a1`](https://github.com/jbloomAus/SAEDashboard/commit/0c079a105b1b98e0edf2ff1a15593567c81bb103))
904
+
905
+ * format ([`6f37e2e`](https://github.com/jbloomAus/SAEDashboard/commit/6f37e2eb050a3207a2d3b9defd5d416645215c7c))
906
+
907
+ * run ci on all branches ([`faa0cc4`](https://github.com/jbloomAus/SAEDashboard/commit/faa0cc4eed4ff35f1e04656a968214c4fefbd573))
908
+
909
+ * don't use feature ablations ([`dc6e6dc`](https://github.com/jbloomAus/SAEDashboard/commit/dc6e6dc2d2affce331894d8bb61942e103182652))
910
+
911
+ * mock information in sequences to make normal sequence generation pass ([`c87b82f`](https://github.com/jbloomAus/SAEDashboard/commit/c87b82fdcc5e849d970cdc8bd1e841ec3e3e48ce))
912
+
913
+ * Remove resid ([`ff83737`](https://github.com/jbloomAus/SAEDashboard/commit/ff837373b65e60d8a9ba7c6e61f78bddc4d170f2))
914
+
915
+ * adding a test for direct_effect_feature_ablation_experiment ([`a9f3d1b`](https://github.com/jbloomAus/SAEDashboard/commit/a9f3d1b8021d8eeb60cf465934037d07583fa0b2))
916
+
917
+ * shortcut direct_effect_feature_ablation_experiment if everything is zero ([`2c68ff0`](https://github.com/jbloomAus/SAEDashboard/commit/2c68ff0c8496c58cc0732f3c51905c9c9f405393))
918
+
919
+ * fixing CI and replacing manual snapshots with syrupy snapshots ([`3b97640`](https://github.com/jbloomAus/SAEDashboard/commit/3b97640803cab3e3915202ac80c43b855c69c1cb))
920
+
921
+ * more refactor, WIP ([`81657c8`](https://github.com/jbloomAus/SAEDashboard/commit/81657c8c897a81102c0df7b29c49d526e639bb44))
922
+
923
+ * continue refactor, make data generator ([`eb1ae0f`](https://github.com/jbloomAus/SAEDashboard/commit/eb1ae0fc621407b33481c50b78b041079b08393d))
924
+
925
+ * add use of safetensors cache for repeated calculations ([`a241c32`](https://github.com/jbloomAus/SAEDashboard/commit/a241c322334340a84c2a252bc0b4a40ed2f19bc9))
926
+
927
+ * more refactor / benchmarking ([`d65ee87`](https://github.com/jbloomAus/SAEDashboard/commit/d65ee87cd191b2ed279f9f6efabb9e98bb700855))
928
+
929
+ * only run unit tests ([`5f11ddd`](https://github.com/jbloomAus/SAEDashboard/commit/5f11ddd9bc25f9c9bb7cbeba11224ba12b260ea8))
930
+
931
+ * fix lint issue ([`24daf17`](https://github.com/jbloomAus/SAEDashboard/commit/24daf17cb92534901681affbcebea314e2cf6580))
932
+
933
+ * format ([`83e89ed`](https://github.com/jbloomAus/SAEDashboard/commit/83e89ed4860d886ccf19be591bf72d0e029e7344))
934
+
935
+ * organise tests, make sure only unit tests run on CI ([`21f5fb1`](https://github.com/jbloomAus/SAEDashboard/commit/21f5fb155665329531b16d10673ddd988e7034ea))
936
+
937
+ * see if we can do some caching ([`c1dca6f`](https://github.com/jbloomAus/SAEDashboard/commit/c1dca6faa61de0849453acc83ae23baab6cf48be))
938
+
939
+ * more refactoring ([`b3f0f41`](https://github.com/jbloomAus/SAEDashboard/commit/b3f0f41f36f0eee57a08142880d4b6654309e62c))
940
+
941
+ * further refactor, possible significant speed up ([`ddd3496`](https://github.com/jbloomAus/SAEDashboard/commit/ddd3496206c0f3e751b596ca51e3544c77ddaf94))
942
+
943
+ * more refactor ([`a5f6deb`](https://github.com/jbloomAus/SAEDashboard/commit/a5f6deb4263c58803e7af23d767fc5cb17dfd2b2))
944
+
945
+ * refactoring in progress ([`d210b60`](https://github.com/jbloomAus/SAEDashboard/commit/d210b6056aa5316d2fd917e24ca8a819331a8114))
946
+
947
+ * use named arguments ([`4a81053`](https://github.com/jbloomAus/SAEDashboard/commit/4a8105355d3b86e460f32cd5c736dde0dbeaa2e3))
948
+
949
+ * remove create method ([`43b2018`](https://github.com/jbloomAus/SAEDashboard/commit/43b20184ed5ed0c2f08cfd13423f2271fd871274))
950
+
951
+ * move chunk ([`0f26aa8`](https://github.com/jbloomAus/SAEDashboard/commit/0f26aa85bc9fbe358f4c5f90971d51b86159f095))
952
+
953
+ * use fixtures ([`7c11dd9`](https://github.com/jbloomAus/SAEDashboard/commit/7c11dd914d467957e5c00b914f302c291924e411))
954
+
955
+ * refactor to create runner ([`9202c19`](https://github.com/jbloomAus/SAEDashboard/commit/9202c19f4ad6134eb6b68f857c9e4bfd0b911cf8))
956
+
957
+ * format ([`abd8747`](https://github.com/jbloomAus/SAEDashboard/commit/abd87472b76cfc151abfe2a6e312ea43b29c2250))
958
+
959
+ * target ci at this branch ([`ea3b2a3`](https://github.com/jbloomAus/SAEDashboard/commit/ea3b2a3181f2eb1ff52d83b2040b586d6fdfef4a))
960
+
961
+ * comment out release process for now ([`7084b5b`](https://github.com/jbloomAus/SAEDashboard/commit/7084b5ba3325bb559a8377d379bd2f3ba6d68348))
962
+
963
+ * test generated output ([`7b8b2ab`](https://github.com/jbloomAus/SAEDashboard/commit/7b8b2abd94213d67c378b7746107f6a7c811d93c))
964
+
965
+ * commit current demo html ([`00a03a0`](https://github.com/jbloomAus/SAEDashboard/commit/00a03a02fbf181caa55704defac25578b4444452))
966
+
967
+ ## v0.0.1 (2024-04-25)
968
+
969
+ ### Chore
970
+
971
+ * chore: setting up pytest ([`2079d00`](https://github.com/jbloomAus/SAEDashboard/commit/2079d00911d1a00ee19cde478b5cab61ca9c0495))
972
+
973
+ * chore: setting up semantic-release ([`09075af`](https://github.com/jbloomAus/SAEDashboard/commit/09075afbec279fb89d157f73e9a0ed47ba66d3c8))
974
+
975
+ ### Fix
976
+
977
+ * fix: remove circular dep with sae lens ([`1dd9f6c`](https://github.com/jbloomAus/SAEDashboard/commit/1dd9f6cd22f879e8d6904ba72f3e52b4344433cd))
978
+
979
+ ### Unknown
980
+
981
+ * Merge pull request #44 from chanind/pytest-setup
982
+
983
+ chore: setting up pytest ([`034eefa`](https://github.com/jbloomAus/SAEDashboard/commit/034eefa5a4163e9a560b574e2e255cd06f8f49a1))
984
+
985
+ * Merge pull request #43 from callummcdougall/move_saelens_dep
986
+
987
+ Remove dependency on saelens from pyproject, add to demo.ipynb ([`147d87e`](https://github.com/jbloomAus/SAEDashboard/commit/147d87ee9534d30e764851cbe73aadb5783d2515))
988
+
989
+ * Add missing matplotlib ([`572a3cc`](https://github.com/jbloomAus/SAEDashboard/commit/572a3cc79709a14117bbeafb871a33f0107600d8))
990
+
991
+ * Remove dependency on saelens from pyproject, add to demo.ipynb ([`1e6f3cf`](https://github.com/jbloomAus/SAEDashboard/commit/1e6f3cf9b2bcfb381a73d9333581c430faa531fd))
992
+
993
+ * Merge branch 'main' of https://github.com/callummcdougall/sae_vis ([`4e7a24c`](https://github.com/jbloomAus/SAEDashboard/commit/4e7a24c37444f11d718035eede68ac728d949a20))
994
+
995
+ * Merge pull request #41 from callummcdougall/allow_disable_buffer
996
+
997
+ oops I forgot to switch back to main before pushing ([`1312cd0`](https://github.com/jbloomAus/SAEDashboard/commit/1312cd09d6e274b1163e79d2ac01f2df54c65157))
998
+
999
+ * Merge branch 'main' into allow_disable_buffer ([`e7edf5a`](https://github.com/jbloomAus/SAEDashboard/commit/e7edf5a9bae4714bf4983ce6a19a0fe6fdf1f118))
1000
+
1001
+ * Merge pull request #40 from chanind/semantic-release-autodeploy
1002
+
1003
+ chore: setting up semantic-release for auto-deploy ([`a4d44d1`](https://github.com/jbloomAus/SAEDashboard/commit/a4d44d1a0e86055fb82ef41f51f0adbb7868df3c))
1004
+
1005
+ * Merge pull request #38 from chanind/type-checking
1006
+
1007
+ Enabling type checking with Pyright ([`f1fd792`](https://github.com/jbloomAus/SAEDashboard/commit/f1fd7926f46f00dca46024377f53aa8f2db98773))
1008
+
1009
+ * enabling type checking with Pyright ([`05d14ea`](https://github.com/jbloomAus/SAEDashboard/commit/05d14eafea707d3db81e78b4be87199087cb8e37))
1010
+
1011
+ * Merge pull request #39 from callummcdougall/fix_loading_saelens_sae
1012
+
1013
+ FIX: SAELens new format has "scaling_factor" key, which causes assert to fail ([`983aee5`](https://github.com/jbloomAus/SAEDashboard/commit/983aee562aea31e90657caf8c6ab6e450e952120))
1014
+
1015
+ * Fix Formatting ([`13b8106`](https://github.com/jbloomAus/SAEDashboard/commit/13b81062485f5dce2568e7832bfb2aae218dd4e9))
1016
+
1017
+ * Merge branch 'main' into fix_loading_saelens_sae ([`21b0086`](https://github.com/jbloomAus/SAEDashboard/commit/21b0086b8af3603441795e925a15e7cded122acb))
1018
+
1019
+ * format ([`8f1506b`](https://github.com/jbloomAus/SAEDashboard/commit/8f1506b6eb7dc0a2d4437d2aa23a0898c46a156d))
1020
+
1021
+ * Allow SAELens autoencoder keys to be superset of required keys, instead of exact match ([`6852170`](https://github.com/jbloomAus/SAEDashboard/commit/6852170d55e7d3cf22632c5807cfab219516da98))
1022
+
1023
+ * v0.2.17 ([`2bb14da`](https://github.com/jbloomAus/SAEDashboard/commit/2bb14daa88a0af601e13f4e51b50a2b00cd75b48))
1024
+
1025
+ * Use main branch of SAELens ([`2b34505`](https://github.com/jbloomAus/SAEDashboard/commit/2b345052bdc92ee9c1255cab0978916307a0a9dc))
1026
+
1027
+ * Update version 0.2.16 ([`bf90293`](https://github.com/jbloomAus/SAEDashboard/commit/bf902930844db9b0f8db4fbe8b3610557352660b))
1028
+
1029
+ * Merge pull request #36 from callummcdougall/allow_disable_buffer
1030
+
1031
+ FEATURE: Allow setting buffer to None, which gives the whole activation sequence ([`f5f9594`](https://github.com/jbloomAus/SAEDashboard/commit/f5f9594fcaf5edb6036a85446e092278004ea200))
1032
+
1033
+ * 16 ([`64e7018`](https://github.com/jbloomAus/SAEDashboard/commit/64e701849570d9e172dc065812c9a3e7149a9176))
1034
+
1035
+ * version 0.2.16 ([`afca0be`](https://github.com/jbloomAus/SAEDashboard/commit/afca0be8826e0c007b5730fa9fa18454699d16a3))
1036
+
1037
+ * Fix version ([`5a43916`](https://github.com/jbloomAus/SAEDashboard/commit/5a43916cbd9836396f051f7a258fdca8664e05e9))
1038
+
1039
+ * fix all indices view ([`5f87d52`](https://github.com/jbloomAus/SAEDashboard/commit/5f87d52154d6a8e8c8984836bbe8f85ee25f279d))
1040
+
1041
+ * Merge branch 'fix_gpt2_demo' into allow_disable_buffer ([`ea57bfc`](https://github.com/jbloomAus/SAEDashboard/commit/ea57bfc2ee1e23666810982abf32e6e9cbb74193))
1042
+
1043
+ * Allow disabling the buffer ([`c1be9f8`](https://github.com/jbloomAus/SAEDashboard/commit/c1be9f8e4b8ee6d8f18c4a1a0445840304440c1d))
1044
+
1045
+ * fix conflicts ([`ea3d624`](https://github.com/jbloomAus/SAEDashboard/commit/ea3d624013b9aa7cbd2d6eaa7212a1f7c4ee8e28))
1046
+
1047
+ * Merge pull request #35 from callummcdougall/fix_gpt2_demo
1048
+
1049
+ Fix usage of SAELens and demo notebook ([`88b5933`](https://github.com/jbloomAus/SAEDashboard/commit/88b59338d3cadbd5c70f0c1117dff00f01a54e6a))
1050
+
1051
+ * Import updated SAELens, use correct tokens, fix missing file cfg.json file error. ([`14ba9b0`](https://github.com/jbloomAus/SAEDashboard/commit/14ba9b03d4ce791ba8f4cac553fb82a93c47dfb8))
1052
+
1053
+ * Merge pull request #34 from ArthurConmy/patch-1
1054
+
1055
+ Update README.md ([`3faac82`](https://github.com/jbloomAus/SAEDashboard/commit/3faac82686f546800492d8aeb5e1d5919cbf1517))
1056
+
1057
+ * Update README.md ([`416eca8`](https://github.com/jbloomAus/SAEDashboard/commit/416eca8073c6cb2b120c759330ec47f52ab32d1e))
1058
+
1059
+ * Merge pull request #33 from chanind/setup-poetry-and-ruff
1060
+
1061
+ Setting up poetry / ruff / github actions ([`287f30f`](https://github.com/jbloomAus/SAEDashboard/commit/287f30f1d8fc39ab583f202c9277e07e5eeeaf62))
1062
+
1063
+ * setting up poetry and ruff for linting/formatting ([`0e0eba9`](https://github.com/jbloomAus/SAEDashboard/commit/0e0eba9e4d54c746cddc835ef4f6ddf2bab96844))
1064
+
1065
+ * fix feature vis demo gpt ([`821781e`](https://github.com/jbloomAus/SAEDashboard/commit/821781e96b732a5909d8735714482c965891b2ea))
1066
+
1067
+ * add scatter plot support ([`6eab28b`](https://github.com/jbloomAus/SAEDashboard/commit/6eab28bef9ef5cd9360fef73e02763301fa1a028))
1068
+
1069
+ * update setup ([`8d2ca53`](https://github.com/jbloomAus/SAEDashboard/commit/8d2ca53e8a6bba860fe71368741d06a718adaa27))
1070
+
1071
+ * fix setup ([`9cae8f4`](https://github.com/jbloomAus/SAEDashboard/commit/9cae8f461bd780e23eb2d994f56b495ede16201a))
1072
+
1073
+ * Merge branch 'main' of https://github.com/callummcdougall/sae_vis ([`ed8f8cb`](https://github.com/jbloomAus/SAEDashboard/commit/ed8f8cb7ad1fba2383dcdd471c33ce4a1b9f32e3))
1074
+
1075
+ * Merge pull request #27 from wllgrnt/will-add-eindex-dependency
1076
+
1077
+ Update setup.py with eindex dependency ([`8d7ed12`](https://github.com/jbloomAus/SAEDashboard/commit/8d7ed123505ac7ecf93dd310f57888547aead1d7))
1078
+
1079
+ * two more deps ([`7f231a8`](https://github.com/jbloomAus/SAEDashboard/commit/7f231a83acfef2494c1866249f57e10c21a1a443))
1080
+
1081
+ * Update setup.py with eindex
1082
+
1083
+ Without this, 'pip install sae-vis' will cause errors when e.g. you do 'from sae_vis.data_fetching_fns import get_feature_data' ([`a9d7de9`](https://github.com/jbloomAus/SAEDashboard/commit/a9d7de90b492f7305758e15303ba890fb9b503d0))
1084
+
1085
+ * fix sae bug ([`247d14b`](https://github.com/jbloomAus/SAEDashboard/commit/247d14b55f209ed9ccf50e5ce091ed66ffbf19d2))
1086
+
1087
+ * Merge pull request #32 from hijohnnylin/pin_older_sae_training
1088
+
1089
+ Demo notebook errors under "Multi-layer models" vis ([`9ac1dac`](https://github.com/jbloomAus/SAEDashboard/commit/9ac1dac51af32909666977cb5b3794965c70f62f))
1090
+
1091
+ * Pin older commit of mats_sae_training ([`8ca7ac1`](https://github.com/jbloomAus/SAEDashboard/commit/8ca7ac14b919fedb91240630ac7072cac40a6d6a))
1092
+
1093
+ * update version number ([`72e584b`](https://github.com/jbloomAus/SAEDashboard/commit/72e584b6492ed1ef3989968f6588a17fca758650))
1094
+
1095
+ * add gifs to readme ([`1393740`](https://github.com/jbloomAus/SAEDashboard/commit/13937405da31cca70cd1027aaca6c9cc84797ff1))
1096
+
1097
+ * test gif ([`4fbafa6`](https://github.com/jbloomAus/SAEDashboard/commit/4fbafa69343dc58dc18d0f78e393b5fcc9e24c0c))
1098
+
1099
+ * fix height issue ([`3f272f6`](https://github.com/jbloomAus/SAEDashboard/commit/3f272f61a954effef7bd648cc8117346da3bb971))
1100
+
1101
+ * fix pypi ([`7151164`](https://github.com/jbloomAus/SAEDashboard/commit/7151164cc0df8af278617147f07cbfbe3977cfeb))
1102
+
1103
+ * update setup ([`8c43478`](https://github.com/jbloomAus/SAEDashboard/commit/8c43478ad2eba8d3d4106fe4239c1229b8720fe6))
1104
+
1105
+ * Merge pull request #26 from hijohnnylin/update_html_anomalies
1106
+
1107
+ Update and add some HTML_ANOMALIES ([`1874a47`](https://github.com/jbloomAus/SAEDashboard/commit/1874a47a099ce32795bdbb5f98b9167dcca85ff2))
1108
+
1109
+ * Update and add some HTML_ANOMALIES ([`c541b7f`](https://github.com/jbloomAus/SAEDashboard/commit/c541b7f06108046ad1e2eb82c89f30f061f4411e))
1110
+
1111
+ * 0.2.9 ([`a5c8a6d`](https://github.com/jbloomAus/SAEDashboard/commit/a5c8a6d2008b818db90566cba50211845c753444))
1112
+
1113
+ * fix readme ([`5a8a7e3`](https://github.com/jbloomAus/SAEDashboard/commit/5a8a7e3173fc50fdb5ff0e56d7fa83e475af38a3))
1114
+
1115
+ * include feature tables ([`7c4c263`](https://github.com/jbloomAus/SAEDashboard/commit/7c4c263a2e069482d341b6265015664792bde817))
1116
+
1117
+ * add license ([`fa02a3d`](https://github.com/jbloomAus/SAEDashboard/commit/fa02a3dc93b721322b3902e2ac416ed156bf9d80))
1118
+
1119
+ * Merge branch 'main' of https://github.com/callummcdougall/sae_vis ([`ca5efcd`](https://github.com/jbloomAus/SAEDashboard/commit/ca5efcdc81074d3c3002bd997b35e326a44a4a25))
1120
+
1121
+ * Merge pull request #24 from chanind/fix-pypi-repo-link
1122
+
1123
+ fixing repo URL in setup.py ([`14a0be5`](https://github.com/jbloomAus/SAEDashboard/commit/14a0be54a57b1bc73ac4741611f9c8d1bd229e6f))
1124
+
1125
+ * fixing repo URL in setup.py ([`4faeca5`](https://github.com/jbloomAus/SAEDashboard/commit/4faeca5da06c0bb4384e202a91d895a217365d30))
1126
+
1127
+ * re-fix html anomalies ([`2fbae4c`](https://github.com/jbloomAus/SAEDashboard/commit/2fbae4c9a7dd663737bae25e73e978d40c59064a))
1128
+
1129
+ * fix hook point bug ([`9b573b2`](https://github.com/jbloomAus/SAEDashboard/commit/9b573b27590db1cbd6c8ef08fca7ff8c9d26b340))
1130
+
1131
+ * Merge pull request #20 from chanind/fix-final-resid-layer
1132
+
1133
+ fixing bug if hook_point == hook_point_resid_final ([`d6882e3`](https://github.com/jbloomAus/SAEDashboard/commit/d6882e3f813ef0d399e07548871f61b1f6a98ac6))
1134
+
1135
+ * fixing bug using hook_point_resid_final ([`cfe9b30`](https://github.com/jbloomAus/SAEDashboard/commit/cfe9b3042cfe127d5f7958064ffe817c25a19b56))
1136
+
1137
+ * fix indexing speed ([`865ff64`](https://github.com/jbloomAus/SAEDashboard/commit/865ff64329538641cd863dc7668dfc77907fb384))
1138
+
1139
+ * enable JSON saving ([`feea47a`](https://github.com/jbloomAus/SAEDashboard/commit/feea47a342d52296b72784ed18ea628848d4c7d4))
1140
+
1141
+ * Merge pull request #19 from chanind/support-mlp-and-attn-out
1142
+
1143
+ supporting mlp and attn out hooks ([`1c5463b`](https://github.com/jbloomAus/SAEDashboard/commit/1c5463b12f85cd0598b4e2fba5c556b1e9c0fbbe))
1144
+
1145
+ * supporting mlp and attn out hooks ([`a100e58`](https://github.com/jbloomAus/SAEDashboard/commit/a100e586498e8cae14df475bc7924cdecaed71ea))
1146
+
1147
+ * Merge branch 'main' of https://github.com/callummcdougall/sae_vis ([`083aeba`](https://github.com/jbloomAus/SAEDashboard/commit/083aeba0e4048d9976ec5cbee8df7dc8fd4db4e9))
1148
+
1149
+ * Merge pull request #18 from chanind/remove-build-artifacts
1150
+
1151
+ removing Python build artifacts and adding to .gitignore ([`b0e0594`](https://github.com/jbloomAus/SAEDashboard/commit/b0e0594590b4472b34052c6eb3ebceb6c9f58a11))
1152
+
1153
+ * removing Python build artifacts and adding to .gitignore ([`b6486f5`](https://github.com/jbloomAus/SAEDashboard/commit/b6486f56bea9d4bb7544c36afe70e6f891101b63))
1154
+
1155
+ * fix variable naming ([`2507918`](https://github.com/jbloomAus/SAEDashboard/commit/25079186b3f31d2271b1ecdb11f26904af7146d2))
1156
+
1157
+ * update readme ([`0ee3608`](https://github.com/jbloomAus/SAEDashboard/commit/0ee3608af396a1a6586dfb809f2f6480bb4f6390))
1158
+
1159
+ * update readme ([`f8351f8`](https://github.com/jbloomAus/SAEDashboard/commit/f8351f88e8432ccd4b2206e859daea316304d6c6))
1160
+
1161
+ * update version number ([`1e74408`](https://github.com/jbloomAus/SAEDashboard/commit/1e7440883f44a92705299430215f802fea4e1915))
1162
+
1163
+ * fix formatting and docstrings ([`b9fe2bb`](https://github.com/jbloomAus/SAEDashboard/commit/b9fe2bbb15a48e4b0415f6f4240d895990d54c9a))
1164
+
1165
+ * Merge pull request #17 from jordansauce/sae-agnostic-functions-new
1166
+
1167
+ Added SAE class agnostic functions ([`0039c6f`](https://github.com/jbloomAus/SAEDashboard/commit/0039c6f8f99d6e8a1b2ff56aa85f60a3eba3afb0))
1168
+
1169
+ * Added sae class agnostic functions
1170
+
1171
+ Added parse_feature_data() and parse_prompt_data() ([`e2709d0`](https://github.com/jbloomAus/SAEDashboard/commit/e2709d0b4c55d73d6026f3b9ce534f59ce61f344))
1172
+
1173
+ * add to pypi ([`02a5b9a`](https://github.com/jbloomAus/SAEDashboard/commit/02a5b9acd15433cc59d438271b9bd5e12d62b662))
1174
+
1175
+ * update notebook images ([`b87ad4d`](https://github.com/jbloomAus/SAEDashboard/commit/b87ad4d256f12c23605b0e7db307ee56913c93ef))
1176
+
1177
+ * fix layer parse and custom device ([`14c7ae9`](https://github.com/jbloomAus/SAEDashboard/commit/14c7ae9d0c8b7dad21b953cfc93fe7f34c74e149))
1178
+
1179
+ * update dropdown styling ([`83be219`](https://github.com/jbloomAus/SAEDashboard/commit/83be219bfe31b985a26762e06345c574aa0e6fe1))
1180
+
1181
+ * add custom prompt vis ([`cabdc5c`](https://github.com/jbloomAus/SAEDashboard/commit/cabdc5cb31f881cddf236490c41332c525d2ee74))
1182
+
1183
+ * d3 & multifeature refactor ([`f79a919`](https://github.com/jbloomAus/SAEDashboard/commit/f79a919691862f60a9e30fe0f79fd8e771bc932a))
1184
+
1185
+ * remove readme links ([`4bcef48`](https://github.com/jbloomAus/SAEDashboard/commit/4bcef489b644dd3357b1975f3245d534f6f0d2e0))
1186
+
1187
+ * add demo html ([`629c713`](https://github.com/jbloomAus/SAEDashboard/commit/629c713345407562dc4ccd9875bf3cfab5480bdd))
1188
+
1189
+ * remove demos ([`beedea9`](https://github.com/jbloomAus/SAEDashboard/commit/beedea9667761534a5293015aff9cc17638666a5))
1190
+
1191
+ * fix quantile error ([`3a23cfd`](https://github.com/jbloomAus/SAEDashboard/commit/3a23cfd56f21fe0775a1a9957db340d15f75f51a))
1192
+
1193
+ * width 425 ([`f25c776`](https://github.com/jbloomAus/SAEDashboard/commit/f25c776d5cb746916d3f2fdf368cbd5448742949))
1194
+
1195
+ * fix device bug ([`85dfa49`](https://github.com/jbloomAus/SAEDashboard/commit/85dfa497bc804945911e80607ac31cf3afbdc759))
1196
+
1197
+ * dont return vocab dict ([`b4c7138`](https://github.com/jbloomAus/SAEDashboard/commit/b4c713873870acb4035986cc5bff3a4ce1e466c9))
1198
+
1199
+ * save as JSON, fix device ([`eba2cff`](https://github.com/jbloomAus/SAEDashboard/commit/eba2cff3eb6215558577a6b4d4f8cc716766b927))
1200
+
1201
+ * simple fixed and issues ([`b28a0f7`](https://github.com/jbloomAus/SAEDashboard/commit/b28a0f7c7e936f4bea05528d952dfcd438533cce))
1202
+
1203
+ * Merge pull request #8 from lucyfarnik/topk-empty-mask
1204
+
1205
+ Topk error handling for empty masks ([`2740c00`](https://github.com/jbloomAus/SAEDashboard/commit/2740c0047e78df7e56d7bcf707c909ac18e71c1f))
1206
+
1207
+ * Topk error handling for empty masks ([`1c2627e`](https://github.com/jbloomAus/SAEDashboard/commit/1c2627e237f8f67725fc44e60a190bc141d36fc8))
1208
+
1209
+ * viz to vis ([`216d02b`](https://github.com/jbloomAus/SAEDashboard/commit/216d02b550d6fbcb9b37d39c1b272a7dda91aadc))
1210
+
1211
+ * update readme links ([`f9b3f95`](https://github.com/jbloomAus/SAEDashboard/commit/f9b3f95e31e7150024be27ec62246f43bf9bcbb8))
1212
+
1213
+ * update for TL ([`1941db1`](https://github.com/jbloomAus/SAEDashboard/commit/1941db1e22093d6fc88fb3fcd6f4c7d535d8b3b4))
1214
+
1215
+ * Merge pull request #5 from lucyfarnik/transformer-lens-models
1216
+
1217
+ Compatibility with TransformerLens models ([`8d59c6c`](https://github.com/jbloomAus/SAEDashboard/commit/8d59c6c5a5f2b98c486e5c74130371ad9254d1c9))
1218
+
1219
+ * Merge branch 'main' into transformer-lens-models ([`73057d7`](https://github.com/jbloomAus/SAEDashboard/commit/73057d7e2a3e4e9669fc0556e64190811ac8b52d))
1220
+
1221
+ * Merge pull request #4 from lucyfarnik/resid-saes-support
1222
+
1223
+ Added support for residual-adjacent SAEs ([`b02e98b`](https://github.com/jbloomAus/SAEDashboard/commit/b02e98b3b852c0613a890f8949d04b5560fb6fd6))
1224
+
1225
+ * Added support for residual-adjacent SAEs ([`89aacf1`](https://github.com/jbloomAus/SAEDashboard/commit/89aacf1b22aa81b393b10eca8611c9dbf406c638))
1226
+
1227
+ * Merge pull request #7 from lucyfarnik/fix-histogram-div-zero
1228
+
1229
+ Fixed division by zero in histogram calculation ([`3aee20e`](https://github.com/jbloomAus/SAEDashboard/commit/3aee20ea7f99cc07e6c5085fddb70cadd8327f4d))
1230
+
1231
+ * Fixed division by zero in histogram calculation ([`e986e90`](https://github.com/jbloomAus/SAEDashboard/commit/e986e907cc42790efc93ce75ebf7b28a0278aaa2))
1232
+
1233
+ * Merge pull request #6 from lucyfarnik/handling-dead-features
1234
+
1235
+ Edge case handling for dead features ([`9e43c30`](https://github.com/jbloomAus/SAEDashboard/commit/9e43c308e58769828234e1505f1c1102ba651dfd))
1236
+
1237
+ * Edge case handling for dead features ([`5197aee`](https://github.com/jbloomAus/SAEDashboard/commit/5197aee2c9f92bce7c5fd6d22201152a68c2e6ca))
1238
+
1239
+ * add features argument ([`f24ef7e`](https://github.com/jbloomAus/SAEDashboard/commit/f24ef7ebebb3d4fd92e299858dbd5b968b78c69e))
1240
+
1241
+ * fix image link ([`22c8734`](https://github.com/jbloomAus/SAEDashboard/commit/22c873434dfa84e3aed5ee0aab0fd25b288428a6))
1242
+
1243
+ * Merge pull request #1 from lucyfarnik/read-me-links-fix
1244
+
1245
+ Fixed readme links pointing to the old colab ([`86f8e20`](https://github.com/jbloomAus/SAEDashboard/commit/86f8e2012e376b6c498e5e708324f812af6fbc98))
1246
+
1247
+ * Fixed readme links pointing to the old colab ([`28ef1cb`](https://github.com/jbloomAus/SAEDashboard/commit/28ef1cbd1b91f6c09c842f48e1f997d189ca04e7))
1248
+
1249
+ * Added readme section about models ([`7523e7f`](https://github.com/jbloomAus/SAEDashboard/commit/7523e7f6363e030196496b3c6a3dc70b234c2d9a))
1250
+
1251
+ * Compatibility with TransformerLens models ([`ba708e9`](https://github.com/jbloomAus/SAEDashboard/commit/ba708e987be6cc7a09d34ea8fb83de009312684d))
1252
+
1253
+ * Added support for MPS ([`196c0a2`](https://github.com/jbloomAus/SAEDashboard/commit/196c0a24d0e8277b327eb2d57662075f9106990b))
1254
+
1255
+ * black font ([`d81e74d`](https://github.com/jbloomAus/SAEDashboard/commit/d81e74d575326ef786881fb9182a768f9de2cb70))
1256
+
1257
+ * fix html bug ([`265dedd`](https://github.com/jbloomAus/SAEDashboard/commit/265dedd376991230e2041fd37d5b6a0eda048545))
1258
+
1259
+ * add jax and dataset deps ([`f1caeaf`](https://github.com/jbloomAus/SAEDashboard/commit/f1caeafc9613e27c7663447cf862301ac11d842d))
1260
+
1261
+ * remove TL dependency ([`155991f`](https://github.com/jbloomAus/SAEDashboard/commit/155991fe61d0199d081d344ac44996edce35d118))
1262
+
1263
+ * first commit ([`7782eb6`](https://github.com/jbloomAus/SAEDashboard/commit/7782eb6d5058372630c5bbb8693eb540a7bceaf4))
SAEDashboard/Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # docker build --target development -t decoderesearch/saedashboard-cuda --file Dockerfile .
2
+ # docker run --entrypoint /bin/bash -it decoderesearch/saedashboard-cuda
3
+
4
+ ARG APP_NAME=sae_dashboard
5
+ ARG APP_PATH=/opt/$APP_NAME
6
+ ARG PYTHON_VERSION=3.12.2
7
+ ARG POETRY_VERSION=1.8.3
8
+
9
+ FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel AS staging
10
+ ARG APP_NAME
11
+ ARG APP_PATH
12
+ ARG POETRY_VERSION
13
+
14
+ ENV \
15
+ PYTHONDONTWRITEBYTECODE=1 \
16
+ PYTHONUNBUFFERED=1 \
17
+ PYTHONFAULTHANDLER=1
18
+ ENV \
19
+ POETRY_VERSION=$POETRY_VERSION \
20
+ POETRY_HOME="/opt/poetry" \
21
+ POETRY_VIRTUALENVS_IN_PROJECT=true \
22
+ POETRY_NO_INTERACTION=1
23
+
24
+ RUN apt-get update && apt-get install --no-install-recommends -y curl git-lfs vim && rm -rf /var/lib/apt/lists/*
25
+
26
+ RUN curl -sSL https://install.python-poetry.org | python
27
+ ENV PATH="$POETRY_HOME/bin:$PATH"
28
+
29
+ WORKDIR $APP_PATH
30
+ COPY ./pyproject.toml ./README.md ./
31
+ COPY ./$APP_NAME ./$APP_NAME
32
+
33
+ FROM staging AS development
34
+ ARG APP_NAME
35
+ ARG APP_PATH
36
+
37
+ WORKDIR $APP_PATH
38
+ RUN poetry lock
39
+ RUN poetry install --no-dev --no-cache
40
+
41
+ RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash
42
+ RUN git lfs install
43
+
44
+ ENTRYPOINT ["/bin/bash"]
45
+ CMD ["poetry", "shell"]
SAEDashboard/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Decode Research
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
SAEDashboard/Makefile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ format:
2
+ poetry run black .
3
+ poetry run isort .
4
+
5
+ lint:
6
+ poetry run flake8 .
7
+ poetry run black --check .
8
+ poetry run isort --check-only --diff .
9
+
10
+ check-type:
11
+ poetry run pyright .
12
+
13
+ test:
14
+ poetry run pytest --cov=sae_dashboard --cov-report=term-missing tests/unit
15
+
16
+ check-ci:
17
+ make format
18
+ make lint
19
+ make check-type
20
+ make test
21
+
22
+ profile-memory-unit:
23
+ poetry run pytest --memray tests/unit
24
+
25
+ profile-speed-unit:
26
+ poetry run py.test tests/unit --profile-svg -k "test_SaeVisData_create_results_look_reasonable[Default]"
27
+ open prof/combined.svg
SAEDashboard/README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAEDashboard
2
+
3
+ SAEDashboard is a tool for visualizing and analyzing Sparse Autoencoders (SAEs) in neural networks. This repository is an adaptation and extension of Callum McDougal's [SAEVis](https://github.com/callummcdougall/sae_vis/tree/main), providing enhanced functionality for feature visualization and analysis as well as feature dashboard creation at scale.
4
+
5
+ ## Overview
6
+
7
+ This codebase was originally designed to replicate Anthropic's sparse autoencoder visualizations, which you can see [here](https://transformer-circuits.pub/2023/monosemantic-features/vis/a1.html). SAEDashboard primarily provides visualizations of features, including their activations, logits, and correlations--similar to what is shown in the Anthropic link.
8
+
9
+ <img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/feature-vis-video.gif" width="800">
10
+
11
+ ## Features
12
+
13
+ - Customizable dashboards with various plots and data representations for SAE features
14
+ - Support for any SAE in the SAELens library
15
+ - Neuronpedia integration for hosting and comprehensive neuron analysis (note: this requires a Neuronpedia account and is currently only used internally)
16
+ - Ability to handle large datasets and models efficiently
17
+
18
+ ## Installation
19
+
20
+ Install SAEDashboard using pip:
21
+ ```bash
22
+ pip install sae-dashboard
23
+ ```
24
+
25
+
26
+ ## Quick Start
27
+
28
+ Here's a basic example of how to use SAEDashboard with SaeVisRunner:
29
+
30
+ ```python
31
+ from sae_lens import SAE
32
+ from transformer_lens import HookedTransformer
33
+ from sae_dashboard.sae_vis_data import SaeVisConfig
34
+ from sae_dashboard.sae_vis_runner import SaeVisRunner
35
+
36
+ # Load model and SAE
37
+ model = HookedTransformer.from_pretrained("gpt2-small", device="cuda", dtype="bfloat16")
38
+ sae, _, _ = SAE.from_pretrained(
39
+ release="gpt2-small-res-jb",
40
+ sae_id="blocks.6.hook_resid_pre",
41
+ device="cuda"
42
+ )
43
+ sae.fold_W_dec_norm()
44
+
45
+ # Configure visualization
46
+ config = SaeVisConfig(
47
+ hook_point=sae.cfg.hook_name,
48
+ features=list(range(256)),
49
+ minibatch_size_features=64,
50
+ minibatch_size_tokens=256,
51
+ device="cuda",
52
+ dtype="bfloat16"
53
+ )
54
+
55
+ # Generate data
56
+ data = SaeVisRunner(config).run(encoder=sae, model=model, tokens=your_token_dataset)
57
+
58
+ # Save feature-centric visualization
59
+ from sae_dashboard.data_writing_fns import save_feature_centric_vis
60
+ save_feature_centric_vis(sae_vis_data=data, filename="feature_dashboard.html")
61
+ ```
62
+
63
+ For a more detailed tutorial, check out our [demo notebook](https://colab.research.google.com/drive/1oqDS35zibmL1IUQrk_OSTxdhcGrSS6yO?usp=drive_link).
64
+
65
+ ## Advanced Usage: Neuronpedia Runner
66
+
67
+ For internal use or advanced analysis, SAEDashboard provides a Neuronpedia runner that generates data compatible with Neuronpedia. Here's a basic example:
68
+
69
+ ```python
70
+ from sae_dashboard.neuronpedia.neuronpedia_runner_config import NeuronpediaRunnerConfig
71
+ from sae_dashboard.neuronpedia.neuronpedia_runner import NeuronpediaRunner
72
+
73
+ config = NeuronpediaRunnerConfig(
74
+ sae_set="your_sae_set",
75
+ sae_path="path/to/sae",
76
+ np_set_name="your_neuronpedia_set_name",
77
+ huggingface_dataset_path="dataset/path",
78
+ n_prompts_total=1000,
79
+ n_features_at_a_time=64
80
+ )
81
+
82
+ runner = NeuronpediaRunner(config)
83
+ runner.run()
84
+ ```
85
+
86
+ For more options and detailed configuration, refer to the `NeuronpediaRunnerConfig` class in the code.
87
+
88
+ ## Cross-Layer Transcoder (CLT) Support
89
+
90
+ SAEDashboard now supports visualization of Cross-Layer Transcoders (CLTs), which are a variant of SAEs that process activations across transformer layers. To use CLT visualization:
91
+
92
+ ### Required Files
93
+
94
+ When using a CLT model, you'll need these files in your CLT model directory:
95
+
96
+ 1. **Model weights**: A `.safetensors` or `.pt` file containing the CLT weights
97
+ 2. **Configuration**: A `cfg.json` file with the CLT configuration, including:
98
+ - `num_features`: Number of features in the CLT
99
+ - `num_layers`: Number of transformer layers
100
+ - `d_model`: Model dimension
101
+ - `activation_fn`: Activation function (e.g., "jumprelu", "relu")
102
+ - `normalization_method`: How inputs are normalized (e.g., "mean_std", "none")
103
+ - `tl_input_template`: TransformerLens hook template (e.g., "blocks.{}.ln2.hook_normalized"). Note that this will usually differ from the hook name in the model's cfg.json, which is based on NNsight/transformers. You will need to find the corresponding TransformerLens hook name.
104
+ 3. **Normalization statistics** (if `normalization_method` is "mean_std"): A `norm_stats.json` file containing the mean and standard deviation for each layer's inputs, generated from the dataset when activations were generated (or afterwards). The file should have this structure:
105
+ ```json
106
+ {
107
+ "0": {
108
+ "inputs": {
109
+ "mean": [0.1, -0.2, ...], // Array of d_model values
110
+ "std": [1.0, 0.9, ...] // Array of d_model values
111
+ }
112
+ },
113
+ "1": {
114
+ "inputs": {
115
+ "mean": [...],
116
+ "std": [...]
117
+ }
118
+ },
119
+ // ... entries for each layer
120
+ }
121
+ ```
122
+
123
+ ### Example Usage
124
+
125
+ ```python
126
+ from sae_dashboard.neuronpedia.neuronpedia_runner_config import NeuronpediaRunnerConfig
127
+ from sae_dashboard.neuronpedia.neuronpedia_runner import NeuronpediaRunner
128
+
129
+ config = NeuronpediaRunnerConfig(
130
+ sae_set="your_clt_set",
131
+ sae_path="/path/to/clt/model/directory", # Directory containing the files above
132
+ model_id="gpt2", # Base model the CLT was trained on
133
+ outputs_dir="clt_outputs",
134
+ huggingface_dataset_path="your/dataset",
135
+ use_clt=True, # Enable CLT mode
136
+ clt_layer_idx=5, # Which layer to visualize (0-indexed)
137
+ clt_weights_filename="model.safetensors", # Optional: specify exact weights file
138
+ n_prompts_total=1000,
139
+ n_features_at_a_time=64
140
+ )
141
+
142
+ runner = NeuronpediaRunner(config)
143
+ runner.run()
144
+ ```
145
+
146
+ ### Notes on CLT Support
147
+
148
+ - CLTs must be loaded from local files (HuggingFace Hub loading not yet supported)
149
+ - The `--use-clt` flag is mutually exclusive with `--use-transcoder` and `--use-skip-transcoder`
150
+ - JumpReLU activation functions with learned thresholds are supported
151
+ - The visualization will show features for the specified layer only
152
+
153
+ ## Configuration Options
154
+
155
+ SAEDashboard offers a wide range of configuration options for both SaeVisRunner and NeuronpediaRunner. Key options include:
156
+
157
+ - `hook_point`: The layer to analyze in the model
158
+ - `features`: List of feature indices to visualize
159
+ - `minibatch_size_features`: Number of features to process in each batch
160
+ - `minibatch_size_tokens`: Number of tokens to process in each forward pass
161
+ - `device`: Computation device (e.g., "cuda", "cpu")
162
+ - `dtype`: Data type for computations
163
+ - `sparsity_threshold`: Threshold for feature sparsity (Neuronpedia runner)
164
+ - `n_prompts_total`: Total number of prompts to analyze
165
+ - `use_wandb`: Enable logging with Weights & Biases
166
+
167
+ Refer to `SaeVisConfig` and `NeuronpediaRunnerConfig` for full lists of options.
168
+
169
+ ## Contributing
170
+
171
+ This project uses [Poetry](https://python-poetry.org/) for dependency management. After cloning the repo, install dependencies with `poetry lock && poetry install`.
172
+
173
+ We welcome contributions to SAEDashboard! Please follow these steps:
174
+
175
+ 1. Fork the repository
176
+ 2. Create a new branch for your feature
177
+ 3. Implement your changes
178
+ 4. Run tests and checks:
179
+ - Use `make format` to format your code
180
+ - Use `make check-ci` to run all checks and tests
181
+ 5. Submit a pull request
182
+
183
+ Ensure your code passes all checks, including:
184
+ - Black and Flake8 for formatting and linting
185
+ - Pyright for type-checking
186
+ - Pytest for tests
187
+
188
+ ## Citing This Work
189
+
190
+ To cite SAEDashboard in your research, please use the following BibTeX entry:
191
+
192
+ ```bibtex
193
+ @misc{sae_dashboard,
194
+ title = {{SAE Dashboard}},
195
+ author = {Decode Research},
196
+ howpublished = {\url{https://github.com/jbloomAus/sae-dashboard}},
197
+ year = {2024}
198
+ }
199
+ ```
200
+
201
+ ## License
202
+
203
+ SAE Dashboard is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
204
+
205
+ ## Acknowledgment and Citation
206
+
207
+ This project is based on the work by Callum McDougall. If you use SAEDashboard in your research, please cite the original SAEVis project as well:
208
+
209
+ ```bibtex
210
+ @misc{sae_vis,
211
+ title = {{SAE Visualizer}},
212
+ author = {Callum McDougall},
213
+ howpublished = {\url{https://github.com/callummcdougall/sae_vis}},
214
+ year = {2024}
215
+ }
216
+ ```
217
+
218
+ ## Contact
219
+
220
+ For questions or support, please [open an issue](https://github.com/your-username/sae-dashboard/issues) on our GitHub repository.
221
+
SAEDashboard/docker/docker-entrypoint.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ set -e
4
+
5
+ # activate our virtual environment here
6
+ . /opt/pysetup/.venv/bin/activate
7
+
8
+ # You can put other setup logic here
9
+
10
+ # Evaluating passed command:
11
+ exec "$@"
SAEDashboard/docker/docker-hub.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build and Push Docker Image to Docker Hub
2
+
3
+ on:
4
+ push:
5
+ branches: [ "main" ]
6
+ pull_request:
7
+ branches: [ "main" ]
8
+
9
+ env:
10
+ REGISTRY: docker.io
11
+ IMAGE_NAME: decoderesearch/saedashboard-cuda
12
+
13
+ jobs:
14
+
15
+ build:
16
+
17
+ runs-on: ubuntu-latest
18
+
19
+ steps:
20
+ - uses: actions/checkout@v3
21
+ - name: Build the Docker image
22
+ run: docker build --target development -t ${{ env.IMAGE_NAME }} --file Dockerfile .
23
+ # test:
24
+ # runs-on: ubuntu-latest
25
+ # steps:
26
+ # - uses: actions/checkout@v2
27
+ # - name: Test the Docker image
28
+ # run: docker-compose up -d
29
+ push_to_registry:
30
+ name: Push Docker image to Docker Hub
31
+ runs-on: ubuntu-latest
32
+ steps:
33
+ - name: Check out the repo
34
+ uses: actions/checkout@v3
35
+
36
+ - name: Set up Docker Buildx
37
+ uses: docker/setup-buildx-action@v2
38
+
39
+ - name: Log in to Docker Hub
40
+ uses: docker/login-action@v3
41
+ with:
42
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
43
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
44
+
45
+ - name: Extract metadata (tags, labels) for Docker
46
+ id: meta
47
+ uses: docker/metadata-action@v5
48
+ with:
49
+ images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
50
+
51
+ - name: Build and push Docker image
52
+ uses: docker/build-push-action@v2
53
+ with:
54
+ context: "{{defaultContext}}"
55
+ push: true
56
+ tags: ${{ steps.meta.outputs.tags }}
57
+ labels: ${{ steps.meta.outputs.labels }}
SAEDashboard/neuronpedia_vector_pipeline_demo.ipynb ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Number of vectors: 1\n",
13
+ "Vector dimension: 768\n",
14
+ "Vector names: ['sentiment_vector']\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "# Example usage\n",
20
+ "import json\n",
21
+ "import torch\n",
22
+ "from pathlib import Path\n",
23
+ "from sae_dashboard.neuronpedia.vector_set import VectorSet\n",
24
+ "\n",
25
+ "\n",
26
+ "# Load vector from file. Note that the vectors should be stored in this format, as a list of lists of floats:\n",
27
+ "# {\n",
28
+ "# \"vectors\": [\n",
29
+ "# [vector_1],\n",
30
+ "# [vector_2],\n",
31
+ "# ...\n",
32
+ "# ]\n",
33
+ "# }\n",
34
+ "json_path = Path(\"test_vectors/logistic_direction.json\")\n",
35
+ "\n",
36
+ "# Load the vector into a VectorSet\n",
37
+ "vector_set = VectorSet.from_json(\n",
38
+ " json_path=json_path,\n",
39
+ " d_model=768, # Example dimension for GPT-2 Small\n",
40
+ " hook_point=\"blocks.7.hook_resid_pre\",\n",
41
+ " hook_layer=7,\n",
42
+ " model_name=\"gpt2\",\n",
43
+ " names=[\"sentiment_vector\"], # Optional custom name\n",
44
+ ")\n",
45
+ "\n",
46
+ "# Now you can use the vector set\n",
47
+ "print(f\"Number of vectors: {vector_set.vectors.shape[0]}\")\n",
48
+ "print(f\"Vector dimension: {vector_set.vectors.shape[1]}\")\n",
49
+ "print(f\"Vector names: {vector_set.names}\")"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 3,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "# You can also save and load the vector set as a VectorSet object as opposed to a simple list of lists of floats\n",
59
+ "vector_set.save(Path(\"test_vectors/logistic_direction_vector_set.json\"))\n",
60
+ "vector_set = VectorSet.load(Path(\"test_vectors/logistic_direction_vector_set.json\"))"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 4,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "from sae_dashboard.neuronpedia.neuronpedia_vector_runner import (\n",
70
+ " NeuronpediaVectorRunner,\n",
71
+ " NeuronpediaVectorRunnerConfig,\n",
72
+ ")\n",
73
+ "\n",
74
+ "cfg = NeuronpediaVectorRunnerConfig(\n",
75
+ " outputs_dir=\"test_outputs/\",\n",
76
+ " huggingface_dataset_path=\"monology/pile-uncopyrighted\",\n",
77
+ " vector_dtype=\"float32\",\n",
78
+ " model_dtype=\"float32\",\n",
79
+ " # Small test settings\n",
80
+ " n_prompts_total=16384,\n",
81
+ " n_tokens_in_prompt=128, # Shorter sequences\n",
82
+ " n_prompts_in_forward_pass=256,\n",
83
+ " n_vectors_at_a_time=1,\n",
84
+ " use_wandb=False, # Disable wandb for testing\n",
85
+ ")"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 5,
91
+ "metadata": {},
92
+ "outputs": [
93
+ {
94
+ "name": "stdout",
95
+ "output_type": "stream",
96
+ "text": [
97
+ "Device Count: 1\n",
98
+ "Using specified vector dtype: float32\n",
99
+ "SAE Device: mps\n",
100
+ "Model Device: mps\n",
101
+ "Model Num Devices: 1\n",
102
+ "Activation Store Device: mps\n",
103
+ "Dataset Path: monology/pile-uncopyrighted\n",
104
+ "Forward Pass size: 128\n",
105
+ "Total number of tokens: 2097152\n",
106
+ "Total number of contexts (prompts): 16384\n",
107
+ "Vector DType: float32\n",
108
+ "Model DType: float32\n"
109
+ ]
110
+ },
111
+ {
112
+ "name": "stderr",
113
+ "output_type": "stream",
114
+ "text": [
115
+ "/Users/curttigges/miniconda3/envs/sae-d/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
116
+ " warnings.warn(\n"
117
+ ]
118
+ },
119
+ {
120
+ "name": "stdout",
121
+ "output_type": "stream",
122
+ "text": [
123
+ "Loaded pretrained model gpt2 into HookedTransformer\n"
124
+ ]
125
+ },
126
+ {
127
+ "data": {
128
+ "application/vnd.jupyter.widget-view+json": {
129
+ "model_id": "f1a49eee02cd482e9de6deaa88e4afde",
130
+ "version_major": 2,
131
+ "version_minor": 0
132
+ },
133
+ "text/plain": [
134
+ "Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]"
135
+ ]
136
+ },
137
+ "metadata": {},
138
+ "output_type": "display_data"
139
+ },
140
+ {
141
+ "name": "stdout",
142
+ "output_type": "stream",
143
+ "text": [
144
+ "Warning: Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.\n",
145
+ "Tokens don't exist, making them.\n"
146
+ ]
147
+ },
148
+ {
149
+ "name": "stderr",
150
+ "output_type": "stream",
151
+ "text": [
152
+ " 0%| | 0/2048 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (3180 > 1024). Running this sequence through the model will result in indexing errors\n",
153
+ "100%|██████████| 2048/2048 [00:18<00:00, 108.67it/s]\n",
154
+ "0it [00:00, ?it/s]"
155
+ ]
156
+ },
157
+ {
158
+ "name": "stdout",
159
+ "output_type": "stream",
160
+ "text": [
161
+ "========== Running Batch #0 ==========\n"
162
+ ]
163
+ },
164
+ {
165
+ "data": {
166
+ "application/vnd.jupyter.widget-view+json": {
167
+ "model_id": "418cf3ccf15d4ae597e06d24e4c89b11",
168
+ "version_major": 2,
169
+ "version_minor": 0
170
+ },
171
+ "text/plain": [
172
+ "Forward passes to cache data for vis: 0%| | 0/60 [00:00<?, ?it/s]"
173
+ ]
174
+ },
175
+ "metadata": {},
176
+ "output_type": "display_data"
177
+ },
178
+ {
179
+ "data": {
180
+ "application/vnd.jupyter.widget-view+json": {
181
+ "model_id": "2a0ddbda94d0407598edf564b4487407",
182
+ "version_major": 2,
183
+ "version_minor": 0
184
+ },
185
+ "text/plain": [
186
+ "Extracting vis data from cached data: 0%| | 0/1 [00:00<?, ?it/s]"
187
+ ]
188
+ },
189
+ "metadata": {},
190
+ "output_type": "display_data"
191
+ },
192
+ {
193
+ "name": "stderr",
194
+ "output_type": "stream",
195
+ "text": [
196
+ "/Users/curttigges/Projects/SAEDashboard/sae_dashboard/vector_data_generator.py:205: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
197
+ " return torch.load(\n"
198
+ ]
199
+ },
200
+ {
201
+ "name": "stdout",
202
+ "output_type": "stream",
203
+ "text": [
204
+ "feature_indices: [0]\n"
205
+ ]
206
+ },
207
+ {
208
+ "data": {
209
+ "text/html": [
210
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━┳━━━━━━┳━━━━━━━┓\n",
211
+ "┃<span style=\"font-weight: bold\"> Task </span>┃<span style=\"font-weight: bold\"> Time </span>┃<span style=\"font-weight: bold\"> Pct % </span>┃\n",
212
+ "┡━━━━━━╇━━━━━━╇━━━━━━━┩\n",
213
+ "└──────┴──────┴───────┘\n",
214
+ "</pre>\n"
215
+ ],
216
+ "text/plain": [
217
+ "┏━━━━━━┳━━━━━━┳━━━━━━━┓\n",
218
+ "┃\u001b[1m \u001b[0m\u001b[1mTask\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mTime\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mPct %\u001b[0m\u001b[1m \u001b[0m┃\n",
219
+ "┡━━━━━━╇━━━━━━╇━━━━━━━┩\n",
220
+ "└──────┴──────┴───────┘\n"
221
+ ]
222
+ },
223
+ "metadata": {},
224
+ "output_type": "display_data"
225
+ },
226
+ {
227
+ "name": "stderr",
228
+ "output_type": "stream",
229
+ "text": [
230
+ "1it [00:02, 2.65s/it]"
231
+ ]
232
+ },
233
+ {
234
+ "name": "stdout",
235
+ "output_type": "stream",
236
+ "text": [
237
+ "Output written to test_outputs/gpt2_blocks.7.hook_resid_pre/batch-0.json\n"
238
+ ]
239
+ },
240
+ {
241
+ "name": "stderr",
242
+ "output_type": "stream",
243
+ "text": [
244
+ "\n"
245
+ ]
246
+ }
247
+ ],
248
+ "source": [
249
+ "runner = NeuronpediaVectorRunner(vector_set, cfg)\n",
250
+ "runner.run()"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "metadata": {},
257
+ "outputs": [],
258
+ "source": []
259
+ }
260
+ ],
261
+ "metadata": {
262
+ "kernelspec": {
263
+ "display_name": "sae-d",
264
+ "language": "python",
265
+ "name": "python3"
266
+ },
267
+ "language_info": {
268
+ "codemirror_mode": {
269
+ "name": "ipython",
270
+ "version": 3
271
+ },
272
+ "file_extension": ".py",
273
+ "mimetype": "text/x-python",
274
+ "name": "python",
275
+ "nbconvert_exporter": "python",
276
+ "pygments_lexer": "ipython3",
277
+ "version": "3.12.4"
278
+ }
279
+ },
280
+ "nbformat": 4,
281
+ "nbformat_minor": 2
282
+ }
SAEDashboard/notebooks/experiment_gemma_2_9b_dashboard_generation_np.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # I'm running this in an A100 with 90GB of GPU Ram.
2
+ # I'm using TransformerLens 2.2 which I manually installed from source.
3
+ # I'm a few edits to fix bfloat16 errors (but I've since made PR's so latest SAE Lens / SAE dashboard should be fine here).
4
+ import os
5
+
6
+ from sae_dashboard.neuronpedia.neuronpedia_runner import (
7
+ NeuronpediaRunner,
8
+ NeuronpediaRunnerConfig,
9
+ )
10
+
11
+ # GET WEIGHTS FROM WANDB
12
+ # import wandb
13
+ # run = wandb.init()
14
+ # artifact = run.use_artifact('jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7', type='model')
15
+ # artifact_dir = artifact.download()
16
+
17
+
18
+ # Get Sparsity from Wandb (and manually move it accross)
19
+ # import wandb
20
+ # run = wandb.init()
21
+ # artifact = run.use_artifact('jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688_log_feature_sparsity:v7', type='log_feature_sparsity')
22
+ # artifact_dir = artifact.download()
23
+
24
+ NP_OUTPUT_FOLDER = "neuronpedia_outputs/gemma-2-9b-test"
25
+ SAE_SET = "res-jb-test"
26
+ SAE_PATH = "artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7"
27
+ print(SAE_PATH)
28
+
29
+ # delete output files if present
30
+ os.system(f"rm -rf {NP_OUTPUT_FOLDER}")
31
+ cfg = NeuronpediaRunnerConfig(
32
+ sae_set=SAE_SET,
33
+ sae_path=SAE_PATH,
34
+ outputs_dir=NP_OUTPUT_FOLDER,
35
+ sparsity_threshold=-6,
36
+ n_prompts_total=4096,
37
+ huggingface_dataset_path="monology/pile-uncopyrighted",
38
+ n_features_at_a_time=1024,
39
+ n_tokens_in_prompt=128,
40
+ start_batch=0,
41
+ end_batch=8,
42
+ use_wandb=True,
43
+ sae_device="cuda",
44
+ model_device="cuda",
45
+ model_n_devices=1,
46
+ activation_store_device="cuda",
47
+ model_dtype="bfloat16",
48
+ sae_dtype="float32",
49
+ )
50
+
51
+ runner = NeuronpediaRunner(cfg)
52
+ runner.run()
SAEDashboard/notebooks/sae_dashboard_demo_gemma_2_9b.ipynb ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Demo Notebook"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "Steps:\n",
15
+ "1. Download SAE with SAE Lens.\n",
16
+ "2. Create a dataset consistent with that SAE. \n",
17
+ "3. Fold the SAE decoder norm weights so that feature activations are \"correct\".\n",
18
+ "4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.\n",
19
+ "5. Run the SAE generator for the features you want."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {},
25
+ "source": [
26
+ "# Set Up"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "# Download Gemma-2-9b weights\n",
36
+ "\n",
37
+ "import wandb\n",
38
+ "\n",
39
+ "run = wandb.init()\n",
40
+ "artifact = run.use_artifact(\n",
41
+ " \"jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7\",\n",
42
+ " type=\"model\",\n",
43
+ ")\n",
44
+ "artifact_dir = artifact.download()"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "import wandb\n",
54
+ "\n",
55
+ "run = wandb.init()\n",
56
+ "artifact = run.use_artifact(\n",
57
+ " \"jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688_log_feature_sparsity:v7\",\n",
58
+ " type=\"log_feature_sparsity\",\n",
59
+ ")\n",
60
+ "artifact_dir = artifact.download()"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "import torch\n",
70
+ "import matplotlib.pyplot as plt\n",
71
+ "from safetensors.torch import load_file\n",
72
+ "\n",
73
+ "# Assume we have a PyTorch tensor\n",
74
+ "feature_sparsity = load_file(\n",
75
+ " \"artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7/sparsity.safetensors\"\n",
76
+ ")[\"sparsity\"]\n",
77
+ "\n",
78
+ "# Convert the tensor to a numpy array\n",
79
+ "data = feature_sparsity.numpy()\n",
80
+ "\n",
81
+ "# Create the histogram\n",
82
+ "plt.hist(data, bins=30, edgecolor=\"black\")\n",
83
+ "\n",
84
+ "# Add labels and title\n",
85
+ "plt.xlabel(\"Value\")\n",
86
+ "plt.ylabel(\"Frequency\")\n",
87
+ "plt.title(\"Histogram of PyTorch Tensor\")\n",
88
+ "\n",
89
+ "# Show the plot\n",
90
+ "plt.show()"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "import torch\n",
100
+ "from transformer_lens import HookedTransformer\n",
101
+ "from sae_lens import ActivationsStore, SAE\n",
102
+ "from importlib import reload\n",
103
+ "import sae_dashboard\n",
104
+ "\n",
105
+ "torch.set_grad_enabled(False)\n",
106
+ "\n",
107
+ "reload(sae_dashboard)"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "MODEL = \"gemma-2-9b\"\n",
117
+ "\n",
118
+ "if torch.backends.mps.is_available():\n",
119
+ " device = \"mps\"\n",
120
+ "else:\n",
121
+ " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
122
+ "\n",
123
+ "print(f\"Device: {device}\")\n",
124
+ "\n",
125
+ "model = HookedTransformer.from_pretrained(MODEL, device=device, dtype=\"bfloat16\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "sae = SAE.load_from_pretrained(\n",
135
+ " \"artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7\"\n",
136
+ ")\n",
137
+ "sae.fold_W_dec_norm()"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "# _, cache = model.run_with_cache(\"Wasssssup\", names_filter = sae.cfg.hook_name)\n",
147
+ "# sae_in = cache[sae.cfg.hook_name]\n",
148
+ "# print(sae_in.shape)\n",
149
+ "sae_in = torch.rand((1, 4, 3584)).to(sae.device)\n",
150
+ "sae_out = sae(sae_in)"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "# # the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
160
+ "# # Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
161
+ "# # We also return the feature sparsities which are stored in HF for convenience.\n",
162
+ "# sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
163
+ "# release = \"mistral-7b-res-wg\", # see other options in sae_lens/pretrained_saes.yaml\n",
164
+ "# sae_id = \"blocks.8.hook_resid_pre\", # won't always be a hook point\n",
165
+ "# device = \"cuda:3\",\n",
166
+ "# )\n",
167
+ "# # fold w_dec norm so feature activations are accurate\n",
168
+ "#\n",
169
+ "activations_store = ActivationsStore.from_sae(\n",
170
+ " model=model,\n",
171
+ " sae=sae,\n",
172
+ " streaming=True,\n",
173
+ " store_batch_size_prompts=8,\n",
174
+ " n_batches_in_buffer=8,\n",
175
+ " device=\"cpu\",\n",
176
+ ")"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "sae.encode_fn"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "from sae_lens import run_evals\n",
195
+ "\n",
196
+ "eval_metrics = run_evals(\n",
197
+ " sae=sae,\n",
198
+ " activation_store=activations_store,\n",
199
+ " model=model,\n",
200
+ " n_eval_batches=3,\n",
201
+ " eval_batch_size_prompts=8,\n",
202
+ ")\n",
203
+ "\n",
204
+ "# CE Loss score should be high for residual stream SAEs\n",
205
+ "print(eval_metrics[\"metrics/CE_loss_score\"])\n",
206
+ "\n",
207
+ "# ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly\n",
208
+ "print(eval_metrics[\"metrics/ce_loss_without_sae\"])\n",
209
+ "\n",
210
+ "# ce loss with SAE shouldn't be massively higher\n",
211
+ "print(eval_metrics[\"metrics/ce_loss_with_sae\"])"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "from tqdm import tqdm\n",
221
+ "\n",
222
+ "\n",
223
+ "from sae_dashboard.utils_fns import get_tokens\n",
224
+ "\n",
225
+ "# 1000 prompts is plenty for a demo.\n",
226
+ "token_dataset = get_tokens(activations_store, 4096)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "# torch.save(token_dataset, \"to\")"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "# torch.save(token_dataset, \"token_dataset.pt\")\n",
245
+ "token_dataset = torch.load(\"token_dataset.pt\")"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": [
254
+ "import os\n",
255
+ "\n",
256
+ "os.rmdir(\"demo_activations_cache\")"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "import torch\n",
266
+ "\n",
267
+ "\n",
268
+ "def select_indices_in_range(tensor, min_val, max_val, num_samples=None):\n",
269
+ " \"\"\"\n",
270
+ " Select indices of a tensor where values fall within a specified range.\n",
271
+ "\n",
272
+ " Args:\n",
273
+ " tensor (torch.Tensor): Input tensor with values between -10 and 0.\n",
274
+ " min_val (float): Minimum value of the range (inclusive).\n",
275
+ " max_val (float): Maximum value of the range (inclusive).\n",
276
+ " num_samples (int, optional): Number of indices to randomly select. If None, return all indices.\n",
277
+ "\n",
278
+ " Returns:\n",
279
+ " torch.Tensor: Tensor of selected indices.\n",
280
+ " \"\"\"\n",
281
+ " # Ensure the input range is valid\n",
282
+ " if not (-10 <= min_val <= max_val <= 0):\n",
283
+ " raise ValueError(\n",
284
+ " \"Range must be within -10 to 0, and min_val must be <= max_val\"\n",
285
+ " )\n",
286
+ "\n",
287
+ " # Find indices where values are within the specified range\n",
288
+ " mask = (tensor >= min_val) & (tensor <= max_val)\n",
289
+ " indices = mask.nonzero().squeeze()\n",
290
+ "\n",
291
+ " # If num_samples is specified and less than the total number of valid indices,\n",
292
+ " # randomly select that many indices\n",
293
+ " if num_samples is not None and num_samples < indices.numel():\n",
294
+ " perm = torch.randperm(indices.numel())\n",
295
+ " indices = indices[perm[:num_samples]]\n",
296
+ "\n",
297
+ " return indices\n",
298
+ "\n",
299
+ "\n",
300
+ "n_features = 4096\n",
301
+ "feature_idxs = select_indices_in_range(feature_sparsity, -4, -2, 4096)\n",
302
+ "feature_sparsity[feature_idxs.tolist()]"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": []
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "from importlib import reload\n",
319
+ "import sys\n",
320
+ "\n",
321
+ "\n",
322
+ "def reload_user_modules(module_names):\n",
323
+ " \"\"\"Reload specified user modules.\"\"\"\n",
324
+ " for name in module_names:\n",
325
+ " if name in sys.modules:\n",
326
+ " reload(sys.modules[name])\n",
327
+ "\n",
328
+ "\n",
329
+ "# List of your module names\n",
330
+ "user_modules = [\n",
331
+ " \"sae_dashboard\",\n",
332
+ " \"sae_dashboard.sae_vis_runner\",\n",
333
+ " \"sae_dashboard.data_parsing_fns\",\n",
334
+ " \"sae_dashboard.feature_data_generator\",\n",
335
+ "]\n",
336
+ "\n",
337
+ "# Reload modules\n",
338
+ "reload_user_modules(user_modules)\n",
339
+ "\n",
340
+ "# Re-import after reload\n",
341
+ "from sae_dashboard.feature_data_generator import FeatureDataGenerator"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "from pathlib import Path\n",
351
+ "\n",
352
+ "test_feature_idx_gpt = feature_idxs.tolist()\n",
353
+ "\n",
354
+ "feature_vis_config_gpt = sae_vis_runner.SaeVisConfig(\n",
355
+ " hook_point=sae.cfg.hook_name,\n",
356
+ " features=test_feature_idx_gpt,\n",
357
+ " minibatch_size_features=16,\n",
358
+ " minibatch_size_tokens=4096, # this is really prompt with the number of tokens determined by the sequence length\n",
359
+ " verbose=True,\n",
360
+ " device=\"cuda\",\n",
361
+ " cache_dir=Path(\n",
362
+ " \"demo_activations_cache\"\n",
363
+ " ), # this will enable us to skip running the model for subsequent features.\n",
364
+ " dtype=\"bfloat16\",\n",
365
+ ")\n",
366
+ "\n",
367
+ "runner = sae_vis_runner.SaeVisRunner(feature_vis_config_gpt)\n",
368
+ "\n",
369
+ "data = runner.run(\n",
370
+ " encoder=sae, # type: ignore\n",
371
+ " model=model,\n",
372
+ " tokens=token_dataset[:1024],\n",
373
+ ")"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": null,
379
+ "metadata": {},
380
+ "outputs": [],
381
+ "source": []
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": null,
386
+ "metadata": {},
387
+ "outputs": [],
388
+ "source": [
389
+ "from sae_dashboard.data_writing_fns import save_feature_centric_vis\n",
390
+ "\n",
391
+ "filename = f\"demo_feature_dashboards.html\"\n",
392
+ "save_feature_centric_vis(sae_vis_data=data, filename=filename)"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": null,
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": []
401
+ },
402
+ {
403
+ "cell_type": "markdown",
404
+ "metadata": {},
405
+ "source": [
406
+ "# Quick Profiling experiment"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": null,
412
+ "metadata": {},
413
+ "outputs": [],
414
+ "source": [
415
+ "def mock_feature_acts_subset_for_now(sae: SAE):\n",
416
+ "\n",
417
+ " @torch.no_grad()\n",
418
+ " def sae_lens_get_feature_acts_subset(x: torch.Tensor, feature_idx): # type: ignore\n",
419
+ " \"\"\"\n",
420
+ " Get a subset of the feature activations for a dataset.\n",
421
+ " \"\"\"\n",
422
+ " original_device = x.device\n",
423
+ " feature_activations = sae.encode_fn(x.to(device=sae.device, dtype=sae.dtype))\n",
424
+ " return feature_activations[..., feature_idx].to(original_device)\n",
425
+ "\n",
426
+ " sae.get_feature_acts_subset = sae_lens_get_feature_acts_subset # type: ignore\n",
427
+ "\n",
428
+ " return sae\n",
429
+ "\n",
430
+ "\n",
431
+ "sae = mock_feature_acts_subset_for_now(sae)\n",
432
+ "feature_idxs = list(range(128))\n",
433
+ "sae_in = torch.rand((1, 4, 3584)).to(sae.device)\n",
434
+ "sae.get_feature_acts_subset(sae_in, feature_idxs)"
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "execution_count": null,
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "for k, v in sae.named_parameters():\n",
444
+ " print(k, v.shape)"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": [
453
+ "from torch import nn\n",
454
+ "from typing import List\n",
455
+ "\n",
456
+ "\n",
457
+ "class FeatureMaskingContext:\n",
458
+ " def __init__(self, sae: SAE, feature_idxs: List):\n",
459
+ " self.sae = sae\n",
460
+ " self.feature_idxs = feature_idxs\n",
461
+ " self.original_weight = {}\n",
462
+ "\n",
463
+ " def __enter__(self):\n",
464
+ "\n",
465
+ " ## W_dec\n",
466
+ " self.original_weight[\"W_dec\"] = getattr(self.sae, \"W_dec\").data.clone()\n",
467
+ " # mask the weight\n",
468
+ " masked_weight = sae.W_dec[self.feature_idxs]\n",
469
+ " # set the weight\n",
470
+ " setattr(self.sae, \"W_dec\", nn.Parameter(masked_weight))\n",
471
+ "\n",
472
+ " ## W_enc\n",
473
+ " # clone the weight.\n",
474
+ " self.original_weight[\"W_enc\"] = getattr(self.sae, \"W_enc\").data.clone()\n",
475
+ " # mask the weight\n",
476
+ " masked_weight = sae.W_enc[:, self.feature_idxs]\n",
477
+ " # set the weight\n",
478
+ " setattr(self.sae, \"W_enc\", nn.Parameter(masked_weight))\n",
479
+ "\n",
480
+ " if self.sae.cfg.architecture == \"standard\":\n",
481
+ "\n",
482
+ " ## b_enc\n",
483
+ " self.original_weight[\"b_enc\"] = getattr(self.sae, \"b_enc\").data.clone()\n",
484
+ " # mask the weight\n",
485
+ " masked_weight = sae.b_enc[self.feature_idxs]\n",
486
+ " # set the weight\n",
487
+ " setattr(self.sae, \"b_enc\", nn.Parameter(masked_weight))\n",
488
+ "\n",
489
+ " elif self.sae.cfg.architecture == \"gated\":\n",
490
+ "\n",
491
+ " ## b_gate\n",
492
+ " self.original_weight[\"b_gate\"] = getattr(self.sae, \"b_gate\").data.clone()\n",
493
+ " # mask the weight\n",
494
+ " masked_weight = sae.b_gate[self.feature_idxs]\n",
495
+ " # set the weight\n",
496
+ " setattr(self.sae, \"b_gate\", nn.Parameter(masked_weight))\n",
497
+ "\n",
498
+ " ## r_mag\n",
499
+ " self.original_weight[\"r_mag\"] = getattr(self.sae, \"r_mag\").data.clone()\n",
500
+ " # mask the weight\n",
501
+ " masked_weight = sae.r_mag[self.feature_idxs]\n",
502
+ " # set the weight\n",
503
+ " setattr(self.sae, \"r_mag\", nn.Parameter(masked_weight))\n",
504
+ "\n",
505
+ " ## b_mag\n",
506
+ " self.original_weight[\"b_mag\"] = getattr(self.sae, \"b_mag\").data.clone()\n",
507
+ " # mask the weight\n",
508
+ " masked_weight = sae.b_mag[self.feature_idxs]\n",
509
+ " # set the weight\n",
510
+ " setattr(self.sae, \"b_mag\", nn.Parameter(masked_weight))\n",
511
+ " else:\n",
512
+ " raise (ValueError(\"Invalid architecture\"))\n",
513
+ "\n",
514
+ " return self\n",
515
+ "\n",
516
+ " def __exit__(self, exc_type, exc_value, traceback):\n",
517
+ "\n",
518
+ " # set everything back to normal\n",
519
+ " for key, value in self.original_weight.items():\n",
520
+ " setattr(self.sae, key, nn.Parameter(value))"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "metadata": {},
527
+ "outputs": [],
528
+ "source": [
529
+ "import gc\n",
530
+ "import torch\n",
531
+ "\n",
532
+ "gc.collect()\n",
533
+ "torch.cuda.empty_cache()\n",
534
+ "torch.set_grad_enabled(False)\n",
535
+ "\n",
536
+ "\n",
537
+ "def my_function(sae_in):\n",
538
+ " # Your PyTorch code here\n",
539
+ " feature_idxs = list(range(2048))\n",
540
+ " with FeatureMaskingContext(sae, feature_idxs):\n",
541
+ " features = sae(sae_in)\n",
542
+ " print(features.mean())\n",
543
+ "\n",
544
+ "\n",
545
+ "tokens = token_dataset[:64]\n",
546
+ "_, cache = model.run_with_cache(\n",
547
+ " tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=sae.cfg.hook_name\n",
548
+ ")\n",
549
+ "sae_in = cache[sae.cfg.hook_name]"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "execution_count": null,
555
+ "metadata": {},
556
+ "outputs": [],
557
+ "source": [
558
+ "tokens.shape"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": null,
564
+ "metadata": {},
565
+ "outputs": [],
566
+ "source": [
567
+ "sae.W_dec.shape"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "metadata": {},
574
+ "outputs": [],
575
+ "source": [
576
+ "%load_ext memray"
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": null,
582
+ "metadata": {},
583
+ "outputs": [],
584
+ "source": [
585
+ "%%memray_flamegraph --trace-python-allocators --leaks\n",
586
+ "my_function(sae_in)"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": null,
592
+ "metadata": {},
593
+ "outputs": [],
594
+ "source": []
595
+ }
596
+ ],
597
+ "metadata": {
598
+ "kernelspec": {
599
+ "display_name": ".venv",
600
+ "language": "python",
601
+ "name": "python3"
602
+ },
603
+ "language_info": {
604
+ "codemirror_mode": {
605
+ "name": "ipython",
606
+ "version": 3
607
+ },
608
+ "file_extension": ".py",
609
+ "mimetype": "text/x-python",
610
+ "name": "python",
611
+ "nbconvert_exporter": "python",
612
+ "pygments_lexer": "ipython3",
613
+ "version": "3.11.7"
614
+ }
615
+ },
616
+ "nbformat": 4,
617
+ "nbformat_minor": 2
618
+ }
SAEDashboard/pyproject.toml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "sae-dashboard"
3
+ version = "0.7.3"
4
+ description = "Open-source SAE visualizer, based on Anthropic's published visualizer. Forked / Detached from sae_vis."
5
+ authors = ["Callum McDougall <cal.s.mcdougall@gmail.com>", "Joseph Bloom, <jbloomaus@gmail.com>"]
6
+ readme = "README.md"
7
+ license = "MIT"
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.10"
11
+ torch = "^2.0.0"
12
+ einops = ">=0.7.0"
13
+ datasets = "^3.0.0"
14
+ dataclasses-json = "^0.6.4"
15
+ jaxtyping = "^0.2.28"
16
+ transformer-lens = "^2.2.0"
17
+ transformers = "<4.57.0"
18
+ eindex-callum = "^0.1.0"
19
+ rich = "^13.7.1"
20
+ matplotlib = "^3.8.4"
21
+ safetensors = "^0.4.3"
22
+ typer = "^0.12.3"
23
+ sae-lens = "^6.8.0"
24
+ decode-clt = "^0.0.1"
25
+ hf-transfer = "^0.1.9"
26
+
27
+ [tool.poetry.group.dev.dependencies]
28
+ isort = "^5.13.2"
29
+ ruff = "^0.3.7"
30
+ pytest = "^8.1.1"
31
+ ipykernel = "^6.29.4"
32
+ pyright = "^1.1.359"
33
+ pytest-profiling = "^1.7.0"
34
+ memray = "^1.12.0"
35
+ syrupy = "^4.6.1"
36
+ flake8 = "^7.0.0"
37
+ pytest-cov = "^5.0.0"
38
+ black = "^24.4.2"
39
+ pytest-memray = "^1.7.0"
40
+
41
+ [tool.poetry.scripts]
42
+ neuronpedia-runner = "sae_dashboard.neuronpedia.neuronpedia_runner:main"
43
+
44
+ [tool.isort]
45
+ profile = "black"
46
+ src_paths = ["sae_dashboard", "tests"]
47
+
48
+ [tool.pyright]
49
+ typeCheckingMode = "strict"
50
+ reportMissingTypeStubs = "none"
51
+ reportUnknownMemberType = "none"
52
+ reportUnknownArgumentType = "none"
53
+ reportUnknownVariableType = "none"
54
+ reportUntypedFunctionDecorator = "none"
55
+ reportUnnecessaryIsInstance = "none"
56
+ reportUnnecessaryComparison = "none"
57
+ reportConstantRedefinition = "none"
58
+ reportUnknownLambdaType = "none"
59
+ reportPrivateUsage = "none"
60
+ reportPrivateImportUsage = "none"
61
+
62
+ [build-system]
63
+ requires = ["poetry-core"]
64
+ build-backend = "poetry.core.masonry.api"
65
+
66
+ [tool.semantic_release]
67
+ version_variables = ["sae_dashboard/__init__.py:__version__"]
68
+ version_toml = ["pyproject.toml:tool.poetry.version"]
69
+ build_command = "pip install poetry && poetry build"
70
+ branches = { main = { match = "main" } }
SAEDashboard/sae_dashboard/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.7.3"
2
+
3
+ # from .data_fetching_fns import *
4
+ # from .data_storing_fns import *
5
+ # from .html_fns import *
6
+ # from .transformer_lens_wrapper import *
7
+ # from .utils_fns import *
8
+
9
+
10
+ # from autoencoder import AutoEncoder, AutoEncoderConfig
SAEDashboard/sae_dashboard/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (184 Bytes). View file
 
SAEDashboard/sae_dashboard/__pycache__/components.cpython-313.pyc ADDED
Binary file (33.3 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/components_config.cpython-313.pyc ADDED
Binary file (10.9 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/data_parsing_fns.cpython-313.pyc ADDED
Binary file (16.7 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/data_writing_fns.cpython-313.pyc ADDED
Binary file (8 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/dfa_calculator.cpython-313.pyc ADDED
Binary file (6.87 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/feature_data.cpython-313.pyc ADDED
Binary file (10.2 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/feature_data_generator.cpython-313.pyc ADDED
Binary file (14.6 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/html_fns.cpython-313.pyc ADDED
Binary file (11.6 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/layout.cpython-313.pyc ADDED
Binary file (8.84 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/sae_vis_data.cpython-313.pyc ADDED
Binary file (9.15 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/sae_vis_runner.cpython-313.pyc ADDED
Binary file (14.3 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/sequence_data_generator.cpython-313.pyc ADDED
Binary file (13.8 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/transformer_lens_wrapper.cpython-313.pyc ADDED
Binary file (8.23 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/utils_fns.cpython-313.pyc ADDED
Binary file (49.6 kB). View file
 
SAEDashboard/sae_dashboard/__pycache__/vector_vis_data.cpython-313.pyc ADDED
Binary file (9.39 kB). View file
 
SAEDashboard/sae_dashboard/clt_layer_wrapper.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ # Added dataclass, field, asdict
4
+ from dataclasses import asdict, dataclass, field
5
+ from pathlib import Path
6
+
7
+ # import torch.nn as nn # Unused
8
+ # from torch.distributed import ProcessGroup # Unused
9
+ # from types import SimpleNamespace # Unused import
10
+ from typing import ( # Added Optional, Union and List
11
+ TYPE_CHECKING,
12
+ Any,
13
+ List,
14
+ Optional,
15
+ Union,
16
+ )
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from clt.models.activations import BatchTopK # type: ignore
21
+
22
+ if TYPE_CHECKING:
23
+ import torch.distributed # Import for ProcessGroup type hint
24
+ from clt.models.clt import CrossLayerTranscoder # type: ignore
25
+
26
+
27
+ # Placeholder for dist if torch.distributed is not available or initialized
28
+ class MockDist:
29
+ def is_initialized(self) -> bool:
30
+ return False
31
+
32
+ def get_world_size(
33
+ self, group: "Optional[torch.distributed.ProcessGroup]" = None
34
+ ) -> int:
35
+ return 1
36
+
37
+ def all_gather_into_tensor(
38
+ self,
39
+ output_tensor: torch.Tensor,
40
+ input_tensor: torch.Tensor,
41
+ group: "Optional[torch.distributed.ProcessGroup]" = None,
42
+ ) -> None:
43
+ # In non-distributed setting, just copy input to output (assuming output is sized correctly)
44
+ if output_tensor.shape[0] == 1 * input_tensor.shape[0]:
45
+ output_tensor.copy_(input_tensor)
46
+ else:
47
+ # This case shouldn't happen if called correctly, but handle defensively
48
+ raise ValueError(
49
+ "Output tensor size doesn't match input tensor size in mock all_gather"
50
+ )
51
+
52
+ def all_gather(
53
+ self,
54
+ tensor_list: List[torch.Tensor],
55
+ input_tensor: torch.Tensor,
56
+ group: "Optional[torch.distributed.ProcessGroup]" = None,
57
+ ) -> None:
58
+ """Mock all_gather for a list of tensors."""
59
+ if self.get_world_size(group) == 1:
60
+ if len(tensor_list) == 1:
61
+ tensor_list[0].copy_(input_tensor)
62
+ else:
63
+ raise ValueError(
64
+ "tensor_list size must be 1 in mock all_gather when world_size is 1"
65
+ )
66
+ else:
67
+ # This mock doesn't support actual gathering for world_size > 1.
68
+ # It's primarily for the dist.all_gather call in _gather_weight,
69
+ # which should ideally not proceed if world_size > 1 and dist is MockDist.
70
+ # However, _gather_weight checks dist.is_initialized() and dist.get_world_size() first.
71
+ raise NotImplementedError(
72
+ "MockDist.all_gather not implemented for world_size > 1"
73
+ )
74
+
75
+
76
+ try:
77
+ import torch.distributed as dist
78
+
79
+ if not dist.is_available():
80
+ dist = MockDist() # type: ignore
81
+ except ImportError:
82
+ dist = MockDist() # type: ignore
83
+
84
+
85
+ @dataclass
86
+ class CLTMetadata:
87
+ """Simple metadata class for CLT wrapper compatibility."""
88
+
89
+ hook_name: str
90
+ hook_layer: int
91
+ model_name: Optional[str] = None
92
+ context_size: Optional[int] = None
93
+ prepend_bos: bool = True
94
+ hook_head_index: Optional[int] = None
95
+ seqpos_slice: Optional[slice] = None
96
+
97
+
98
+ @dataclass
99
+ class CLTWrapperConfig:
100
+ """Configuration dataclass for the CLTLayerWrapper."""
101
+
102
+ # Fields without defaults first
103
+ d_sae: int
104
+ d_in: int
105
+ hook_name: str
106
+ hook_layer: int
107
+ dtype: str
108
+ device: str
109
+ # Fields with defaults last
110
+ architecture: str = "jumprelu"
111
+ hook_head_index: Optional[int] = None
112
+ model_name: Optional[str] = None
113
+ dataset_path: Optional[str] = None
114
+ context_size: Optional[int] = None
115
+ prepend_bos: bool = True
116
+ normalize_activations: bool = False
117
+ dataset_trust_remote_code: bool = False
118
+ seqpos_slice: Optional[slice] = None
119
+ model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
120
+ metadata: Optional[CLTMetadata] = None
121
+
122
+ def to_dict(self) -> dict[str, Any]:
123
+ """Convert config to dictionary for compatibility with SAE interface."""
124
+ return asdict(self)
125
+
126
+
127
+ class CLTLayerWrapper:
128
+ """Wraps a single layer of a CrossLayerTranscoder to mimic the SAE interface.
129
+
130
+ This allows reusing existing dashboard components that expect an SAE object.
131
+ It specifically provides access to the encoder and the *same-layer* decoder weights
132
+ for the specified layer index.
133
+ """
134
+
135
+ cfg: CLTWrapperConfig # Add type hint for the config attribute
136
+ threshold: Optional[torch.Tensor] = (
137
+ None # For JumpReLU, set by FeatureMaskingContext
138
+ )
139
+
140
+ def __init__(
141
+ self,
142
+ clt: "CrossLayerTranscoder",
143
+ layer_idx: int,
144
+ clt_model_dir_path: Optional[str] = None,
145
+ ):
146
+ self.clt = clt
147
+ self.layer_idx = layer_idx
148
+ self.device = clt.device
149
+ self.dtype = clt.dtype
150
+ self.hook_z_reshaping_mode = False # Added to satisfy SAE interface
151
+
152
+ # Validate layer index
153
+ if not (0 <= layer_idx < clt.config.num_layers):
154
+ raise ValueError(
155
+ f"Invalid layer_idx {layer_idx} for CLT with {clt.config.num_layers} layers."
156
+ )
157
+
158
+ # --- Create the Wrapper Config ---
159
+ # Try to get model_name from the underlying clt config if it exists
160
+ clt_model_name = getattr(clt.config, "model_name", None)
161
+ clt_dataset_path = getattr(clt.config, "dataset_path", None)
162
+ clt_context_size = getattr(
163
+ clt.config, "context_size", 128
164
+ ) # Default to 128 if not set
165
+ clt_prepend_bos = getattr(clt.config, "prepend_bos", True)
166
+ # Use the activation_fn from CLT config for the wrapper's architecture and encode method
167
+ self.activation_fn = getattr(clt.config, "activation_fn", "jumprelu")
168
+ clt_model_from_pretrained_kwargs = getattr(
169
+ clt.config, "model_from_pretrained_kwargs", {}
170
+ )
171
+
172
+ # --- Load CLT-specific normalization stats if applicable ---
173
+ self.clt_norm_mean: Optional[torch.Tensor] = None
174
+ self.clt_norm_std: Optional[torch.Tensor] = None
175
+ wrapper_will_normalize_specifically = False
176
+ clt_norm_method = getattr(clt.config, "normalization_method", "none")
177
+
178
+ if clt_norm_method in ["auto", "estimated_mean_std", "mean_std"]:
179
+ if clt_model_dir_path:
180
+ norm_stats_file = Path(clt_model_dir_path) / "norm_stats.json"
181
+ if norm_stats_file.exists():
182
+ try:
183
+ with open(norm_stats_file, "r") as f:
184
+ stats_data = json.load(f)
185
+
186
+ layer_stats = stats_data.get(str(self.layer_idx), {}).get(
187
+ "inputs", {}
188
+ )
189
+ mean_vals = layer_stats.get("mean")
190
+ std_vals = layer_stats.get("std")
191
+
192
+ if mean_vals is not None and std_vals is not None:
193
+ self.clt_norm_mean = torch.tensor(
194
+ mean_vals, device=self.device, dtype=torch.float32
195
+ ).unsqueeze(0)
196
+ self.clt_norm_std = (
197
+ torch.tensor(
198
+ std_vals, device=self.device, dtype=torch.float32
199
+ )
200
+ + 1e-6
201
+ ).unsqueeze(0)
202
+ if torch.any(self.clt_norm_std <= 0):
203
+ print(
204
+ f"Warning: Loaded std for layer {self.layer_idx} contains non-positive values after adding epsilon. Disabling specific normalization."
205
+ )
206
+ self.clt_norm_mean = None
207
+ self.clt_norm_std = None
208
+ else:
209
+ wrapper_will_normalize_specifically = True
210
+ print(
211
+ f"CLTLayerWrapper: Loaded norm_stats.json for layer {self.layer_idx}. Wrapper will apply specific normalization."
212
+ )
213
+ else:
214
+ print(
215
+ f"Warning: norm_stats.json found, but missing 'mean' or 'std' for layer {self.layer_idx} inputs. Wrapper will not normalize specifically."
216
+ )
217
+ except Exception as e:
218
+ print(
219
+ f"Warning: Error loading or parsing norm_stats.json from {norm_stats_file}: {e}. Wrapper will not normalize specifically."
220
+ )
221
+ else:
222
+ print(
223
+ f"Warning: normalization_method is '{clt_norm_method}' but norm_stats.json not found at {norm_stats_file}. Wrapper will not normalize specifically."
224
+ )
225
+ else:
226
+ print(
227
+ f"Warning: normalization_method is '{clt_norm_method}' but clt_model_dir_path not provided. Wrapper cannot load norm_stats.json and will not normalize specifically."
228
+ )
229
+
230
+ # Determine normalize_activations flag for ActivationsStore based on CLT config and wrapper's capability
231
+ # This flag in self.cfg controls ActivationsStore. ActivationsStore should only normalize if the wrapper *isn't* doing specific normalization AND the CLT expected some form of normalization.
232
+ clt_config_indicated_normalization = clt_norm_method != "none"
233
+ normalize_activations_for_store = clt_config_indicated_normalization and (
234
+ not wrapper_will_normalize_specifically
235
+ )
236
+ if normalize_activations_for_store:
237
+ print(
238
+ f"CLTLayerWrapper: Setting normalize_activations=True for ActivationsStore (CLT method: {clt_norm_method}, wrapper specific norm: False)."
239
+ )
240
+ elif clt_config_indicated_normalization and wrapper_will_normalize_specifically:
241
+ print(
242
+ f"CLTLayerWrapper: Setting normalize_activations=False for ActivationsStore (CLT method: {clt_norm_method}, wrapper specific norm: True)."
243
+ )
244
+ else: # not clt_config_indicated_normalization
245
+ print(
246
+ f"CLTLayerWrapper: Setting normalize_activations=False for ActivationsStore (CLT method: {clt_norm_method})."
247
+ )
248
+
249
+ # Initialize self.threshold if activation is jumprelu
250
+ # This must happen AFTER self.activation_fn, self.device, self.dtype, self.layer_idx, and self.clt are set.
251
+ if self.activation_fn == "jumprelu":
252
+ if (
253
+ hasattr(self.clt, "log_threshold")
254
+ and self.clt.log_threshold is not None
255
+ ):
256
+ if 0 <= self.layer_idx < self.clt.log_threshold.shape[0]:
257
+ # The log_threshold from CLT is [num_layers, num_features]
258
+ # We need the threshold for the current layer_idx
259
+ layer_thresholds = torch.exp(
260
+ self.clt.log_threshold[self.layer_idx].clone().detach()
261
+ )
262
+ self.threshold = layer_thresholds.to(
263
+ device=self.device, dtype=self.dtype
264
+ )
265
+ print(
266
+ f"CLTLayerWrapper: Initialized self.threshold for layer {self.layer_idx} from clt.log_threshold."
267
+ )
268
+ else:
269
+ print(
270
+ f"Warning: CLTLayerWrapper layer_idx {self.layer_idx} is out of bounds for clt.log_threshold "
271
+ f"(shape {self.clt.log_threshold.shape}). self.threshold will be None."
272
+ )
273
+ self.threshold = None
274
+ else:
275
+ print(
276
+ f"Warning: Underlying CLT model for layer {self.layer_idx} does not have 'log_threshold' or it's None, "
277
+ f"but activation_fn is 'jumprelu'. self.threshold will be None."
278
+ )
279
+ self.threshold = None
280
+ # else: self.threshold remains its default None, which is fine for other activation functions.
281
+
282
+ # Get the hook name using prioritized templates
283
+ hook_name_template = getattr(clt.config, "tl_input_template", None)
284
+ if hook_name_template:
285
+ hook_name = hook_name_template.format(layer_idx)
286
+ print(f"Using TL hook name template: {hook_name_template} -> {hook_name}")
287
+ else:
288
+ hook_name_template = getattr(clt.config, "mlp_input_template", None)
289
+ if hook_name_template:
290
+ hook_name = hook_name_template.format(layer_idx)
291
+ print(
292
+ f"Warning: tl_input_template not found. Using mlp_input_template: {hook_name_template} -> {hook_name}"
293
+ )
294
+ else:
295
+ # Fallback for older configs without any template
296
+ hook_name = f"blocks.{layer_idx}.hook_mlp_in"
297
+ print(
298
+ f"Warning: Neither tl_input_template nor mlp_input_template found. Falling back to hardcoded: {hook_name}"
299
+ )
300
+
301
+ self.cfg = CLTWrapperConfig(
302
+ d_sae=clt.config.num_features, # This is the d_sae of the *entire* CLT layer, not a sub-batch
303
+ d_in=clt.config.d_model,
304
+ hook_name=hook_name,
305
+ hook_layer=layer_idx,
306
+ hook_head_index=None,
307
+ dtype=str(self.dtype).replace("torch.", ""),
308
+ device=str(self.device),
309
+ architecture=self.activation_fn, # Use the determined activation_fn
310
+ model_name=clt_model_name,
311
+ dataset_path=clt_dataset_path,
312
+ context_size=clt_context_size,
313
+ prepend_bos=clt_prepend_bos,
314
+ normalize_activations=normalize_activations_for_store,
315
+ dataset_trust_remote_code=False,
316
+ seqpos_slice=None,
317
+ model_from_pretrained_kwargs=clt_model_from_pretrained_kwargs,
318
+ metadata=CLTMetadata(
319
+ hook_name=hook_name,
320
+ hook_layer=layer_idx,
321
+ model_name=clt_model_name,
322
+ context_size=clt_context_size,
323
+ prepend_bos=clt_prepend_bos,
324
+ hook_head_index=None,
325
+ seqpos_slice=None,
326
+ ),
327
+ )
328
+ # --- End Config Creation ---
329
+
330
+ # Extract and potentially gather weights
331
+ # Ensure weights are detached and cloned to avoid modifying the original CLT
332
+ # Original W_enc from CLT encoder module is [d_sae_layer, d_model]
333
+ # We transpose to match sae-lens W_enc convention: [d_model, d_sae_layer]
334
+ self.W_enc = (
335
+ self._gather_encoder_weight(clt.encoder_module.encoders[layer_idx].weight) # type: ignore
336
+ .t()
337
+ .contiguous()
338
+ )
339
+ # For W_dec, use the decoder from the same layer to itself
340
+ decoder_key = f"{layer_idx}->{layer_idx}"
341
+ if decoder_key not in clt.decoder_module.decoders: # type: ignore
342
+ raise KeyError(f"Decoder key {decoder_key} not found in CLT decoders.")
343
+ # Original W_dec from CLT decoder module is [d_model, d_sae_layer]
344
+ # We transpose to match sae-lens W_dec convention: [d_sae_layer, d_model]
345
+ self.W_dec = (
346
+ self._gather_decoder_weight(clt.decoder_module.decoders[decoder_key].weight) # type: ignore
347
+ .t()
348
+ .contiguous()
349
+ )
350
+
351
+ self.b_enc = self._gather_encoder_bias(
352
+ clt.encoder_module.encoders[layer_idx].bias_param # type: ignore
353
+ )
354
+ # For b_dec, use the bias from the same-layer decoder
355
+ self.b_dec = self._gather_decoder_bias(
356
+ clt.decoder_module.decoders[decoder_key].bias_param # type: ignore
357
+ )
358
+
359
+ # Cache for folded weights if needed
360
+ self._W_dec_folded = False
361
+ # Thresholds for JumpReLU will be handled by FeatureMaskingContext if architecture is 'jumprelu'
362
+ # by setting self.threshold directly on the wrapper instance.
363
+
364
+ # --- Façade methods mimicking SAE --- #
365
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
366
+ """
367
+ Encodes input using the CLTLayerWrapper's own W_enc and b_enc,
368
+ respecting masks applied by FeatureMaskingContext.
369
+ Applies the activation function specified in self.activation_fn.
370
+ """
371
+ # x is [..., d_model]
372
+ # self.W_enc after masking (by FeatureMaskingContext) should be [d_model, N_FEATURES_IN_BATCH]
373
+ # self.b_enc after masking (by FeatureMaskingContext) should be [N_FEATURES_IN_BATCH]
374
+
375
+ original_shape = x.shape
376
+ if x.ndim > 2: # Ensure x is [N, d_model] for F.linear
377
+ # self.cfg.d_in should be d_model
378
+ x_reshaped = x.reshape(-1, self.cfg.d_in)
379
+ else:
380
+ x_reshaped = x
381
+
382
+ x_to_process = x_reshaped
383
+ # Apply CLT-specific normalization if stats were loaded
384
+ if self.clt_norm_mean is not None and self.clt_norm_std is not None:
385
+ # Ensure calculation is done in float32 for precision, then cast back
386
+ x_float32 = x_to_process.to(torch.float32)
387
+ normalized_x = (x_float32 - self.clt_norm_mean) / self.clt_norm_std
388
+ x_to_process = normalized_x.to(x.dtype)
389
+
390
+ # F.linear(input, weight, bias) expects weight to be [out_features, in_features]
391
+ # self.W_enc is [d_model, N_FEATURES_IN_BATCH], so its transpose is [N_FEATURES_IN_BATCH, d_model]
392
+ hidden_pre = F.linear(
393
+ x_to_process, self.W_enc.T, self.b_enc
394
+ ) # Output: [N, N_FEATURES_IN_BATCH]
395
+
396
+ # Apply activation function
397
+ if self.activation_fn == "relu":
398
+ encoded_acts = torch.relu(hidden_pre)
399
+ elif self.activation_fn == "jumprelu":
400
+ if not hasattr(self, "threshold") or self.threshold is None:
401
+ raise AttributeError(
402
+ "JumpReLU activation selected, but 'self.threshold' is not available on CLTLayerWrapper. "
403
+ "FeatureMaskingContext should set this if architecture is 'jumprelu'."
404
+ )
405
+ encoded_acts = torch.where(
406
+ hidden_pre > self.threshold, hidden_pre, torch.zeros_like(hidden_pre)
407
+ )
408
+ elif self.activation_fn == "batchtopk":
409
+ k_val: float
410
+ batchtopk_k_abs = getattr(self.clt.config, "batchtopk_k", None)
411
+ batchtopk_k_frac = getattr(self.clt.config, "batchtopk_frac", None)
412
+
413
+ if batchtopk_k_abs is not None:
414
+ # This k is global. For the current batch of features, we use a per-layer approximation.
415
+ k_val = float(batchtopk_k_abs) / self.clt.config.num_layers
416
+ k_val = max(
417
+ 1.0, k_val
418
+ ) # Ensure at least 1 feature is kept if k/num_layers is small
419
+ elif batchtopk_k_frac is not None:
420
+ k_val = float(
421
+ batchtopk_k_frac
422
+ ) # Fraction applies directly to current N_FEATURES_IN_BATCH
423
+ else:
424
+ # Fallback: if neither k nor frac is specified, keep all features currently being processed.
425
+ # This matches the fallback in CrossLayerTranscoder.encode for its per-layer batchtopk.
426
+ print(
427
+ f"Warning: CLTLayerWrapper using batchtopk, but neither 'batchtopk_k' nor 'batchtopk_frac' defined in CLTConfig. Defaulting to keeping all {hidden_pre.size(-1)} features in the current batch."
428
+ )
429
+ k_val = float(hidden_pre.size(-1))
430
+
431
+ straight_through_flag = getattr(
432
+ self.clt.config, "batchtopk_straight_through", False
433
+ )
434
+ encoded_acts = BatchTopK.apply(hidden_pre, k_val, straight_through_flag)
435
+ else:
436
+ raise ValueError(
437
+ f"Unsupported activation function in CLTLayerWrapper: {self.activation_fn}"
438
+ )
439
+
440
+ if x.ndim > 2:
441
+ # Reshape back to original batch/sequence dimensions, with the last dim being N_FEATURES_IN_BATCH
442
+ encoded_acts = encoded_acts.reshape(*original_shape[:-1], -1) # type: ignore
443
+
444
+ return encoded_acts # type: ignore
445
+
446
+ def turn_off_forward_pass_hook_z_reshaping(self):
447
+ """Stub method to satisfy SAE interface. CLTWrapper does not use this."""
448
+ # This mode is not applicable to CLTLayerWrapper, so this method is a no-op.
449
+ pass
450
+
451
+ # Note: CLTLayerWrapper does not have a separate `decode` method façade
452
+ # because the dashboard primarily uses W_dec directly for analysis (e.g., logits).
453
+ # The CLT's actual decode logic (summing across layers) isn't needed here.
454
+
455
+ def fold_W_dec_norm(self):
456
+ """Folds the L2 norm of W_dec into W_enc and b_enc.
457
+
458
+ Mirrors the logic in sae_lens.SAE.fold_W_dec_norm.
459
+ Important for ensuring that W_enc activations directly correspond
460
+ to the output norm when using the wrapped W_dec.
461
+ """
462
+ if self._W_dec_folded:
463
+ print("Warning: W_dec norm already folded.")
464
+ return
465
+
466
+ if self.W_dec is None or self.W_enc is None:
467
+ print("Warning: Cannot fold W_dec norm, weights not available.")
468
+ return
469
+
470
+ # Detach W_dec before calculating norm to avoid gradient issues
471
+ # W_dec is [N_FEATURES_IN_BATCH, d_model] (after masking context and init)
472
+ # Norm should be taken over d_model dim (dim=1)
473
+
474
+ # Use W_dec with its original dtype for norm calculation
475
+ w_dec_for_norm = self.W_dec.detach()
476
+ w_dec_norms = torch.norm(
477
+ w_dec_for_norm, dim=1, keepdim=True
478
+ ) # [N_FEATURES_IN_BATCH, 1]
479
+
480
+ w_dec_norms = torch.where(
481
+ w_dec_norms == 0, torch.ones_like(w_dec_norms), w_dec_norms
482
+ )
483
+
484
+ # self.W_enc is [d_model, N_FEATURES_IN_BATCH]
485
+ # We want to scale each column of W_enc (each feature's encoder vector)
486
+ # by the corresponding feature's w_dec_norm.
487
+ # Ensure dtypes match for multiplication, then cast W_enc back if necessary
488
+ original_w_enc_dtype = self.W_enc.dtype
489
+ self.W_enc.data = (self.W_enc.data.to(w_dec_norms.dtype) * w_dec_norms.t()).to(
490
+ original_w_enc_dtype
491
+ )
492
+
493
+ if self.b_enc is not None:
494
+ # self.b_enc is [N_FEATURES_IN_BATCH]
495
+ # w_dec_norms.squeeze() is [N_FEATURES_IN_BATCH]
496
+ original_b_enc_dtype = self.b_enc.dtype
497
+ self.b_enc.data = (
498
+ self.b_enc.data.to(w_dec_norms.dtype) * w_dec_norms.squeeze()
499
+ ).to(original_b_enc_dtype)
500
+
501
+ # Store the norms for potential unfolding or reference
502
+ self._w_dec_norms_backup = w_dec_norms
503
+ self._W_dec_folded = True
504
+ print("Folded W_dec norm into W_enc and b_enc.")
505
+
506
+ def unfold_W_dec_norm(self):
507
+ """Unfolds the L2 norm of W_dec from W_enc and b_enc."""
508
+ if not self._W_dec_folded or not hasattr(self, "_w_dec_norms_backup"):
509
+ print("Warning: W_dec norm not folded or backup norms not found.")
510
+ return
511
+
512
+ if self.W_enc is None:
513
+ print("Warning: Cannot unfold W_dec norm, W_enc not available.")
514
+ return
515
+
516
+ # Retrieve the norms used for folding
517
+ w_dec_norms = self._w_dec_norms_backup
518
+ # Avoid division by zero (should have been handled in fold, but double check)
519
+ w_dec_norms = torch.where(
520
+ w_dec_norms == 0, torch.ones_like(w_dec_norms), w_dec_norms
521
+ )
522
+
523
+ original_w_enc_dtype = self.W_enc.dtype
524
+ self.W_enc.data = (self.W_enc.data.to(w_dec_norms.dtype) / w_dec_norms.t()).to(
525
+ original_w_enc_dtype
526
+ )
527
+
528
+ if self.b_enc is not None:
529
+ original_b_enc_dtype = self.b_enc.dtype
530
+ self.b_enc.data = (
531
+ self.b_enc.data.to(w_dec_norms.dtype) / w_dec_norms.squeeze()
532
+ ).to(original_b_enc_dtype)
533
+
534
+ del self._w_dec_norms_backup
535
+ self._W_dec_folded = False
536
+ print("Unfolded W_dec norm from W_enc and b_enc.")
537
+
538
+ def to(self, device: Union[str, torch.device]):
539
+ """Moves the wrapper and underlying components to the specified device."""
540
+ target_device = torch.device(device)
541
+
542
+ # Move the underlying CLT model
543
+ try:
544
+ self.clt.to(target_device)
545
+ except Exception as e:
546
+ print(
547
+ f"Warning: Failed to move underlying CLT model to {target_device}: {e}"
548
+ )
549
+ # Continue trying to move wrapper components
550
+
551
+ # Move the wrapper's stored tensors
552
+ if self.W_enc is not None:
553
+ self.W_enc = self.W_enc.to(target_device)
554
+ if self.W_dec is not None:
555
+ self.W_dec = self.W_dec.to(target_device)
556
+ if self.b_enc is not None:
557
+ self.b_enc = self.b_enc.to(target_device)
558
+ if self.b_dec is not None:
559
+ self.b_dec = self.b_dec.to(target_device)
560
+ if (
561
+ hasattr(self, "_w_dec_norms_backup")
562
+ and self._w_dec_norms_backup is not None
563
+ ):
564
+ self._w_dec_norms_backup = self._w_dec_norms_backup.to(target_device)
565
+
566
+ # Update device attributes
567
+ self.device = target_device
568
+ self.cfg.device = str(target_device)
569
+
570
+ # Update activation_fn related thresholds if they exist (e.g. for JumpReLU)
571
+ if hasattr(self, "threshold") and self.threshold is not None:
572
+ self.threshold = self.threshold.to(target_device)
573
+
574
+ if self.clt_norm_mean is not None: # Added to move norm stats
575
+ self.clt_norm_mean = self.clt_norm_mean.to(target_device)
576
+ if self.clt_norm_std is not None: # Added to move norm stats
577
+ self.clt_norm_std = self.clt_norm_std.to(target_device)
578
+
579
+ print(f"Moved CLTLayerWrapper to {target_device}")
580
+ return self
581
+
582
+ # --- Helper methods for Tensor Parallelism --- #
583
+
584
+ def _gather_weight(
585
+ self,
586
+ weight_shard: torch.Tensor,
587
+ gather_dim: int = 0,
588
+ target_full_dim_size: Optional[int] = None,
589
+ ) -> torch.Tensor:
590
+ """Gather a weight tensor shard across TP ranks."""
591
+ if not dist.is_initialized() or dist.get_world_size() == 1:
592
+ return weight_shard.clone().detach()
593
+
594
+ world_size = dist.get_world_size()
595
+ # Create a list to hold all gathered tensors
596
+ tensor_list = [torch.empty_like(weight_shard) for _ in range(world_size)]
597
+ dist.all_gather(tensor_list, weight_shard)
598
+
599
+ # Concatenate along the specified dimension
600
+ full_weight = torch.cat(tensor_list, dim=gather_dim)
601
+
602
+ # Trim padding if necessary
603
+ if target_full_dim_size is not None:
604
+ if gather_dim == 0:
605
+ if full_weight.shape[0] > target_full_dim_size:
606
+ full_weight = full_weight[:target_full_dim_size, :]
607
+ elif gather_dim == 1:
608
+ if full_weight.shape[1] > target_full_dim_size:
609
+ full_weight = full_weight[:, :target_full_dim_size]
610
+ # Add other gather_dim cases if needed
611
+
612
+ return full_weight.detach()
613
+
614
+ def _gather_encoder_weight(self, weight_shard: torch.Tensor) -> torch.Tensor:
615
+ """Gather ColumnParallelLinear weight (sharded along output/feature dim)."""
616
+ # ColumnParallel weight is [d_sae_local, d_model]
617
+ # We need to gather along dim 0 to get [d_sae_full_for_layer, d_model]
618
+ return self._gather_weight(
619
+ weight_shard,
620
+ gather_dim=0,
621
+ target_full_dim_size=self.clt.config.num_features,
622
+ )
623
+
624
+ def _gather_decoder_weight(self, weight_shard: torch.Tensor) -> torch.Tensor:
625
+ """Gather RowParallelLinear weight (sharded along input/feature dim)."""
626
+ # RowParallel weight is [d_model, d_sae_local]
627
+ # We need to gather along dim 1 to get [d_model, d_sae_full_for_layer]
628
+ return self._gather_weight(
629
+ weight_shard,
630
+ gather_dim=1,
631
+ target_full_dim_size=self.clt.config.num_features,
632
+ )
633
+
634
+ def _gather_bias(
635
+ self,
636
+ bias_shard: Optional[torch.Tensor],
637
+ gather_dim: int = 0,
638
+ target_full_dim_size: Optional[int] = None,
639
+ ) -> Optional[torch.Tensor]:
640
+ """Gather a bias tensor shard across TP ranks."""
641
+ if bias_shard is None:
642
+ return None
643
+ # Biases are typically sharded along the same dimension as the weight's corresponding output dim
644
+ return self._gather_weight(
645
+ bias_shard, gather_dim=gather_dim, target_full_dim_size=target_full_dim_size
646
+ )
647
+
648
+ def _gather_encoder_bias(
649
+ self, bias_shard_candidate: Optional[torch.Tensor]
650
+ ) -> Optional[torch.Tensor]:
651
+ """Gather ColumnParallelLinear bias (sharded along output/feature dim).
652
+
653
+ Defensively checks if the provided candidate is actually a Tensor.
654
+ """
655
+ # Check if the provided object is a Tensor
656
+ if isinstance(bias_shard_candidate, torch.Tensor):
657
+ # Encoder bias shape [d_sae_local], gather along dim 0
658
+ return self._gather_bias(
659
+ bias_shard_candidate,
660
+ gather_dim=0,
661
+ target_full_dim_size=self.clt.config.num_features,
662
+ )
663
+ else:
664
+ # If it's None, bool, or anything else, treat as no bias
665
+ return None
666
+
667
+ def _gather_decoder_bias(
668
+ self, bias_shard_candidate: Optional[torch.Tensor]
669
+ ) -> Optional[torch.Tensor]:
670
+ """Gather RowParallelLinear bias (NOT sharded, but might need broadcast/check).
671
+
672
+ Defensively checks if the provided candidate is actually a Tensor.
673
+ """
674
+ # Check if the provided object is a Tensor
675
+ if isinstance(bias_shard_candidate, torch.Tensor):
676
+ # RowParallelLinear bias is typically not sharded (added after all-reduce)
677
+ # However, let's check world size and return a clone if TP=1, or verify replication if TP>1
678
+ if not dist.is_initialized() or dist.get_world_size() == 1:
679
+ return bias_shard_candidate.clone().detach()
680
+
681
+ # In TP > 1, the bias should be identical across ranks. Verify this.
682
+ world_size = dist.get_world_size()
683
+ tensor_list = [
684
+ torch.empty_like(bias_shard_candidate) for _ in range(world_size)
685
+ ]
686
+ dist.all_gather(tensor_list, bias_shard_candidate)
687
+ # Check if all gathered biases are the same
688
+ for i in range(1, world_size):
689
+ if not torch.equal(tensor_list[0], tensor_list[i]):
690
+ raise RuntimeError(
691
+ "RowParallelLinear bias shards are not identical across TP ranks, which is unexpected."
692
+ )
693
+ # Return the bias from rank 0 (or any rank, as they are identical)
694
+ return tensor_list[0].clone().detach()
695
+ else:
696
+ # If it's None, bool, or anything else, treat as no bias
697
+ return None
SAEDashboard/sae_dashboard/components.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import Any, Callable, List
5
+
6
+ import numpy as np
7
+ from dataclasses_json import dataclass_json
8
+
9
+ from sae_dashboard.components_config import (
10
+ ActsHistogramConfig,
11
+ FeatureTablesConfig,
12
+ LogitsHistogramConfig,
13
+ LogitsTableConfig,
14
+ PromptConfig,
15
+ SequencesConfig,
16
+ )
17
+ from sae_dashboard.html_fns import HTML, bgColorMap, uColorMap
18
+ from sae_dashboard.utils_fns import (
19
+ HistogramData,
20
+ max_or_1,
21
+ to_str_tokens,
22
+ unprocess_str_tok,
23
+ )
24
+
25
+ PRECISION = 4
26
+
27
+
28
+ @dataclass
29
+ class DecoderWeightsDistribution:
30
+ n_heads: int
31
+ allocation_by_head: List[float]
32
+
33
+
34
+ @dataclass_json
35
+ @dataclass
36
+ class FeatureTablesData:
37
+ """
38
+ This contains all the data necessary to make the left-hand tables in prompt-centric visualization. See diagram
39
+ in readme:
40
+
41
+ https://github.com/callummcdougall/sae_vis#data_storing_fnspy
42
+
43
+ Inputs:
44
+ neuron_alignment...
45
+ The data for the neuron alignment table (each of its 3 columns). In other words, the data containing which
46
+ neurons in the transformer the encoder feature is most aligned with.
47
+
48
+ correlated_neurons...
49
+ The data for the correlated neurons table (each of its 3 columns). In other words, the data containing which
50
+ neurons in the transformer are most correlated with the encoder feature.
51
+
52
+ correlated_features...
53
+ The data for the correlated features table (each of its 3 columns). In other words, the data containing
54
+ which features in this encoder are most correlated with each other.
55
+
56
+ correlated_b_features...
57
+ The data for the correlated features table (each of its 3 columns). In other words, the data containing
58
+ which features in encoder-B are most correlated with those in the original encoder. Note, this one might be
59
+ absent if we're not using a B-encoder.
60
+ """
61
+
62
+ neuron_alignment_indices: list[int] = field(default_factory=list)
63
+ neuron_alignment_values: list[float] = field(default_factory=list)
64
+ neuron_alignment_l1: list[float] = field(default_factory=list)
65
+ correlated_neurons_indices: list[int] = field(default_factory=list)
66
+ correlated_neurons_pearson: list[float] = field(default_factory=list)
67
+ correlated_neurons_cossim: list[float] = field(default_factory=list)
68
+ correlated_features_indices: list[int] = field(default_factory=list)
69
+ correlated_features_pearson: list[float] = field(default_factory=list)
70
+ correlated_features_cossim: list[float] = field(default_factory=list)
71
+ correlated_b_features_indices: list[int] = field(default_factory=list)
72
+ correlated_b_features_pearson: list[float] = field(default_factory=list)
73
+ correlated_b_features_cossim: list[float] = field(default_factory=list)
74
+
75
+ def _get_html_data(
76
+ self,
77
+ cfg: FeatureTablesConfig,
78
+ decode_fn: Callable[[int | list[int]], str | list[str]],
79
+ id_suffix: str,
80
+ column: int | tuple[int, int],
81
+ component_specific_kwargs: dict[str, Any] = {},
82
+ ) -> HTML:
83
+ """
84
+ Returns the HTML for the left-hand tables, wrapped in a 'grid-column' div.
85
+
86
+ Note, we only ever use this obj in the context of the left-hand column of the feature-centric vis, and it's
87
+ always the same width & height, which is why there's no customization available for this function.
88
+ """
89
+ # Read HTML from file, and replace placeholders with real ID values
90
+ html_str = (
91
+ Path(__file__).parent / "html" / "feature_tables_template.html"
92
+ ).read_text()
93
+ html_str = html_str.replace("FEATURE_TABLES_ID", f"feature-tables-{id_suffix}")
94
+
95
+ # Create dictionary storing the data
96
+ data: dict[str, list[dict[str, str | float]]] = {}
97
+
98
+ # Store the neuron alignment data, if it exists
99
+ if len(self.neuron_alignment_indices) > 0:
100
+ assert len(self.neuron_alignment_indices) >= cfg.n_rows, "Not enough rows!"
101
+ data["neuronAlignment"] = [
102
+ {
103
+ "index": index,
104
+ "value": f"{value:+.3f}",
105
+ "percentageL1": f"{percent_l1:.1%}",
106
+ }
107
+ for index, value, percent_l1 in zip(
108
+ self.neuron_alignment_indices,
109
+ self.neuron_alignment_values,
110
+ self.neuron_alignment_l1,
111
+ )
112
+ ]
113
+
114
+ # Store the other 3, if they exist (they're all in the same format, so we can do it in a for loop)
115
+ for name, js_name in zip(
116
+ ["correlated_neurons", "correlated_features", "correlated_b_features"],
117
+ ["correlatedNeurons", "correlatedFeatures", "correlatedFeaturesB"],
118
+ ):
119
+ if len(getattr(self, f"{name}_indices")) > 0:
120
+ # assert len(getattr(self, f"{name}_indices")) >= cfg.n_rows, "Not enough rows!"
121
+ data[js_name] = [
122
+ {
123
+ "index": index,
124
+ "value": f"{value:+.3f}",
125
+ "percentageL1": f"{percent_L1:+.3f}",
126
+ }
127
+ for index, value, percent_L1 in zip(
128
+ getattr(self, f"{name}_indices")[: cfg.n_rows],
129
+ getattr(self, f"{name}_pearson")[: cfg.n_rows],
130
+ getattr(self, f"{name}_cossim")[: cfg.n_rows],
131
+ )
132
+ ]
133
+
134
+ return HTML(
135
+ html_data={column: html_str},
136
+ js_data={"featureTablesData": {id_suffix: data}},
137
+ )
138
+
139
+
140
+ @dataclass_json
141
+ @dataclass
142
+ class ActsHistogramData(HistogramData):
143
+ def _get_html_data(
144
+ self,
145
+ cfg: ActsHistogramConfig,
146
+ decode_fn: Callable[[int | list[int]], str | list[str]],
147
+ id_suffix: str,
148
+ column: int | tuple[int, int],
149
+ component_specific_kwargs: dict[str, Any] = {},
150
+ ) -> HTML:
151
+ """
152
+ Converts data -> HTML object, for the feature activations histogram (i.e. the histogram over all sampled tokens,
153
+ showing the distribution of activations for this feature).
154
+ """
155
+ # We can't post-hoc change the number of bins, so check this wasn't changed in the config
156
+ # assert cfg.n_bins == len(self.bar_heights),\
157
+ # "Can't post-hoc change `n_bins` in histogram config - you need to regenerate data."
158
+
159
+ # Read HTML from file, and replace placeholders with real ID values
160
+ html_str = (
161
+ Path(__file__).parent / "html" / "acts_histogram_template.html"
162
+ ).read_text()
163
+ html_str = html_str.replace("HISTOGRAM_ACTS_ID", f"histogram-acts-{id_suffix}")
164
+
165
+ # Process colors for frequency histogram; it's darker at higher values
166
+ bar_values_normed = [
167
+ (0.4 * max(self.bar_values) + 0.6 * v)
168
+ / max(max(self.bar_values), 1e-6) # avoid divide by zero
169
+ for v in self.bar_values
170
+ ]
171
+ bar_colors = [bgColorMap(v) for v in bar_values_normed]
172
+
173
+ # Next we create the data dict
174
+ data: dict[str, Any] = {
175
+ "y": self.bar_heights,
176
+ "x": self.bar_values,
177
+ "ticks": self.tick_vals,
178
+ "colors": bar_colors,
179
+ }
180
+ if self.title is not None:
181
+ data["title"] = self.title
182
+
183
+ return HTML(
184
+ html_data={column: html_str},
185
+ js_data={"actsHistogramData": {id_suffix: data}},
186
+ )
187
+
188
+
189
+ @dataclass_json
190
+ @dataclass
191
+ class LogitsHistogramData(HistogramData):
192
+ def _get_html_data(
193
+ self,
194
+ cfg: LogitsHistogramConfig,
195
+ decode_fn: Callable[[int | list[int]], str | list[str]],
196
+ id_suffix: str,
197
+ column: int | tuple[int, int],
198
+ component_specific_kwargs: dict[str, Any] = {},
199
+ ) -> HTML:
200
+ """
201
+ Converts data -> HTML object, for the logits histogram (i.e. the histogram over all tokens in the vocab, showing
202
+ the distribution of direct logit effect on that token).
203
+ """
204
+ # We can't post-hoc change the number of bins, so check this wasn't changed in the config
205
+ # assert cfg.n_bins == len(self.bar_heights),\
206
+ # "Can't post-hoc change `n_bins` in histogram config - you need to regenerate data."
207
+
208
+ # Read HTML from file, and replace placeholders with real ID values
209
+ html_str = (
210
+ Path(__file__).parent / "html" / "logits_histogram_template.html"
211
+ ).read_text()
212
+ html_str = html_str.replace(
213
+ "HISTOGRAM_LOGITS_ID", f"histogram-logits-{id_suffix}"
214
+ )
215
+
216
+ data: dict[str, Any] = {
217
+ "y": self.bar_heights,
218
+ "x": self.bar_values,
219
+ "ticks": self.tick_vals,
220
+ }
221
+ if self.title is not None:
222
+ data["title"] = self.title
223
+
224
+ return HTML(
225
+ html_data={column: html_str},
226
+ js_data={"logitsHistogramData": {id_suffix: data}},
227
+ )
228
+
229
+
230
+ @dataclass_json
231
+ @dataclass
232
+ class LogitsTableData:
233
+ bottom_token_ids: list[int] = field(default_factory=list)
234
+ bottom_logits: list[float] = field(default_factory=list)
235
+ top_token_ids: list[int] = field(default_factory=list)
236
+ top_logits: list[float] = field(default_factory=list)
237
+
238
+ def _get_html_data(
239
+ self,
240
+ cfg: LogitsTableConfig,
241
+ decode_fn: Callable[[int | list[int]], str | list[str]],
242
+ id_suffix: str,
243
+ column: int | tuple[int, int],
244
+ component_specific_kwargs: dict[str, Any] = {},
245
+ ) -> HTML:
246
+ """
247
+ Converts data -> HTML object, for the logits table (i.e. the top and bottom affected tokens by this feature).
248
+ """
249
+ # Crop the lists to `cfg.n_rows` (first checking the config doesn't ask for more rows than we have)
250
+ assert cfg.n_rows <= len(self.bottom_logits)
251
+ bottom_token_ids = self.bottom_token_ids[: cfg.n_rows]
252
+ bottom_logits = self.bottom_logits[: cfg.n_rows]
253
+ top_token_ids = self.top_token_ids[: cfg.n_rows]
254
+ top_logits = self.top_logits[: cfg.n_rows]
255
+
256
+ # Get the negative and positive background values (darkest when equals max abs)
257
+ max_value = max(
258
+ max(top_logits[: cfg.n_rows]), -min(bottom_logits[: cfg.n_rows])
259
+ )
260
+ neg_bg_values = np.absolute(bottom_logits[: cfg.n_rows]) / max_value
261
+ pos_bg_values = np.absolute(top_logits[: cfg.n_rows]) / max_value
262
+
263
+ # Get the string tokens, using the decode function
264
+ neg_str = to_str_tokens(decode_fn, bottom_token_ids[: cfg.n_rows])
265
+ pos_str = to_str_tokens(decode_fn, top_token_ids[: cfg.n_rows])
266
+
267
+ # Read HTML from file, and replace placeholders with real ID values
268
+ html_str = (
269
+ Path(__file__).parent / "html" / "logits_table_template.html"
270
+ ).read_text()
271
+ html_str = html_str.replace("LOGITS_TABLE_ID", f"logits-table-{id_suffix}")
272
+
273
+ # Create object for storing JS data
274
+ data: dict[str, list[dict[str, str | float]]] = {
275
+ "negLogits": [],
276
+ "posLogits": [],
277
+ }
278
+
279
+ # Get data for the tables of pos/neg logits
280
+ for i in range(len(neg_str)):
281
+ data["negLogits"].append(
282
+ {
283
+ "symbol": unprocess_str_tok(neg_str[i]),
284
+ "value": round(bottom_logits[i], 2),
285
+ "color": f"rgba(255,{int(255*(1-neg_bg_values[i]))},{int(255*(1-neg_bg_values[i]))},0.5)",
286
+ }
287
+ )
288
+ data["posLogits"].append(
289
+ {
290
+ "symbol": unprocess_str_tok(pos_str[i]),
291
+ "value": round(top_logits[i], 2),
292
+ "color": f"rgba({int(255*(1-pos_bg_values[i]))},{int(255*(1-pos_bg_values[i]))},255,0.5)",
293
+ }
294
+ )
295
+
296
+ return HTML(
297
+ html_data={column: html_str},
298
+ js_data={"logitsTableData": {id_suffix: data}},
299
+ )
300
+
301
+
302
+ @dataclass_json
303
+ @dataclass
304
+ class SequenceData:
305
+ """
306
+ This contains all the data necessary to make a sequence of tokens in the vis. See diagram in readme:
307
+
308
+ https://github.com/callummcdougall/sae_vis#data_storing_fnspy
309
+
310
+ Always-visible data:
311
+ token_ids: List of token IDs in the sequence
312
+ feat_acts: Sizes of activations on this sequence
313
+ loss_contribution: Effect on loss of this feature, for this particular token (neg = helpful)
314
+
315
+ Data which is visible on hover:
316
+ token_logits: The logits of the particular token in that sequence (used for line on logits histogram)
317
+ top_token_ids: List of the top 5 logit-boosted tokens by this feature
318
+ top_logits: List of the corresponding 5 changes in logits for those tokens
319
+ bottom_token_ids: List of the bottom 5 logit-boosted tokens by this feature
320
+ bottom_logits: List of the corresponding 5 changes in logits for those tokens
321
+ """
322
+
323
+ original_index: int = 0
324
+ qualifying_token_index: int = 0
325
+ token_ids: list[int] = field(default_factory=list)
326
+ feat_acts: list[float] = field(default_factory=list)
327
+ loss_contribution: list[float] = field(default_factory=list)
328
+
329
+ token_logits: list[float] = field(default_factory=list)
330
+ top_token_ids: list[list[int]] = field(default_factory=list)
331
+ top_logits: list[list[float]] = field(default_factory=list)
332
+ bottom_token_ids: list[list[int]] = field(default_factory=list)
333
+ bottom_logits: list[list[float]] = field(default_factory=list)
334
+
335
+ def __post_init__(self) -> None:
336
+ """
337
+ Filters the logits & token IDs by removing any elements which are zero (this saves space in the eventual
338
+ JavaScript).
339
+ """
340
+ self.seq_len = len(self.token_ids)
341
+ self.top_logits, self.top_token_ids = self._filter(
342
+ self.top_logits, self.top_token_ids
343
+ )
344
+ self.bottom_logits, self.bottom_token_ids = self._filter(
345
+ self.bottom_logits, self.bottom_token_ids
346
+ )
347
+
348
+ def _filter(
349
+ self, float_list: list[list[float]], int_list: list[list[int]]
350
+ ) -> tuple[list[list[float]], list[list[int]]]:
351
+ """
352
+ Filters the list of floats and ints, by removing any elements which are zero. Note - the absolute values of the
353
+ floats are monotonic non-increasing, so we can assume that all the elements we keep will be the first elements
354
+ of their respective lists. Also reduces precisions of feature activations & logits.
355
+ """
356
+ # Next, filter out zero-elements and reduce precision
357
+ float_list = [
358
+ [round(f, PRECISION) for f in floats if abs(f) > 1e-6]
359
+ for floats in float_list
360
+ ]
361
+ int_list = [ints[: len(floats)] for ints, floats in zip(int_list, float_list)]
362
+ return float_list, int_list
363
+
364
+ def _get_html_data(
365
+ self,
366
+ cfg: PromptConfig | SequencesConfig,
367
+ decode_fn: Callable[[int | list[int]], str | list[str]],
368
+ id_suffix: str,
369
+ column: int | tuple[int, int],
370
+ component_specific_kwargs: dict[str, Any] = {},
371
+ ) -> HTML:
372
+ """
373
+ Args:
374
+
375
+ Returns:
376
+ js_data: list[dict[str, Any]]
377
+ The data for this sequence, in the form of a list of dicts for each token (where the dict stores things
378
+ like token, feature activations, etc).
379
+ """
380
+ assert isinstance(
381
+ cfg, (PromptConfig, SequencesConfig)
382
+ ), f"Invalid config type: {type(cfg)}"
383
+ seq_group_id = component_specific_kwargs.get("seq_group_id", None)
384
+ max_feat_act = component_specific_kwargs.get("max_feat_act", None)
385
+ max_loss_contribution = component_specific_kwargs.get(
386
+ "max_loss_contribution", None
387
+ )
388
+ bold_idx = component_specific_kwargs.get("bold_idx", None)
389
+ permanent_line = component_specific_kwargs.get("permanent_line", False)
390
+ first_in_group = component_specific_kwargs.get("first_in_group", True)
391
+ title = component_specific_kwargs.get("title", None)
392
+ hover_above = component_specific_kwargs.get("hover_above", False)
393
+
394
+ # If we didn't supply a sequence group ID, then we assume this sequence is on its own, and give it a unique ID
395
+ if seq_group_id is None:
396
+ seq_group_id = f"prompt-{column:03d}"
397
+
398
+ # If we didn't specify bold_idx, then set it to be the midpoint
399
+ if bold_idx is None:
400
+ bold_idx = self.seq_len // 2
401
+
402
+ # If we only have data for the bold token, we pad out everything with zeros or empty lists
403
+ only_bold = isinstance(cfg, SequencesConfig) and not (cfg.compute_buffer)
404
+ if only_bold:
405
+ assert bold_idx != "max", "Don't know how to deal with this case yet."
406
+ feat_acts = [
407
+ self.feat_acts[0] if (i == bold_idx) else 0.0
408
+ for i in range(self.seq_len)
409
+ ]
410
+ loss_contribution = [
411
+ self.loss_contribution[0] if (i == bold_idx) + 1 else 0.0
412
+ for i in range(self.seq_len)
413
+ ]
414
+ pos_ids = [
415
+ self.top_token_ids[0] if (i == bold_idx) + 1 else []
416
+ for i in range(self.seq_len)
417
+ ]
418
+ neg_ids = [
419
+ self.bottom_token_ids[0] if (i == bold_idx) + 1 else []
420
+ for i in range(self.seq_len)
421
+ ]
422
+ pos_val = [
423
+ self.top_logits[0] if (i == bold_idx) + 1 else []
424
+ for i in range(self.seq_len)
425
+ ]
426
+ neg_val = [
427
+ self.bottom_logits[0] if (i == bold_idx) + 1 else []
428
+ for i in range(self.seq_len)
429
+ ]
430
+ else:
431
+ feat_acts = deepcopy(self.feat_acts)
432
+ loss_contribution = deepcopy(self.loss_contribution)
433
+ pos_ids = deepcopy(self.top_token_ids)
434
+ neg_ids = deepcopy(self.bottom_token_ids)
435
+ pos_val = deepcopy(self.top_logits)
436
+ neg_val = deepcopy(self.bottom_logits)
437
+
438
+ # EXPERIMENT: let's just hardcode everything except feature acts to be 0's for now.
439
+ loss_contribution = [0.0 for _ in range(self.seq_len)]
440
+ pos_ids = [[] for _ in range(self.seq_len)]
441
+ neg_ids = [[] for _ in range(self.seq_len)]
442
+ pos_val = [[] for _ in range(self.seq_len)]
443
+ neg_val = [[] for _ in range(self.seq_len)]
444
+ ### END EXPERIMENT
445
+
446
+ # Get values for converting into colors later
447
+ bg_denom = max_feat_act or max_or_1(self.feat_acts)
448
+ u_denom = max_loss_contribution or max_or_1(self.loss_contribution, abs=True)
449
+ bg_values = (np.maximum(feat_acts, 0.0) / max(1e-4, bg_denom)).tolist()
450
+ u_values = (np.array(loss_contribution) / max(1e-4, u_denom)).tolist()
451
+
452
+ # If we sent in a prompt rather than this being sliced from a longer sequence, then the pos_ids etc will be shorter
453
+ # than the token list by 1, so we need to pad it at the first token
454
+ if isinstance(cfg, PromptConfig):
455
+ assert (
456
+ len(pos_ids)
457
+ == len(neg_ids)
458
+ == len(pos_val)
459
+ == len(neg_val)
460
+ == len(self.token_ids) - 1
461
+ ), "If this is a single prompt, these lists must be the same length as token_ids or 1 less"
462
+ pos_ids = [[]] + pos_ids
463
+ neg_ids = [[]] + neg_ids
464
+ pos_val = [[]] + pos_val
465
+ neg_val = [[]] + neg_val
466
+
467
+ assert (
468
+ len(pos_ids)
469
+ == len(neg_ids)
470
+ == len(pos_val)
471
+ == len(neg_val)
472
+ == len(self.token_ids)
473
+ ), "If this is part of a sequence group etc are given, they must be the same length as token_ids"
474
+
475
+ # Process the tokens to get str toks
476
+ toks = to_str_tokens(decode_fn, self.token_ids)
477
+ pos_toks = [to_str_tokens(decode_fn, pos) for pos in pos_ids]
478
+ neg_toks = [to_str_tokens(decode_fn, neg) for neg in neg_ids]
479
+
480
+ # Define the JavaScript object which will be used to populate the HTML string
481
+ js_data_list = []
482
+
483
+ for i in range(len(self.token_ids)):
484
+ # We might store a bunch of different case-specific data in the JavaScript object for each token. This is
485
+ # done in the form of a disjoint union over different dictionaries (which can each be empty or not), this
486
+ # minimizes the size of the overall JavaScript object. See function in `tokens_script.js` for more.
487
+ kwargs_bold: dict[str, bool] = {}
488
+ kwargs_hide: dict[str, bool] = {}
489
+ kwargs_this_token_active: dict[str, Any] = {}
490
+ kwargs_prev_token_active: dict[str, Any] = {}
491
+ kwargs_hover_above: dict[str, bool] = {}
492
+
493
+ # Get args if this is the bolded token (we make it bold, and maybe add permanent line to histograms)
494
+ if bold_idx is not None:
495
+ kwargs_bold["isBold"] = (bold_idx == i) or (
496
+ bold_idx == "max" and i == np.argmax(feat_acts).item()
497
+ )
498
+ if kwargs_bold["isBold"] and permanent_line:
499
+ kwargs_bold["permanentLine"] = True
500
+
501
+ # If we only have data for the bold token, we hide all other tokens' hoverdata (and skip other kwargs)
502
+ if (
503
+ only_bold
504
+ and isinstance(bold_idx, int)
505
+ and (i not in {bold_idx, bold_idx + 1})
506
+ ):
507
+ kwargs_hide["hide"] = True
508
+
509
+ else:
510
+ # Get args if we're making the tooltip hover above token (default is below)
511
+ if hover_above:
512
+ kwargs_hover_above["hoverAbove"] = True
513
+
514
+ # If feature active on this token, get background color and feature act (for hist line)
515
+ if abs(feat_acts[i]) > 1e-8:
516
+ kwargs_this_token_active = dict(
517
+ featAct=round(feat_acts[i], PRECISION),
518
+ bgColor=bgColorMap(bg_values[i]),
519
+ )
520
+
521
+ # If prev token active, get the top/bottom logits table, underline color, and loss effect (for hist line)
522
+ pos_toks_i, neg_toks_i, pos_val_i, neg_val_i = (
523
+ pos_toks[i],
524
+ neg_toks[i],
525
+ pos_val[i],
526
+ neg_val[i],
527
+ )
528
+ if len(pos_toks_i) + len(neg_toks_i) > 0:
529
+ # Create dictionary
530
+ kwargs_prev_token_active = dict(
531
+ posToks=pos_toks_i,
532
+ negToks=neg_toks_i,
533
+ posVal=pos_val_i,
534
+ negVal=neg_val_i,
535
+ lossEffect=round(loss_contribution[i], PRECISION),
536
+ uColor=uColorMap(u_values[i]),
537
+ )
538
+
539
+ js_data_list.append(
540
+ dict(
541
+ tok=unprocess_str_tok(toks[i]),
542
+ tokID=self.token_ids[i],
543
+ tokenLogit=round(self.token_logits[i], PRECISION),
544
+ **kwargs_bold,
545
+ **kwargs_this_token_active,
546
+ **kwargs_prev_token_active,
547
+ **kwargs_hover_above,
548
+ )
549
+ )
550
+
551
+ # Create HTML string (empty by default since sequences are added by JavaScript) and JS data
552
+ html_str = ""
553
+ js_seq_group_data: dict[str, Any] = {"data": [js_data_list]}
554
+
555
+ # Add group-specific stuff if this is the first sequence in the group
556
+ if first_in_group:
557
+ # Read HTML from file, replace placeholders with real ID values
558
+ html_str = (
559
+ Path(__file__).parent / "html" / "sequences_group_template.html"
560
+ ).read_text()
561
+ html_str = html_str.replace("SEQUENCE_GROUP_ID", seq_group_id)
562
+
563
+ # Get title of sequence group, and the idSuffix to match up with a histogram
564
+ js_seq_group_data["idSuffix"] = id_suffix
565
+ if title is not None:
566
+ js_seq_group_data["title"] = title
567
+
568
+ return HTML(
569
+ html_data={column: html_str},
570
+ js_data={"tokenData": {seq_group_id: js_seq_group_data}},
571
+ )
572
+
573
+
574
+ @dataclass_json
575
+ @dataclass
576
+ class SequenceGroupData:
577
+ """
578
+ This contains all the data necessary to make a single group of sequences (e.g. a quantile in prompt-centric
579
+ visualization). See diagram in readme:
580
+
581
+ https://github.com/callummcdougall/sae_vis#data_storing_fnspy
582
+
583
+ Inputs:
584
+ title: The title that this sequence group will have, if any. This is used in `_get_html_data`. The titles
585
+ will actually be in the HTML strings, not in the JavaScript data.
586
+ seq_data: The data for the sequences in this group.
587
+ """
588
+
589
+ title: str = ""
590
+ seq_data: list[SequenceData] = field(default_factory=list)
591
+
592
+ def __len__(self) -> int:
593
+ return len(self.seq_data)
594
+
595
+ @property
596
+ def max_feat_act(self) -> float:
597
+ """Returns maximum value of feature activation over all sequences in this group."""
598
+ return max_or_1([act for seq in self.seq_data for act in seq.feat_acts])
599
+
600
+ @property
601
+ def max_loss_contribution(self) -> float:
602
+ """Returns maximum value of loss contribution over all sequences in this group."""
603
+ return max_or_1(
604
+ [loss for seq in self.seq_data for loss in seq.loss_contribution], abs=True
605
+ )
606
+
607
+ def _get_html_data(
608
+ self,
609
+ cfg: SequencesConfig,
610
+ decode_fn: Callable[[int | list[int]], str | list[str]],
611
+ id_suffix: str,
612
+ column: int | tuple[int, int],
613
+ component_specific_kwargs: dict[str, Any] = {},
614
+ # These default values should be correct when we only have one sequence group, because when we call this from
615
+ # a SequenceMultiGroupData we'll override them)
616
+ ) -> HTML:
617
+ """
618
+ This creates a single group of sequences, i.e. title plus some number of vertically stacked sequences.
619
+
620
+ Note, `column` is treated specially here, because the col might overflow (hence colulmn could be a tuple).
621
+
622
+ Args (from component-specific kwargs):
623
+ seq_group_id: The id of the sequence group div. This will usually be passed as e.g. "seq-group-001".
624
+ group_size: Max size of sequences in the group (i.e. we truncate after this many, if argument supplied).
625
+ max_feat_act: If supplied, then we use this as the most extreme value (for coloring by feature act).
626
+
627
+ Returns:
628
+ html_obj: Object containing the HTML and JavaScript data for this seq group.
629
+ """
630
+ seq_group_id = component_specific_kwargs.get("seq_group_id", None)
631
+ group_size = component_specific_kwargs.get("group_size", None)
632
+ max_feat_act = component_specific_kwargs.get("max_feat_act", self.max_feat_act)
633
+ max_loss_contribution = component_specific_kwargs.get(
634
+ "max_loss_contribution", self.max_loss_contribution
635
+ )
636
+
637
+ # Get the data that will go into the div (list of list of dicts, i.e. containing all data for seqs in group). We
638
+ # start with the title.
639
+ html_obj = HTML()
640
+
641
+ # If seq_group_id is not supplied, then we assume this is the only sequence in the column, and we name the group
642
+ # after the column
643
+ if seq_group_id is None:
644
+ seq_group_id = f"seq-group-{column:03d}"
645
+
646
+ # Accumulate the HTML data for each sequence in this group
647
+ for i, seq in enumerate(self.seq_data[:group_size]):
648
+ html_obj += seq._get_html_data(
649
+ cfg=cfg,
650
+ # pass in a PromptConfig object
651
+ decode_fn=decode_fn,
652
+ id_suffix=id_suffix,
653
+ column=column,
654
+ component_specific_kwargs=dict(
655
+ bold_idx="max" if cfg.buffer is None else cfg.buffer[0],
656
+ permanent_line=False, # in a group, we're never showing a permanent line (only for single seqs)
657
+ max_feat_act=max_feat_act,
658
+ max_loss_contribution=max_loss_contribution,
659
+ seq_group_id=seq_group_id,
660
+ first_in_group=(i == 0),
661
+ title=self.title,
662
+ ),
663
+ )
664
+
665
+ return html_obj
666
+
667
+
668
+ @dataclass_json
669
+ @dataclass
670
+ class SequenceMultiGroupData:
671
+ """
672
+ This contains all the data necessary to make multiple groups of sequences (e.g. the different quantiles in the
673
+ prompt-centric visualization). See diagram in readme:
674
+
675
+ https://github.com/callummcdougall/sae_vis#data_storing_fnspy
676
+ """
677
+
678
+ seq_group_data: list[SequenceGroupData] = field(default_factory=list)
679
+
680
+ def __getitem__(self, idx: int) -> SequenceGroupData:
681
+ return self.seq_group_data[idx]
682
+
683
+ @property
684
+ def max_feat_act(self) -> float:
685
+ """Returns maximum value of feature activation over all sequences in this group."""
686
+ return max_or_1([seq_group.max_feat_act for seq_group in self.seq_group_data])
687
+
688
+ @property
689
+ def max_loss_contribution(self) -> float:
690
+ """Returns maximum value of loss contribution over all sequences in this group."""
691
+ return max_or_1(
692
+ [seq_group.max_loss_contribution for seq_group in self.seq_group_data]
693
+ )
694
+
695
+ def _get_html_data(
696
+ self,
697
+ cfg: SequencesConfig,
698
+ decode_fn: Callable[[int | list[int]], str | list[str]],
699
+ id_suffix: str,
700
+ column: int | tuple[int, int],
701
+ component_specific_kwargs: dict[str, Any] = {},
702
+ ) -> HTML:
703
+ """
704
+ Args:
705
+ decode_fn: Mapping from token IDs to string tokens.
706
+ id_suffix: The suffix for the ID of the div containing the sequences.
707
+ column: The index of this column. Note that this will be an int, but we might end up
708
+ turning it into a tuple if we overflow into a new column.
709
+ component_specific_kwargs: Contains any specific kwargs that could be used to customize this component.
710
+
711
+ Returns:
712
+ html_obj: Object containing the HTML and JavaScript data for these multiple seq groups.
713
+ """
714
+ assert isinstance(column, int)
715
+
716
+ # Get max activation value & max loss contributions, over all sequences in all groups
717
+ max_feat_act = component_specific_kwargs.get("max_feat_act", self.max_feat_act)
718
+ max_loss_contribution = component_specific_kwargs.get(
719
+ "max_loss_contribution", self.max_loss_contribution
720
+ )
721
+
722
+ # Get the correct column indices for the sequence groups, depending on how group_wrap is configured. Note, we
723
+ # deal with overflowing columns by extending the dictionary, i.e. our column argument isn't just `column`, but
724
+ # is a tuple of `(column, x)` where `x` is the number of times we've overflowed. For instance, if we have mode
725
+ # 'stack-none' then our columns are `(column, 0), (column, 1), (column, 1), (column, 1), (column, 2), ...`
726
+ n_groups = len(self.seq_group_data)
727
+ n_quantile_groups = n_groups - 1
728
+ match cfg.stack_mode:
729
+ case "stack-all":
730
+ # Here, we stack all groups into 1st column
731
+ cols = [column for _ in range(n_groups)]
732
+ case "stack-quantiles":
733
+ # Here, we give 1st group its own column, and stack all groups into second column
734
+ cols = [(column, 0)] + [(column, 1) for _ in range(n_quantile_groups)]
735
+ case "stack-none":
736
+ # Here, we stack groups into columns as [1, 3, 3, ...]
737
+ cols = [
738
+ (column, 0),
739
+ *[(column, 1 + int(i / 3)) for i in range(n_quantile_groups)],
740
+ ]
741
+ case _:
742
+ raise ValueError(
743
+ f"Invalid stack_mode: {cfg.stack_mode}. Expected in 'stack-{{all,quantiles,none}}'."
744
+ )
745
+
746
+ # Create the HTML object, and add all the sequence groups to it, possibly across different columns
747
+ html_obj = HTML()
748
+ for i, (col, group_size, sequences_group) in enumerate(
749
+ zip(cols, cfg.group_sizes, self.seq_group_data)
750
+ ):
751
+ html_obj += sequences_group._get_html_data(
752
+ cfg=cfg,
753
+ decode_fn=decode_fn,
754
+ id_suffix=id_suffix,
755
+ column=col,
756
+ component_specific_kwargs=dict(
757
+ group_size=group_size,
758
+ max_feat_act=max_feat_act,
759
+ max_loss_contribution=max_loss_contribution,
760
+ seq_group_id=f"seq-group-{column}-{i}", # we label our sequence groups with (index, column)
761
+ ),
762
+ )
763
+
764
+ return html_obj
765
+
766
+
767
+ GenericData = (
768
+ FeatureTablesData
769
+ | ActsHistogramData
770
+ | LogitsTableData
771
+ | LogitsHistogramData
772
+ | SequenceMultiGroupData
773
+ | SequenceData
774
+ )
SAEDashboard/sae_dashboard/components_config.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Iterator, Literal
3
+
4
+ SEQUENCES_CONFIG_HELP = dict(
5
+ buffer="How many tokens to add as context to each sequence, on each side. The tokens chosen for the top acts / \
6
+ quantile groups can't be outside the buffer range. If None, we use the entire sequence as context.",
7
+ compute_buffer="If False, then we don't compute the loss effect, activations, or any other data for tokens \
8
+ other than the bold tokens in our sequences (saving time).",
9
+ n_quantiles="Number of quantile groups for the sequences. If zero, we only show top activations, no quantile \
10
+ groups.",
11
+ top_acts_group_size="Number of sequences in the 'top activating sequences' group.",
12
+ quantile_group_size="Number of sequences in each of the sequence quantile groups.",
13
+ top_logits_hoverdata="Number of top/bottom logits to show in the hoverdata for each token.",
14
+ stack_mode="How to stack the sequence groups.\n 'stack-all' = all groups are stacked in a single column \
15
+ (scrolls vertically if it overflows)\n 'stack-quantiles' = first col contains top acts, second col contains all \
16
+ quantile groups\n 'stack-none' = we stack in a way which ensures no vertical scrolling.",
17
+ hover_below="Whether the hover information about a token appears below or above the token.",
18
+ )
19
+
20
+ ACTIVATIONS_HISTOGRAM_CONFIG_HELP = dict(
21
+ n_bins="Number of bins for the histogram.",
22
+ )
23
+
24
+ LOGITS_HISTOGRAM_CONFIG_HELP = dict(
25
+ n_bins="Number of bins for the histogram.",
26
+ )
27
+
28
+ LOGITS_TABLE_CONFIG_HELP = dict(
29
+ n_rows="Number of top/bottom logits to show in the table.",
30
+ )
31
+
32
+ FEATURE_TABLES_CONFIG_HELP = dict(
33
+ n_rows="Number of rows to show for each feature table.",
34
+ neuron_alignment_table="Whether to show the neuron alignment table.",
35
+ correlated_neurons_table="Whether to show the correlated neurons table.",
36
+ correlated_features_table="Whether to show the (pairwise) correlated features table.",
37
+ correlated_b_features_table="Whether to show the correlated encoder-B features table.",
38
+ )
39
+
40
+
41
+ @dataclass
42
+ class BaseComponentConfig:
43
+ def data_is_contained_in(self, other: "BaseComponentConfig") -> bool:
44
+ """
45
+ This returns False only when the data that was computed based on `other` wouldn't be enough to show the data
46
+ that was computed based on `self`. For instance, if `self` was a config object with 10 rows, and `other` had
47
+ just 5 rows, then this would return False. A less obvious example: if `self` was a histogram config with 50 bins
48
+ then `other` would need to have exactly 50 bins (because we can't change the bins after generating them).
49
+ """
50
+ return True
51
+
52
+ @property
53
+ def help_dict(self) -> dict[str, str]:
54
+ """
55
+ This is a dictionary which maps the name of each argument to a description of what it does. This is used when
56
+ printing out the help for a config object, to show what each argument does.
57
+ """
58
+ return {}
59
+
60
+
61
+ @dataclass
62
+ class PromptConfig(BaseComponentConfig):
63
+ pass
64
+
65
+
66
+ @dataclass
67
+ class SequencesConfig(BaseComponentConfig):
68
+ buffer: tuple[int, int] | None = (5, 5)
69
+ compute_buffer: bool = True
70
+ n_quantiles: int = 10
71
+ top_acts_group_size: int = 20
72
+ quantile_group_size: int = 5
73
+ top_logits_hoverdata: int = 5
74
+ stack_mode: Literal["stack-all", "stack-quantiles", "stack-none"] = "stack-all"
75
+ hover_below: bool = True
76
+
77
+ def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
78
+ assert isinstance(other, self.__class__)
79
+ return all(
80
+ [
81
+ self.buffer is None
82
+ or (
83
+ other.buffer is not None and self.buffer[0] <= other.buffer[0]
84
+ ), # the buffer needs to be <=
85
+ self.buffer is None
86
+ or (other.buffer is not None and self.buffer[1] <= other.buffer[1]),
87
+ int(self.compute_buffer)
88
+ <= int(
89
+ other.compute_buffer
90
+ ), # we can't compute the buffer if we didn't in `other`
91
+ self.n_quantiles
92
+ in {
93
+ 0,
94
+ other.n_quantiles,
95
+ }, # we actually need the quantiles identical (or one to be zero)
96
+ self.top_acts_group_size
97
+ <= other.top_acts_group_size, # group size needs to be <=
98
+ self.quantile_group_size
99
+ <= other.quantile_group_size, # each quantile group needs to be <=
100
+ self.top_logits_hoverdata
101
+ <= other.top_logits_hoverdata, # hoverdata rows need to be <=
102
+ ]
103
+ )
104
+
105
+ def __post_init__(self):
106
+ # Get list of group lengths, based on the config params
107
+ self.group_sizes = [self.top_acts_group_size] + [
108
+ self.quantile_group_size
109
+ ] * self.n_quantiles
110
+
111
+ @property
112
+ def help_dict(self) -> dict[str, str]:
113
+ return SEQUENCES_CONFIG_HELP
114
+
115
+
116
+ @dataclass
117
+ class ActsHistogramConfig(BaseComponentConfig):
118
+ n_bins: int = 50
119
+
120
+ def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
121
+ assert isinstance(other, self.__class__)
122
+ return self.n_bins == other.n_bins
123
+
124
+ @property
125
+ def help_dict(self) -> dict[str, str]:
126
+ return ACTIVATIONS_HISTOGRAM_CONFIG_HELP
127
+
128
+
129
+ @dataclass
130
+ class LogitsHistogramConfig(BaseComponentConfig):
131
+ n_bins: int = 50
132
+
133
+ def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
134
+ assert isinstance(other, self.__class__)
135
+ return self.n_bins == other.n_bins
136
+
137
+ @property
138
+ def help_dict(self) -> dict[str, str]:
139
+ return LOGITS_HISTOGRAM_CONFIG_HELP
140
+
141
+
142
+ @dataclass
143
+ class LogitsTableConfig(BaseComponentConfig):
144
+ n_rows: int = 10
145
+
146
+ def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
147
+ assert isinstance(other, self.__class__)
148
+ return self.n_rows <= other.n_rows
149
+
150
+ @property
151
+ def help_dict(self) -> dict[str, str]:
152
+ return LOGITS_TABLE_CONFIG_HELP
153
+
154
+
155
+ @dataclass
156
+ class FeatureTablesConfig(BaseComponentConfig):
157
+ n_rows: int = 3
158
+ neuron_alignment_table: bool = True
159
+ correlated_neurons_table: bool = True
160
+ correlated_features_table: bool = True
161
+ correlated_b_features_table: bool = False
162
+
163
+ def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
164
+ assert isinstance(other, self.__class__)
165
+ return all(
166
+ [
167
+ self.n_rows <= other.n_rows,
168
+ self.neuron_alignment_table <= other.neuron_alignment_table,
169
+ self.correlated_neurons_table <= other.correlated_neurons_table,
170
+ self.correlated_features_table <= other.correlated_features_table,
171
+ self.correlated_b_features_table <= other.correlated_b_features_table,
172
+ ]
173
+ )
174
+
175
+ @property
176
+ def help_dict(self) -> dict[str, str]:
177
+ return FEATURE_TABLES_CONFIG_HELP
178
+
179
+
180
+ GenericComponentConfig = (
181
+ PromptConfig
182
+ | SequencesConfig
183
+ | ActsHistogramConfig
184
+ | LogitsHistogramConfig
185
+ | LogitsTableConfig
186
+ | FeatureTablesConfig
187
+ )
188
+
189
+
190
+ class Column:
191
+ def __init__(
192
+ self,
193
+ *args: GenericComponentConfig,
194
+ width: int | None = None,
195
+ ):
196
+ self.components = list(args)
197
+ self.width = width
198
+
199
+ def __iter__(self) -> Iterator[Any]:
200
+ return iter(self.components)
201
+
202
+ def __getitem__(self, idx: int) -> Any:
203
+ return self.components[idx]
204
+
205
+ def __len__(self) -> int:
206
+ return len(self.components)
SAEDashboard/sae_dashboard/css/dropdown.css ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Styling of the dropdowns */
2
+ select {
3
+ appearance: none;
4
+ border: 0;
5
+ flex: 1;
6
+ padding: 0 1em;
7
+ background-color: #eee;
8
+ cursor: pointer;
9
+ }
10
+ .select {
11
+ box-shadow: 0 5px 5px rgba(0, 0, 0, 0.25);
12
+ cursor: pointer;
13
+ display: flex;
14
+ width: 100px;
15
+ height: 25px;
16
+ border-radius: .25em;
17
+ overflow: hidden;
18
+ position: relative;
19
+ margin-right: 15px;
20
+ }
21
+ .select::after {
22
+ position: absolute;
23
+ content: '\25BC';
24
+ font-size: 9px;
25
+ top: 0;
26
+ right: 0;
27
+ padding: 1em;
28
+ background-color: #ddd;
29
+ transition: .25s all ease;
30
+ pointer-events: none;
31
+ }
32
+ .select:hover::after {
33
+ color: black;
34
+ }
35
+ #dropdown-container {
36
+ margin-left: 10px;
37
+ margin-top: 20px;
38
+ display: flex;
39
+ flex-wrap: wrap;
40
+ }
SAEDashboard/sae_dashboard/css/general.css ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Styling of the top-level container */
2
+ .grid-container {
3
+ font-family: 'system-ui';
4
+ border: 1px solid #e6e6e6;
5
+ background-color: #fff;
6
+ margin: 30px 10px;
7
+ box-shadow: 0 5px 5px rgba(0, 0, 0, 0.25);
8
+ display: grid;
9
+ justify-content: start;
10
+ grid-template-columns: auto;
11
+ overflow-x: auto;
12
+ overflow-y: visible;
13
+ grid-auto-flow: column;
14
+ white-space: nowrap;
15
+ padding-bottom: 12px;
16
+ padding-top: 35px;
17
+ padding-left: 20px;
18
+ }
19
+ /* Styling each grid column (note, the max-height controls height of grid-container) */
20
+ .grid-column {
21
+ margin-left: 20px;
22
+ padding-right: 20px;
23
+ width: max-content;
24
+ overflow-y: auto;
25
+ max-height: 750px;
26
+ }
27
+ /* Styling the scrollbars */
28
+ ::-webkit-scrollbar {
29
+ height: 10px;
30
+ width: 10px;
31
+ }
32
+ ::-webkit-scrollbar-track {
33
+ background: #f1f1f1;
34
+ }
35
+ ::-webkit-scrollbar-thumb {
36
+ background: #999;
37
+ }
38
+ ::-webkit-scrollbar-thumb:hover {
39
+ background: #555;
40
+ }
41
+ /* Margin at the bottom of each histogram */
42
+ .plotly-hist {
43
+ margin-bottom: 25px;
44
+ }
45
+ /* Margins below the titles (most subtitles are h4, except for the prompt-centric view which has h2 titles) */
46
+ h4 {
47
+ margin-top: 0px;
48
+ margin-bottom: 10px;
49
+ }
50
+ /* Some space below the <hr> line in prompt-centric vis */
51
+ hr {
52
+ margin-bottom: 35px;
53
+ }
SAEDashboard/sae_dashboard/css/sequences.css ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Default font & appearance for the words in the sequence, before being hovered over */
2
+ code {
3
+ font-family: Consolas, Menlo, Monaco;
4
+ }
5
+ /* Margin at the bottom of every sequence group, plus handle how overflow works (maybe not necessary) */
6
+ .seq-group {
7
+ overflow-x: auto;
8
+ overflow-y: visible;
9
+ padding-top: 5px;
10
+ padding-bottom: 10px;
11
+ margin-bottom: 10px;
12
+ }
13
+ /* Margin between single sequences */
14
+ .seq {
15
+ margin-bottom: 11px;
16
+ }
17
+ /* Styling for each token in a sequence */
18
+ .token {
19
+ font-family: Consolas, Menlo, Monaco;
20
+ font-size: 0.9em;
21
+ border-top-left-radius: 3px;
22
+ border-top-right-radius: 3px;
23
+ padding: 1px;
24
+ color: black;
25
+ display: inline;
26
+ white-space: pre-wrap;
27
+ }
28
+ /* All the messy hovering stuff! */
29
+ .hover-text {
30
+ position: relative;
31
+ cursor: pointer;
32
+ display: inline-block; /* Needed to contain the tooltip */
33
+ box-sizing: border-box;
34
+ }
35
+ .tooltip {
36
+ background-color: #fff;
37
+ color: #333;
38
+ text-align: center;
39
+ border-radius: 10px;
40
+ padding: 5px;
41
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.25);
42
+ align-items: center;
43
+ justify-content: center;
44
+ overflow: hidden;
45
+ font-family: 'system-ui';
46
+ font-size: 1.1em;
47
+ display: none;
48
+ position: fixed;
49
+ z-index: 1000;
50
+ }
51
+ .token:hover {
52
+ border-top: 3px solid black;
53
+ }
54
+ .tooltip-container {
55
+ position: absolute;
56
+ pointer-events: none;
57
+ }
58
+ .hover-text:hover + .tooltip-container .tooltip {
59
+ display: block;
60
+ }
61
+
SAEDashboard/sae_dashboard/css/tables.css ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ table {
2
+ border: unset;
3
+ color: black;
4
+ border-collapse: collapse;
5
+ width: -moz-fit-content;
6
+ width: -webkit-fit-content;
7
+ width: fit-content;
8
+ margin-left: auto;
9
+ margin-right: auto;
10
+ font-size: 0.8em;
11
+ }
12
+ table.table-left tr {
13
+ border-bottom: 1px solid #eee;
14
+ padding: 15px;
15
+ }
16
+ table.table-left td {
17
+ padding: 3px 4px;
18
+ }
19
+ table.table-left {
20
+ width: 100%;
21
+ }
22
+ table.table-left td.left-aligned {
23
+ max-width: 120px;
24
+ overflow-x: hidden;
25
+ }
26
+ td {
27
+ border: none;
28
+ padding: 2px 4px;
29
+ white-space: nowrap;
30
+ }
31
+ .right-aligned {
32
+ text-align: right;
33
+ }
34
+ .left-aligned {
35
+ text-align: left;
36
+ }
37
+ .center-aligned {
38
+ text-align: center;
39
+ padding-bottom: 8px;
40
+ }
41
+ table code {
42
+ background-color: #ddd;
43
+ padding: 2px;
44
+ border-radius: 3px;
45
+ }
46
+ .table-container {
47
+ width: 100%;
48
+ }
49
+ .half-width-container {
50
+ display: flex;
51
+ }
52
+ .half-width {
53
+ width: 50%;
54
+ margin-right: -4px;
55
+ }
56
+
57
+ /* Feature tables should have space below them, also they should have a min column width */
58
+ div.feature-tables table {
59
+ margin-bottom: 25px;
60
+ min-width: 250px;
61
+ }
62
+ /* Configure logits table container (i.e. the thing containing the smaller and larger tables) */
63
+ div.logits-table {
64
+ min-width: 375px;
65
+ display: flex;
66
+ overflow-x: hidden;
67
+ margin-bottom: 20px;
68
+ }
69
+ /* Code is always bold in this table (this is just the neg/pos string tokens) */
70
+ div.logits-table code {
71
+ font-weight: bold;
72
+ }
73
+ /* Set width of the tables inside the container (so they can stack horizontally), also put a gap between them */
74
+ div.logits-table > div.positive {
75
+ width: 47%;
76
+ }
77
+ div.logits-table > div.negative {
78
+ width: 47%;
79
+ margin-right: 5%;
80
+ }
81
+
SAEDashboard/sae_dashboard/data_parsing_fns.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import numpy as np
3
+ import torch
4
+ from eindex import eindex
5
+ from jaxtyping import Float, Int
6
+ from sae_lens import SAE
7
+ from torch import Tensor
8
+ from transformer_lens import HookedTransformer, utils
9
+
10
+ from sae_dashboard.components import LogitsTableData, SequenceData
11
+ from sae_dashboard.sae_vis_data import SaeVisData
12
+ from sae_dashboard.transformer_lens_wrapper import (
13
+ ActivationConfig,
14
+ TransformerLensWrapper,
15
+ to_resid_direction,
16
+ )
17
+ from sae_dashboard.utils_fns import RollingCorrCoef, TopK
18
+
19
+ Arr = np.ndarray
20
+
21
+
22
+ def get_features_table_data(
23
+ feature_out_dir: Float[Tensor, "feats d_out"],
24
+ n_rows: int,
25
+ corrcoef_neurons: RollingCorrCoef | None = None,
26
+ corrcoef_encoder: RollingCorrCoef | None = None,
27
+ ) -> dict[str, list[list[int]] | list[list[float]]]:
28
+ # ! Calculate all data for the left-hand column visualisations, i.e. the 3 tables
29
+ # Store kwargs (makes it easier to turn the tables on and off individually)
30
+ feature_tables_data: dict[str, list[list[int]] | list[list[float]]] = {}
31
+
32
+ # Table 1: neuron alignment, based on decoder weights
33
+ # if layout.feature_tables_cfg.neuron_alignment_table:
34
+ # Let's just always do this.
35
+ add_neuron_alignment_data(
36
+ feature_out_dir=feature_out_dir,
37
+ feature_tables_data=feature_tables_data,
38
+ n_rows=n_rows,
39
+ )
40
+
41
+ # Table 2: neurons correlated with this feature, based on their activations
42
+ if corrcoef_neurons is not None:
43
+ add_feature_neuron_correlations(
44
+ corrcoef_neurons=corrcoef_neurons,
45
+ feature_tables_data=feature_tables_data,
46
+ n_rows=n_rows,
47
+ )
48
+
49
+ # Table 3: primary encoder features correlated with this feature, based on their activations
50
+ if corrcoef_encoder is not None:
51
+ add_intra_encoder_correlations(
52
+ corrcoef_encoder=corrcoef_encoder,
53
+ feature_tables_data=feature_tables_data,
54
+ n_rows=n_rows,
55
+ )
56
+
57
+ return feature_tables_data
58
+
59
+
60
+ def add_intra_encoder_correlations(
61
+ corrcoef_encoder: RollingCorrCoef,
62
+ feature_tables_data: dict[str, list[list[int]] | list[list[float]]],
63
+ n_rows: int,
64
+ ):
65
+ enc_indices, enc_pearson, enc_cossim = corrcoef_encoder.topk_pearson(
66
+ k=n_rows,
67
+ )
68
+ feature_tables_data["correlated_features_indices"] = enc_indices
69
+ feature_tables_data["correlated_features_pearson"] = enc_pearson
70
+ feature_tables_data["correlated_features_cossim"] = enc_cossim
71
+
72
+
73
+ def add_neuron_alignment_data(
74
+ feature_out_dir: Float[Tensor, "feats d_out"],
75
+ feature_tables_data: dict[str, list[list[int]] | list[list[float]]],
76
+ n_rows: int,
77
+ ):
78
+ top3_neurons_aligned = TopK(tensor=feature_out_dir.float(), k=n_rows, largest=True)
79
+ feature_out_l1_norm = feature_out_dir.abs().sum(dim=-1, keepdim=True)
80
+ pct_of_l1: Arr = np.absolute(top3_neurons_aligned.values) / utils.to_numpy(
81
+ feature_out_l1_norm.float()
82
+ )
83
+ feature_tables_data["neuron_alignment_indices"] = (
84
+ top3_neurons_aligned.indices.tolist()
85
+ )
86
+ feature_tables_data["neuron_alignment_values"] = (
87
+ top3_neurons_aligned.values.tolist()
88
+ )
89
+ feature_tables_data["neuron_alignment_l1"] = pct_of_l1.tolist()
90
+
91
+
92
+ def add_feature_neuron_correlations(
93
+ corrcoef_neurons: RollingCorrCoef,
94
+ feature_tables_data: dict[str, list[list[int]] | list[list[float]]],
95
+ n_rows: int,
96
+ ):
97
+ neuron_indices, neuron_pearson, neuron_cossim = corrcoef_neurons.topk_pearson(
98
+ k=n_rows,
99
+ )
100
+
101
+ feature_tables_data["correlated_neurons_indices"] = neuron_indices
102
+ feature_tables_data["correlated_neurons_pearson"] = neuron_pearson
103
+ feature_tables_data["correlated_neurons_cossim"] = neuron_cossim
104
+
105
+
106
+ def get_logits_table_data(
107
+ logit_vector: Float[Tensor, "d_vocab"], n_rows: int # noqa: F821
108
+ ):
109
+ # Get logits table data
110
+ top_logits = TopK(logit_vector.float(), k=n_rows, largest=True)
111
+ bottom_logits = TopK(logit_vector.float(), k=n_rows, largest=False)
112
+
113
+ top_logit_values = top_logits.values.tolist()
114
+ top_token_ids = top_logits.indices.tolist()
115
+
116
+ bottom_logit_values = bottom_logits.values.tolist()
117
+ bottom_token_ids = bottom_logits.indices.tolist()
118
+
119
+ logits_table_data = LogitsTableData(
120
+ bottom_logits=bottom_logit_values,
121
+ bottom_token_ids=bottom_token_ids,
122
+ top_logits=top_logit_values,
123
+ top_token_ids=top_token_ids,
124
+ )
125
+
126
+ return logits_table_data
127
+
128
+
129
+ # @torch.inference_mode()
130
+ # def get_feature_data(
131
+ # encoder: AutoEncoder,
132
+ # model: HookedTransformer,
133
+ # tokens: Int[Tensor, "batch seq"],
134
+ # cfg: SaeVisConfig,
135
+ # ) -> SaeVisData:
136
+ # """
137
+ # This is the main function which users will run to generate the feature visualization data. It batches this
138
+ # computation over features, in accordance with the arguments in the SaeVisConfig object (we don't want to compute all
139
+ # the features at once, since might give OOMs).
140
+
141
+ # See the `_get_feature_data` function for an explanation of the arguments, as well as a more detailed explanation
142
+ # of what this function is doing.
143
+
144
+ # The return object is the merged SaeVisData objects returned by the `_get_feature_data` function.
145
+ # """
146
+ # pass
147
+
148
+ # # return sae_vis_data
149
+
150
+
151
+ @torch.inference_mode()
152
+ def parse_prompt_data(
153
+ tokens: Int[Tensor, "batch seq"],
154
+ str_toks: list[str],
155
+ sae_vis_data: SaeVisData,
156
+ feat_acts: Float[Tensor, "seq feats"],
157
+ feature_resid_dir: Float[Tensor, "feats d_model"],
158
+ resid_post: Float[Tensor, "seq d_model"],
159
+ W_U: Float[Tensor, "d_model d_vocab"],
160
+ feature_idx: list[int] | None = None,
161
+ num_top_features: int = 10,
162
+ ) -> dict[str, tuple[list[int], list[str]]]:
163
+ """
164
+ Gets data needed to create the sequences in the prompt-centric vis (displaying dashboards for the most relevant
165
+ features on a prompt).
166
+
167
+ This function exists so that prompt dashboards can be generated without using our AutoEncoder or
168
+ TransformerLens(Wrapper) classes.
169
+
170
+ Args:
171
+ tokens: Int[Tensor, "batch seq"]
172
+ The tokens we'll be using to get the feature activations. Note that we might not be using all of them; the
173
+ number used is determined by `fvp.total_batch_size`.
174
+
175
+ str_toks: list[str]
176
+ The tokens as a list of strings, so that they can be visualized in HTML.
177
+
178
+ sae_vis_data: SaeVisData
179
+ The object storing all data for each feature. We'll set each `feature_data.prompt_data` to the
180
+ data we get from `prompt`.
181
+
182
+ feat_acts: Float[Tensor, "seq feats"]
183
+ The activations values of the features across the sequence.
184
+
185
+ feature_resid_dir: Float[Tensor, "feats d_model"]
186
+ The directions that each feature writes to the residual stream.
187
+
188
+ resid_post: Float[Tensor, "seq d_model"]
189
+ The activations of the final layer of the model before the unembed.
190
+
191
+ W_U: Float[Tensor, "d_model d_vocab"]
192
+ The model's unembed weights for the logit lens.
193
+
194
+ feature_idx: list[int] or None
195
+ The features we're actually computing. These might just be a subset of the model's full features.
196
+
197
+ num_top_features: int
198
+ The number of top features to display in this view, for any given metric.
199
+
200
+ Returns:
201
+ scores_dict: dict[str, tuple[list[int], list[str]]]
202
+ A dictionary mapping keys like "act_quantile|'django' (0)" to a tuple of lists, where the first list is the
203
+ feature indices, and the second list is the string-formatted values of the scores.
204
+
205
+ As well as returning this dictionary, this function will also set `FeatureData.prompt_data` for each feature in
206
+ `sae_vis_data` (this is necessary for getting the prompts in the prompt-centric vis). Note this design choice could
207
+ have been done differently (i.e. have this function return a list of the prompt data for each feature). I chose this
208
+ way because it means the FeatureData._get_html_data_prompt_centric can work fundamentally the same way as
209
+ FeatureData._get_html_data_feature_centric, rather than treating the prompt data object as a different kind of
210
+ component in the vis.
211
+ """
212
+
213
+ device = sae_vis_data.cfg.device
214
+
215
+ if feature_idx is None:
216
+ feature_idx = list(sae_vis_data.feature_data_dict.keys())
217
+ n_feats = len(feature_idx)
218
+ assert (
219
+ feature_resid_dir.shape[0] == n_feats
220
+ ), f"The number of features in feature_resid_dir ({feature_resid_dir.shape[0]}) does not match the number of feature indices ({n_feats})"
221
+
222
+ assert (
223
+ feat_acts.shape[1] == n_feats
224
+ ), f"The number of features in feat_acts ({feat_acts.shape[1]}) does not match the number of feature indices ({n_feats})"
225
+
226
+ feats_loss_contribution = torch.empty(
227
+ size=(n_feats, tokens.shape[1] - 1), device=device
228
+ )
229
+ # Some logit computations which we only need to do once
230
+ # correct_token_unembeddings = model_wrapped.W_U[:, tokens[0, 1:]] # [d_model seq]
231
+ orig_logits = (
232
+ resid_post / resid_post.std(dim=-1, keepdim=True)
233
+ ) @ W_U # [seq d_vocab]
234
+ raw_logits = feature_resid_dir @ W_U # [feats d_vocab]
235
+
236
+ for i, feat in enumerate(feature_idx):
237
+ # ! Calculate the sequence data for each feature, and store it as FeatureData.prompt_data
238
+
239
+ # Get this feature's output vector, using an outer product over the feature activations for all tokens
240
+ resid_post_feature_effect = einops.einsum(
241
+ feat_acts[:, i], feature_resid_dir[i], "seq, d_model -> seq d_model"
242
+ )
243
+
244
+ # Ablate the output vector from the residual stream, and get logits post-ablation
245
+ new_resid_post = resid_post - resid_post_feature_effect
246
+ new_logits = (new_resid_post / new_resid_post.std(dim=-1, keepdim=True)) @ W_U
247
+
248
+ # Get the top5 & bottom5 changes in logits (don't bother with `efficient_topk` cause it's small)
249
+ contribution_to_logprobs = orig_logits.log_softmax(
250
+ dim=-1
251
+ ) - new_logits.log_softmax(dim=-1)
252
+ top_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5)
253
+ bottom_contribution_to_logits = TopK(
254
+ contribution_to_logprobs[:-1], k=5, largest=False
255
+ )
256
+
257
+ # Get the change in loss (which is negative of change of logprobs for correct token)
258
+ loss_contribution = eindex(
259
+ -contribution_to_logprobs[:-1], tokens[0, 1:], "seq [seq]"
260
+ )
261
+ feats_loss_contribution[i, :] = loss_contribution
262
+
263
+ # Store the sequence data
264
+ sae_vis_data.feature_data_dict[feat].prompt_data = SequenceData(
265
+ token_ids=tokens.squeeze(0).tolist(),
266
+ feat_acts=[round(f, 4) for f in feat_acts[:, i].tolist()],
267
+ loss_contribution=[0.0] + loss_contribution.tolist(),
268
+ token_logits=raw_logits[i, tokens.squeeze(0)].tolist(),
269
+ top_token_ids=top_contribution_to_logits.indices.tolist(),
270
+ top_logits=top_contribution_to_logits.values.tolist(),
271
+ bottom_token_ids=bottom_contribution_to_logits.indices.tolist(),
272
+ bottom_logits=bottom_contribution_to_logits.values.tolist(),
273
+ )
274
+
275
+ # ! Lastly, return a dictionary mapping each key like 'act_quantile|"django" (0)' to a list of feature indices & scores
276
+
277
+ # Get a dict with keys like f"act_quantile|'My' (1)" and values (feature indices list, feature score values list)
278
+ scores_dict: dict[str, tuple[list[int], list[str]]] = {}
279
+
280
+ for seq_pos, seq_key in enumerate([f"{t!r} ({i})" for i, t in enumerate(str_toks)]):
281
+ # Filter the feature activations, since we only need the ones that are non-zero
282
+ feat_acts_nonzero_filter = utils.to_numpy(feat_acts[seq_pos] > 0)
283
+ feat_acts_nonzero_locations = np.nonzero(feat_acts_nonzero_filter)[0].tolist()
284
+ _feat_acts = feat_acts[seq_pos, feat_acts_nonzero_filter] # [feats_filtered,]
285
+ _feature_idx = np.array(feature_idx)[feat_acts_nonzero_filter]
286
+
287
+ if feat_acts_nonzero_filter.sum() > 0:
288
+ k = min(num_top_features, _feat_acts.numel())
289
+
290
+ # Get the top features by activation size. This is just applying a TopK function to the feat acts (which
291
+ # were stored by the code before this). The feat acts are formatted to 3dp.
292
+ act_size_topk = TopK(_feat_acts, k=k, largest=True)
293
+ top_features = _feature_idx[act_size_topk.indices].tolist()
294
+ formatted_scores = [f"{v:.3f}" for v in act_size_topk.values]
295
+ scores_dict[f"act_size|{seq_key}"] = (top_features, formatted_scores)
296
+
297
+ # Get the top features by activation quantile. We do this using the `feature_act_quantiles` object, which
298
+ # was stored `sae_vis_data`. This quantiles object has a method to return quantiles for a given set of
299
+ # data, as well as the precision (we make the precision higher for quantiles closer to 100%, because these
300
+ # are usually the quantiles we're interested in, and it lets us to save space in `feature_act_quantiles`).
301
+ act_quantile, act_precision = sae_vis_data.feature_stats.get_quantile(
302
+ _feat_acts, feat_acts_nonzero_locations
303
+ )
304
+ act_quantile_topk = TopK(act_quantile, k=k, largest=True)
305
+ act_formatting = [
306
+ f".{act_precision[i]-2}%" for i in act_quantile_topk.indices
307
+ ]
308
+ top_features = _feature_idx[act_quantile_topk.indices].tolist()
309
+ formatted_scores = [
310
+ f"{v:{f}}" for v, f in zip(act_quantile_topk.values, act_formatting)
311
+ ]
312
+ scores_dict[f"act_quantile|{seq_key}"] = (top_features, formatted_scores)
313
+
314
+ # We don't measure loss effect on the first token
315
+ if seq_pos == 0:
316
+ continue
317
+
318
+ # Filter the loss effects, since we only need the ones which have non-zero feature acts on the tokens before them
319
+ prev_feat_acts_nonzero_filter = utils.to_numpy(feat_acts[seq_pos - 1] > 0)
320
+ _loss_contribution = feats_loss_contribution[
321
+ prev_feat_acts_nonzero_filter, seq_pos - 1
322
+ ] # [feats_filtered,]
323
+ _feature_idx_prev = np.array(feature_idx)[prev_feat_acts_nonzero_filter]
324
+
325
+ if prev_feat_acts_nonzero_filter.sum() > 0:
326
+ k = min(num_top_features, _loss_contribution.numel())
327
+
328
+ # Get the top features by loss effect. This is just applying a TopK function to the loss effects (which were
329
+ # stored by the code before this). The loss effects are formatted to 3dp. We look for the most negative
330
+ # values, i.e. the most loss-reducing features.
331
+ loss_contribution_topk = TopK(_loss_contribution, k=k, largest=False)
332
+ top_features = _feature_idx_prev[loss_contribution_topk.indices].tolist()
333
+ formatted_scores = [f"{v:+.3f}" for v in loss_contribution_topk.values]
334
+ scores_dict[f"loss_effect|{seq_key}"] = (top_features, formatted_scores)
335
+ return scores_dict
336
+
337
+
338
+ @torch.inference_mode()
339
+ def get_prompt_data(
340
+ sae_vis_data: SaeVisData,
341
+ prompt: str,
342
+ num_top_features: int,
343
+ ) -> dict[str, tuple[list[int], list[str]]]:
344
+ """
345
+ Gets data that will be used to create the sequences in the prompt-centric HTML visualisation, i.e. an object of
346
+ type SequenceData for each of our features.
347
+
348
+ Args:
349
+ sae_vis_data: The object storing all data for each feature. We'll set each `feature_data.prompt_data` to the
350
+ data we get from `prompt`.
351
+ prompt: The prompt we'll be using to get the feature activations.#
352
+ num_top_features: The number of top features we'll be getting data for.
353
+
354
+ Returns:
355
+ scores_dict: A dictionary mapping keys like "act_quantile|0" to a tuple of lists, where the first list is
356
+ the feature indices, and the second list is the string-formatted values of the scores.
357
+
358
+ As well as returning this dictionary, this function will also set `FeatureData.prompt_data` for each feature in
359
+ `sae_vis_data`. This is because the prompt-centric vis will call `FeatureData._get_html_data_prompt_centric` on each
360
+ feature data object, so it's useful to have all the data in once place! Even if this will get overwritten next
361
+ time we call `get_prompt_data` for this same `sae_vis_data` object.
362
+ """
363
+
364
+ # ! Boring setup code
365
+ feature_idx = list(sae_vis_data.feature_data_dict.keys())
366
+ encoder = sae_vis_data.encoder
367
+ assert isinstance(encoder, SAE)
368
+ model = sae_vis_data.model
369
+ assert isinstance(model, HookedTransformer)
370
+ cfg = sae_vis_data.cfg
371
+ assert isinstance(cfg.hook_point, str), f"{cfg.hook_point=}, expected a string"
372
+
373
+ str_toks: list[str] = model.tokenizer.tokenize(prompt) # type: ignore
374
+ tokens = model.tokenizer.encode(prompt, return_tensors="pt").to( # type: ignore
375
+ sae_vis_data.cfg.device
376
+ )
377
+ assert isinstance(tokens, torch.Tensor)
378
+
379
+ model_wrapped = TransformerLensWrapper(model, ActivationConfig(cfg.hook_point, [])) # type: ignore
380
+
381
+ feature_act_dir = encoder.W_enc[:, feature_idx] # [d_in feats]
382
+ feature_out_dir = encoder.W_dec[feature_idx] # [feats d_in]
383
+ feature_resid_dir = to_resid_direction(
384
+ feature_out_dir, model_wrapped
385
+ ) # [feats d_model]
386
+ assert (
387
+ feature_act_dir.T.shape
388
+ == feature_out_dir.shape
389
+ == (len(feature_idx), encoder.cfg.d_in)
390
+ )
391
+
392
+ # ! Define hook functions to cache all the info required for feature ablation, then run those hook fns
393
+ resid_post, act_post = model_wrapped(tokens, return_logits=False)
394
+ resid_post: Tensor = resid_post.squeeze(0) # type: ignore
395
+ feat_acts = encoder.get_feature_acts_subset(act_post, feature_idx).squeeze( # type: ignore
396
+ 0
397
+ ) # [seq feats] # type: ignore
398
+
399
+ # ! Use the data we've collected to make the scores_dict and update the sae_vis_data
400
+ scores_dict = parse_prompt_data(
401
+ tokens=tokens,
402
+ str_toks=str_toks,
403
+ sae_vis_data=sae_vis_data,
404
+ feat_acts=feat_acts,
405
+ feature_resid_dir=feature_resid_dir,
406
+ resid_post=resid_post,
407
+ W_U=model.W_U,
408
+ feature_idx=feature_idx,
409
+ num_top_features=num_top_features,
410
+ )
411
+
412
+ return scores_dict
SAEDashboard/sae_dashboard/data_writing_fns.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from copy import deepcopy
3
+ from pathlib import Path
4
+
5
+ from tqdm.auto import tqdm
6
+
7
+ from sae_dashboard.data_parsing_fns import get_prompt_data
8
+ from sae_dashboard.html_fns import HTML
9
+ from sae_dashboard.sae_vis_data import SaeVisData
10
+ from sae_dashboard.utils_fns import get_decode_html_safe_fn
11
+
12
+ METRIC_TITLES = {
13
+ "act_size": "Activation Size",
14
+ "act_quantile": "Activation Quantile",
15
+ "loss_effect": "Loss Effect",
16
+ }
17
+
18
+
19
+ def save_feature_centric_vis(
20
+ sae_vis_data: SaeVisData,
21
+ filename: str | Path,
22
+ feature_idx: int | None = None,
23
+ include_only: list[int] | None = None,
24
+ separate_files: bool = False,
25
+ ) -> None:
26
+ """
27
+ Returns the HTML string for the view which lets you navigate between different features.
28
+
29
+ Args:
30
+ sae_vis_data: Object containing visualization data.
31
+ filename: The HTML filepath we'll save the visualization to. If separate_files is True, this is used as a base name.
32
+ feature_idx: This is the default feature index we'll start on. If None, we use the first feature.
33
+ include_only: Optional list of specific features to include.
34
+ separate_files: If True, saves each feature to a separate HTML file.
35
+ """
36
+ # Set the default argument for the dropdown (i.e. when the page first loads)
37
+ first_feature = (
38
+ next(iter(sae_vis_data.feature_data_dict))
39
+ if (feature_idx is None)
40
+ else feature_idx
41
+ )
42
+
43
+ # Get tokenize function (we only need to define it once)
44
+ assert sae_vis_data.model is not None
45
+ assert sae_vis_data.model.tokenizer is not None
46
+ decode_fn = get_decode_html_safe_fn(sae_vis_data.model.tokenizer)
47
+
48
+ # Create iterator
49
+ if include_only is not None:
50
+ iterator = [(i, sae_vis_data.feature_data_dict[i]) for i in include_only]
51
+ else:
52
+ iterator = list(sae_vis_data.feature_data_dict.items())
53
+ if sae_vis_data.cfg.verbose:
54
+ iterator = tqdm(iterator, desc="Saving feature-centric vis")
55
+
56
+ HTML_OBJ = HTML() # Initialize HTML object for combined file
57
+
58
+ # For each FeatureData object, we get the html_obj for its feature-centric vis
59
+ for feature, feature_data in iterator:
60
+ html_obj = feature_data._get_html_data_feature_centric(
61
+ sae_vis_data.cfg.feature_centric_layout, decode_fn
62
+ )
63
+
64
+ if separate_files:
65
+ feature_HTML_OBJ = HTML() # Initialize a new HTML object for each feature
66
+ feature_HTML_OBJ.js_data[str(feature)] = deepcopy(html_obj.js_data)
67
+ feature_HTML_OBJ.html_data = deepcopy(html_obj.html_data)
68
+
69
+ # Add the aggdata
70
+ feature_HTML_OBJ.js_data = {
71
+ "AGGDATA": sae_vis_data.feature_stats.aggdata,
72
+ "DASHBOARD_DATA": feature_HTML_OBJ.js_data,
73
+ }
74
+
75
+ # Generate filename for this feature
76
+ feature_filename = Path(filename).with_stem(
77
+ f"{Path(filename).stem}_feature_{feature}"
78
+ )
79
+
80
+ # Save the HTML for this feature
81
+ feature_HTML_OBJ.get_html(
82
+ layout_columns=sae_vis_data.cfg.feature_centric_layout.columns,
83
+ layout_height=sae_vis_data.cfg.feature_centric_layout.height,
84
+ filename=feature_filename,
85
+ first_key=str(feature),
86
+ )
87
+ else:
88
+ # Original behavior: accumulate all features in one HTML object
89
+ HTML_OBJ.js_data[str(feature)] = deepcopy(html_obj.js_data)
90
+ if feature == first_feature:
91
+ HTML_OBJ.html_data = deepcopy(html_obj.html_data)
92
+
93
+ if not separate_files:
94
+ # Add the aggdata
95
+ HTML_OBJ.js_data = {
96
+ "AGGDATA": sae_vis_data.feature_stats.aggdata,
97
+ "DASHBOARD_DATA": HTML_OBJ.js_data,
98
+ }
99
+
100
+ # Save our full HTML
101
+ HTML_OBJ.get_html(
102
+ layout_columns=sae_vis_data.cfg.feature_centric_layout.columns,
103
+ layout_height=sae_vis_data.cfg.feature_centric_layout.height,
104
+ filename=filename,
105
+ first_key=str(first_feature),
106
+ )
107
+
108
+
109
+ def save_prompt_centric_vis(
110
+ sae_vis_data: SaeVisData,
111
+ prompt: str,
112
+ filename: str | Path,
113
+ metric: str | None = None,
114
+ seq_pos: int | None = None,
115
+ num_top_features: int = 10,
116
+ ):
117
+ """
118
+ Returns the HTML string for the view which lets you navigate between different features.
119
+
120
+ Args:
121
+ prompt: The user-input prompt.
122
+ model: Used to get the tokenizer (for converting token IDs to string tokens).
123
+ filename: The HTML filepath we'll save the visualization to.
124
+ metric: This is the default scoring metric we'll start on. If None, we use 'act_quantile'.
125
+ seq_pos: This is the default seq pos we'll start on. If None, we use 0.
126
+ """
127
+ # Initialize the object we'll eventually get_html from
128
+ HTML_OBJ = HTML()
129
+
130
+ # Run forward passes on our prompt, and store the data within each FeatureData object as `self.prompt_data` as
131
+ # well as returning the scores_dict (which maps from score hash to a list of feature indices & formatted scores)
132
+
133
+ scores_dict = get_prompt_data(
134
+ sae_vis_data=sae_vis_data,
135
+ prompt=prompt,
136
+ num_top_features=num_top_features,
137
+ )
138
+
139
+ # Get all possible values for dropdowns
140
+ str_toks = sae_vis_data.model.tokenizer.tokenize(prompt) # type: ignore
141
+ str_toks = [
142
+ t.replace("|", "│") for t in str_toks
143
+ ] # vertical line -> pipe (hacky, so key splitting on | works)
144
+ str_toks_list = [f"{t!r} ({i})" for i, t in enumerate(str_toks)]
145
+ metric_list = ["act_quantile", "act_size", "loss_effect"]
146
+
147
+ # Get default values for dropdowns
148
+ first_metric = "act_quantile" or metric
149
+ first_seq_pos = str_toks_list[0 if seq_pos is None else seq_pos]
150
+ first_key = f"{first_metric}|{first_seq_pos}"
151
+
152
+ # Get tokenize function (we only need to define it once)
153
+ assert sae_vis_data.model is not None
154
+ assert sae_vis_data.model.tokenizer is not None
155
+ decode_fn = get_decode_html_safe_fn(sae_vis_data.model.tokenizer)
156
+
157
+ # For each (metric, seqpos) object, we merge the prompt-centric views of each of the top features, then we merge
158
+ # these all together into our HTML_OBJ
159
+ for _metric, _seq_pos in itertools.product(metric_list, range(len(str_toks))):
160
+ # Create the key for this given combination of metric & seqpos, and get our top features & scores
161
+ key = f"{_metric}|{str_toks_list[_seq_pos]}"
162
+ if key not in scores_dict:
163
+ continue
164
+ feature_idx_list, scores_formatted = scores_dict[key]
165
+
166
+ # Create HTML object, to store each feature column for all the top features for this particular key
167
+ html_obj = HTML()
168
+
169
+ for i, (feature_idx, score_formatted) in enumerate(
170
+ zip(feature_idx_list, scores_formatted)
171
+ ):
172
+ # Get HTML object at this column (which includes JavaScript to dynamically set the title)
173
+ html_obj += sae_vis_data.feature_data_dict[
174
+ feature_idx
175
+ ]._get_html_data_prompt_centric(
176
+ layout=sae_vis_data.cfg.prompt_centric_layout,
177
+ decode_fn=decode_fn,
178
+ column_idx=i,
179
+ bold_idx=_seq_pos,
180
+ title=f"<h3>#{feature_idx}<br>{METRIC_TITLES[_metric]} = {score_formatted}</h3><hr>",
181
+ )
182
+
183
+ # Add the JavaScript (which includes the titles for each column)
184
+ HTML_OBJ.js_data[key] = deepcopy(html_obj.js_data)
185
+
186
+ # Set the HTML data to be the one with the most columns (since different options might have fewer cols)
187
+ if len(HTML_OBJ.html_data) < len(html_obj.html_data):
188
+ HTML_OBJ.html_data = deepcopy(html_obj.html_data)
189
+
190
+ # Check our first key is in the scores_dict (if not, we should pick a different key)
191
+ assert first_key in scores_dict, "\n".join(
192
+ [
193
+ f"Key {first_key} not found in {scores_dict.keys()=}.",
194
+ "This means that there are no features with a nontrivial score for this choice of key & metric.",
195
+ ]
196
+ )
197
+
198
+ # Add the aggdata
199
+ HTML_OBJ.js_data = {
200
+ "AGGDATA": sae_vis_data.feature_stats.aggdata,
201
+ "DASHBOARD_DATA": HTML_OBJ.js_data,
202
+ }
203
+
204
+ # Save our full HTML
205
+ HTML_OBJ.get_html(
206
+ layout_columns=sae_vis_data.cfg.prompt_centric_layout.columns,
207
+ layout_height=sae_vis_data.cfg.prompt_centric_layout.height,
208
+ filename=filename,
209
+ first_key=first_key,
210
+ )
SAEDashboard/sae_dashboard/dfa_calculator.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union
2
+
3
+ import einops
4
+ import numpy as np
5
+ import torch
6
+ from sae_lens import SAE
7
+ from transformer_lens import ActivationCache, HookedTransformer
8
+
9
+
10
+ class DFACalculator:
11
+ """Calculate DFA values for a given layer and set of feature indices."""
12
+
13
+ def __init__(self, model: HookedTransformer, sae: SAE[Any]):
14
+ self.model = model
15
+ self.sae = sae
16
+ if (
17
+ hasattr(model.cfg, "n_key_value_heads")
18
+ and model.cfg.n_key_value_heads is not None
19
+ and model.cfg.n_key_value_heads < model.cfg.n_heads
20
+ ):
21
+ print("Using GQA")
22
+ self.use_gqa = True
23
+ else:
24
+ self.use_gqa = False
25
+
26
+ def calculate(
27
+ self,
28
+ activations: Union[Dict[str, torch.Tensor], ActivationCache],
29
+ layer_num: int,
30
+ feature_indices: List[int],
31
+ max_value_indices: torch.Tensor,
32
+ ) -> Dict[int, Any]: # type: ignore
33
+ """Calculate DFA values for a given layer and set of feature indices."""
34
+ if not feature_indices:
35
+ return {}
36
+
37
+ v = activations[f"blocks.{layer_num}.attn.hook_v"]
38
+ attn_weights = activations[f"blocks.{layer_num}.attn.hook_pattern"]
39
+
40
+ if self.use_gqa:
41
+ per_src_pos_dfa = self.calculate_gqa_intermediate_tensor(
42
+ attn_weights, v, feature_indices
43
+ )
44
+ else:
45
+ per_src_pos_dfa = self.calculate_standard_intermediate_tensor(
46
+ attn_weights, v, feature_indices
47
+ )
48
+
49
+ n_prompts, seq_len, _, n_features = per_src_pos_dfa.shape
50
+
51
+ # Use advanced indexing to get per_src_dfa
52
+ prompt_indices = torch.arange(n_prompts)[:, None, None]
53
+ src_pos_indices = torch.arange(seq_len)[None, :, None]
54
+ feature_indices_tensor = torch.arange(n_features)[None, None, :]
55
+ max_value_indices_expanded = max_value_indices[:, None, :]
56
+
57
+ per_src_dfa = per_src_pos_dfa[
58
+ prompt_indices,
59
+ max_value_indices_expanded,
60
+ src_pos_indices,
61
+ feature_indices_tensor,
62
+ ]
63
+
64
+ max_values, _ = per_src_dfa.max(dim=1)
65
+
66
+ # Create a structured numpy array to hold all the data
67
+ dtype = np.dtype(
68
+ [
69
+ ("dfa_values", np.float32, (seq_len,)),
70
+ ("dfa_target_index", np.int32),
71
+ ("dfa_max_value", np.float32),
72
+ ]
73
+ )
74
+ results = np.zeros((len(feature_indices), n_prompts), dtype=dtype)
75
+
76
+ # Fill the numpy array with data
77
+ results["dfa_values"] = per_src_dfa.detach().cpu().numpy().transpose(2, 0, 1)
78
+ results["dfa_target_index"] = max_value_indices.detach().cpu().numpy().T
79
+ results["dfa_max_value"] = max_values.detach().cpu().numpy().T
80
+
81
+ # Create a dictionary mapping feature indices to their respective data
82
+ final_results = {
83
+ feat_idx: results[i] for i, feat_idx in enumerate(feature_indices)
84
+ }
85
+
86
+ return final_results
87
+
88
+ def calculate_standard_intermediate_tensor(
89
+ self,
90
+ attn_weights: torch.Tensor,
91
+ v: torch.Tensor,
92
+ feature_indices: List[int],
93
+ ) -> torch.Tensor:
94
+ v_cat = einops.rearrange(
95
+ v, "batch src_pos n_heads d_head -> batch src_pos (n_heads d_head)"
96
+ )
97
+
98
+ attn_weights_bcast = einops.repeat(
99
+ attn_weights,
100
+ "batch n_heads dest_pos src_pos -> batch dest_pos src_pos (n_heads d_head)",
101
+ d_head=self.model.cfg.d_head,
102
+ )
103
+
104
+ decomposed_z_cat = attn_weights_bcast * v_cat.unsqueeze(1)
105
+
106
+ W_enc_selected = self.sae.W_enc[:, feature_indices] # [d_model, num_indices]
107
+
108
+ per_src_pos_dfa = einops.einsum(
109
+ decomposed_z_cat,
110
+ W_enc_selected,
111
+ "batch dest_pos src_pos d_model, d_model num_features -> batch dest_pos src_pos num_features",
112
+ )
113
+
114
+ return per_src_pos_dfa
115
+
116
+ def calculate_gqa_intermediate_tensor(
117
+ self, attn_weights: torch.Tensor, v: torch.Tensor, feature_indices: List[int]
118
+ ) -> torch.Tensor:
119
+ n_query_heads = attn_weights.shape[1]
120
+ n_kv_heads = v.shape[2]
121
+ expansion_factor = n_query_heads // n_kv_heads
122
+ v = v.repeat_interleave(expansion_factor, dim=2)
123
+
124
+ v_cat = einops.rearrange(
125
+ v, "batch src_pos n_heads d_head -> batch src_pos (n_heads d_head)"
126
+ )
127
+
128
+ W_enc_selected = self.sae.W_enc[:, feature_indices] # [d_model, num_indices]
129
+
130
+ # Initialize the result tensor
131
+ n_prompts, seq_len, _ = v_cat.shape
132
+ n_features = len(feature_indices)
133
+ per_src_pos_dfa = torch.zeros(
134
+ (n_prompts, seq_len, seq_len, n_features), device=v_cat.device
135
+ )
136
+
137
+ # Process in chunks
138
+ chunk_size = 16 # Adjust this based on your memory constraints
139
+ for i in range(0, seq_len, chunk_size):
140
+ chunk_end = min(i + chunk_size, seq_len)
141
+
142
+ # Process a chunk of destination positions
143
+ attn_weights_chunk = attn_weights[:, :, i:chunk_end, :]
144
+ attn_weights_bcast_chunk = einops.repeat(
145
+ attn_weights_chunk,
146
+ "batch n_heads dest_pos src_pos -> batch dest_pos src_pos (n_heads d_head)",
147
+ d_head=self.model.cfg.d_head,
148
+ )
149
+ decomposed_z_cat_chunk = attn_weights_bcast_chunk * v_cat.unsqueeze(1)
150
+
151
+ per_src_pos_dfa_chunk = einops.einsum(
152
+ decomposed_z_cat_chunk,
153
+ W_enc_selected,
154
+ "batch dest_pos src_pos d_model, d_model num_features -> batch dest_pos src_pos num_features",
155
+ )
156
+
157
+ per_src_pos_dfa[:, i:chunk_end, :, :] = per_src_pos_dfa_chunk
158
+
159
+ return per_src_pos_dfa
SAEDashboard/sae_dashboard/feature_data.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Callable, List, Literal, Optional
3
+
4
+ from sae_dashboard.components import (
5
+ ActsHistogramData,
6
+ DecoderWeightsDistribution,
7
+ FeatureTablesData,
8
+ GenericData,
9
+ LogitsHistogramData,
10
+ LogitsTableData,
11
+ SequenceData,
12
+ SequenceMultiGroupData,
13
+ )
14
+ from sae_dashboard.components_config import (
15
+ ActsHistogramConfig,
16
+ FeatureTablesConfig,
17
+ GenericComponentConfig,
18
+ LogitsHistogramConfig,
19
+ LogitsTableConfig,
20
+ PromptConfig,
21
+ SequencesConfig,
22
+ )
23
+ from sae_dashboard.html_fns import HTML
24
+ from sae_dashboard.layout import SaeVisLayoutConfig
25
+
26
+
27
+ @dataclass
28
+ class DFAData:
29
+ dfaValues: List[List[float]] = field(default_factory=list)
30
+ dfaTargetIndex: List[int] = field(default_factory=list)
31
+ dfaMaxValue: float = 0.0
32
+
33
+
34
+ @dataclass
35
+ class FeatureData:
36
+ """
37
+ This contains all the data necessary to make the feature-centric visualization, for a single feature. See
38
+ diagram in readme:
39
+
40
+ https://github.com/callummcdougall/sae_vis#data_storing_fnspy
41
+
42
+ Args:
43
+ feature_idx: Index of the feature in question (not used within this class's methods, but used elsewhere).
44
+ cfg: Contains layout parameters which are important in the `get_html` function.
45
+
46
+ The other args are the 6 possible components we might have in the feature-centric vis, i.e. this is where we
47
+ store the actual data. Note that one of these arguments is `prompt_data` which is only applicable in the prompt-
48
+ centric view.
49
+
50
+ This is used in both the feature-centric and prompt-centric views. In the feature-centric view, a single one
51
+ of these objects creates the HTML for a single feature (i.e. a full screen). In the prompt-centric view, a single
52
+ one of these objects will create one column of the full screen vis.
53
+ """
54
+
55
+ feature_tables_data: FeatureTablesData = field(
56
+ default_factory=lambda: FeatureTablesData()
57
+ )
58
+ acts_histogram_data: ActsHistogramData = field(
59
+ default_factory=lambda: ActsHistogramData()
60
+ )
61
+ logits_table_data: LogitsTableData = field(
62
+ default_factory=lambda: LogitsTableData()
63
+ )
64
+ logits_histogram_data: LogitsHistogramData = field(
65
+ default_factory=lambda: LogitsHistogramData()
66
+ )
67
+ sequence_data: SequenceMultiGroupData = field(
68
+ default_factory=lambda: SequenceMultiGroupData()
69
+ )
70
+ prompt_data: SequenceData = field(default_factory=lambda: SequenceData())
71
+ dfa_data: Optional[dict[int, dict[str, Any]]] = None
72
+ decoder_weights_data: Optional[DecoderWeightsDistribution] = None
73
+
74
+ def __post_init__(self):
75
+ if self.dfa_data is None:
76
+ self.dfa_data = {}
77
+
78
+ def get_component_from_config(self, config: GenericComponentConfig) -> GenericData:
79
+ """
80
+ Given a config object, returns the corresponding data object stored by this instance. For instance, if the input
81
+ is an `FeatureTablesConfig` instance, then this function returns `self.feature_tables_data`.
82
+ """
83
+ CONFIG_CLASS_MAP = {
84
+ FeatureTablesConfig.__name__: self.feature_tables_data,
85
+ ActsHistogramConfig.__name__: self.acts_histogram_data,
86
+ LogitsTableConfig.__name__: self.logits_table_data,
87
+ LogitsHistogramConfig.__name__: self.logits_histogram_data,
88
+ SequencesConfig.__name__: self.sequence_data,
89
+ PromptConfig.__name__: self.prompt_data,
90
+ # Add DFA config here if we create a specific config for it
91
+ }
92
+ config_class_name = config.__class__.__name__
93
+ assert (
94
+ config_class_name in CONFIG_CLASS_MAP
95
+ ), f"Invalid component config: {config_class_name}"
96
+ return CONFIG_CLASS_MAP[config_class_name]
97
+
98
+ def _get_html_data_feature_centric(
99
+ self,
100
+ layout: SaeVisLayoutConfig,
101
+ decode_fn: Callable[[int | list[int]], str | list[str]],
102
+ ) -> HTML:
103
+ """
104
+ Returns the HTML object for a single feature-centric view. These are assembled together into the full feature-
105
+ centric view.
106
+
107
+ Args:
108
+ decode_fn: We use this function to decode the token IDs into string tokens.
109
+
110
+ Returns:
111
+ html_obj.html_data:
112
+ Contains a dictionary with keys equal to columns, and values equal to the HTML strings. These will be
113
+ turned into grid-column elements, and concatenated.
114
+ html_obj.js_data:
115
+ Contains a dictionary with keys = component names, and values = JavaScript data that will be used by the
116
+ scripts we'll eventually dump in.
117
+ """
118
+ # Create object to store all HTML
119
+ html_obj = HTML()
120
+
121
+ # For every column in this feature-centric layout, we add all the components in that column
122
+ for column_idx, column_components in layout.columns.items():
123
+ for component_config in column_components:
124
+ component = self.get_component_from_config(component_config)
125
+
126
+ html_obj += component._get_html_data(
127
+ cfg=component_config,
128
+ decode_fn=decode_fn,
129
+ column=column_idx,
130
+ id_suffix="0", # we only use this if we have >1 set of histograms, i.e. prompt-centric vis
131
+ )
132
+
133
+ return html_obj
134
+
135
+ def _get_html_data_prompt_centric(
136
+ self,
137
+ layout: SaeVisLayoutConfig,
138
+ decode_fn: Callable[[int | list[int]], str | list[str]],
139
+ column_idx: int,
140
+ bold_idx: int | Literal["max"],
141
+ title: str,
142
+ ) -> HTML:
143
+ """
144
+ Returns the HTML object for a single column of the prompt-centric view. These are assembled together into a full
145
+ screen of a prompt-centric view, and then they're further assembled together into the full prompt-centric view.
146
+
147
+ Args:
148
+ decode_fn: We use this function to decode the token IDs into string tokens.
149
+ column_idx: This method only gives us a single column (of the prompt-centric vis), so we need to know which
150
+ column this is (for the JavaScript data).
151
+ bold_idx: Which index should be bolded in the sequence data. If "max", we default to bolding the max-act
152
+ token in each sequence.
153
+ title: The title for this column, which will be used in the JavaScript data.
154
+
155
+ Returns:
156
+ html_obj.html_data:
157
+ Contains a dictionary with the single key `str(column_idx)`, representing the single column. This will
158
+ become a single grid-column element, and will get concatenated with others of these.
159
+ html_obj.js_data:
160
+ Contains a dictionary with keys = component names, and values = JavaScript data that will be used by the
161
+ scripts we'll eventually dump in.
162
+ """
163
+ # Create object to store all HTML
164
+ html_obj = HTML()
165
+
166
+ # Verify that we only have a single column
167
+ assert layout.columns.keys() == {
168
+ 0
169
+ }, f"prompt_centric_layout should only have 1 column, instead found cols {layout.columns.keys()}"
170
+ assert (
171
+ layout.prompt_cfg is not None
172
+ ), "prompt_centric_cfg should include a PromptConfig, but found None"
173
+ if layout.seq_cfg is not None:
174
+ assert (layout.seq_cfg.n_quantiles == 0) or (
175
+ layout.seq_cfg.stack_mode == "stack-all"
176
+ ), "prompt_centric_layout should have stack_mode='stack-all' if n_quantiles > 0, so that it fits in 1 col"
177
+
178
+ # Get the maximum color over both the prompt and the sequences
179
+ max_feat_act = max(
180
+ max(self.prompt_data.feat_acts), self.sequence_data.max_feat_act
181
+ )
182
+ max_loss_contribution = max(
183
+ max(self.prompt_data.loss_contribution),
184
+ self.sequence_data.max_loss_contribution,
185
+ )
186
+
187
+ # For every component in the single column of this prompt-centric layout, add all the components in that column
188
+ for component_config in layout.columns[0]:
189
+ component = self.get_component_from_config(component_config)
190
+
191
+ html_obj += component._get_html_data(
192
+ cfg=component_config,
193
+ decode_fn=decode_fn,
194
+ column=column_idx,
195
+ id_suffix=str(column_idx),
196
+ component_specific_kwargs=dict( # only used by SequenceData (the prompt)
197
+ bold_idx=bold_idx,
198
+ permanent_line=True,
199
+ hover_above=True,
200
+ max_feat_act=max_feat_act,
201
+ max_loss_contribution=max_loss_contribution,
202
+ ),
203
+ )
204
+
205
+ # Add the title in JavaScript, and the empty title element in HTML
206
+ html_obj.html_data[column_idx] = (
207
+ f"<div id='column-{column_idx}-title'></div>\n{html_obj.html_data[column_idx]}"
208
+ )
209
+ html_obj.js_data["gridColumnTitlesData"] = {str(column_idx): title}
210
+
211
+ return html_obj
SAEDashboard/sae_dashboard/feature_data_generator.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Dict, List
3
+
4
+ import einops
5
+ import numpy as np
6
+ import torch
7
+ from jaxtyping import Float, Int
8
+ from sae_lens import SAE, HookedSAETransformer
9
+ from sae_lens.config import DTYPE_MAP as DTYPES
10
+ from sae_lens.saes.topk_sae import TopK
11
+ from torch import Tensor, nn
12
+ from tqdm.auto import tqdm
13
+
14
+ from sae_dashboard.dfa_calculator import DFACalculator
15
+ from sae_dashboard.sae_vis_data import SaeVisConfig
16
+ from sae_dashboard.transformer_lens_wrapper import to_resid_direction
17
+ from sae_dashboard.utils_fns import RollingCorrCoef
18
+
19
+ Arr = np.ndarray
20
+
21
+
22
+ class FeatureDataGenerator:
23
+ def __init__(
24
+ self,
25
+ cfg: SaeVisConfig,
26
+ tokens: Int[Tensor, "batch seq"],
27
+ model: HookedSAETransformer,
28
+ encoder: SAE[Any],
29
+ ):
30
+ self.cfg = cfg
31
+ self.model = model
32
+ self.encoder = encoder
33
+ self.token_minibatches = self.batch_tokens(tokens)
34
+ self.dfa_calculator = (
35
+ DFACalculator(model.model, encoder) if cfg.use_dfa else None # type: ignore
36
+ )
37
+
38
+ if cfg.use_dfa:
39
+ assert (
40
+ "hook_z" in encoder.cfg.hook_name
41
+ ), f"DFAs are only supported for hook_z, but got {encoder.cfg.hook_name}"
42
+
43
+ @torch.inference_mode()
44
+ def batch_tokens(
45
+ self, tokens: Int[Tensor, "batch seq"]
46
+ ) -> list[Int[Tensor, "batch seq"]]:
47
+ # Get tokens into minibatches, for the fwd pass
48
+ token_minibatches = (
49
+ (tokens,)
50
+ if self.cfg.minibatch_size_tokens is None
51
+ else tokens.split(self.cfg.minibatch_size_tokens)
52
+ )
53
+ token_minibatches = [tok.to(self.cfg.device) for tok in token_minibatches]
54
+
55
+ return token_minibatches
56
+
57
+ @torch.inference_mode()
58
+ def get_feature_data( # type: ignore
59
+ self,
60
+ feature_indices: list[int],
61
+ progress: list[tqdm] | None = None, # type: ignore
62
+ ): # type: ignore
63
+ # Create lists to store the feature activations & final values of the residual stream
64
+ all_feat_acts = []
65
+ all_dfa_results = {feature_idx: {} for feature_idx in feature_indices}
66
+ total_prompts = 0
67
+
68
+ # Create objects to store the data for computing rolling stats
69
+ corrcoef_neurons = RollingCorrCoef()
70
+ corrcoef_encoder = RollingCorrCoef(indices=feature_indices, with_self=True)
71
+
72
+ # Get encoder & decoder directions
73
+ feature_out_dir = self.encoder.W_dec[feature_indices] # [feats d_autoencoder]
74
+ feature_resid_dir = to_resid_direction(
75
+ feature_out_dir, self.model # type: ignore
76
+ ) # [feats d_model]
77
+
78
+ # ! Compute & concatenate together all feature activations & post-activation function values
79
+ for i, minibatch in enumerate(self.token_minibatches):
80
+ minibatch.to(self.cfg.device)
81
+ model_activation_dict = self.get_model_acts(i, minibatch)
82
+ primary_acts = model_activation_dict[
83
+ self.model.activation_config.primary_hook_point # type: ignore
84
+ ].to(
85
+ self.encoder.device
86
+ ) # make sure acts are on the correct device
87
+
88
+ # For TopK, compute all activations first, then select features
89
+ if isinstance(self.encoder.activation_fn, TopK):
90
+ # Get all features' activations
91
+ all_features_acts = self.encoder.encode(primary_acts)
92
+ # Then select only the features we're interested in
93
+ feature_acts = all_features_acts[:, :, feature_indices].to(
94
+ DTYPES[self.cfg.dtype]
95
+ )
96
+ else:
97
+ # For other activation functions, use the masking context
98
+ with FeatureMaskingContext(self.encoder, feature_indices):
99
+ feature_acts = self.encoder.encode(primary_acts).to(
100
+ DTYPES[self.cfg.dtype]
101
+ )
102
+
103
+ self.update_rolling_coefficients(
104
+ model_acts=primary_acts,
105
+ feature_acts=feature_acts,
106
+ corrcoef_neurons=corrcoef_neurons,
107
+ corrcoef_encoder=corrcoef_encoder,
108
+ )
109
+
110
+ # Add these to the lists (we'll eventually concat)
111
+ all_feat_acts.append(feature_acts)
112
+
113
+ # Calculate DFA
114
+ if self.cfg.use_dfa and self.dfa_calculator:
115
+ max_value_indices = torch.argmax(feature_acts, dim=1)
116
+ batch_dfa_results = self.dfa_calculator.calculate(
117
+ model_activation_dict,
118
+ self.model.hook_layer, # type: ignore
119
+ feature_indices,
120
+ max_value_indices,
121
+ )
122
+ for feature_idx, feature_data in batch_dfa_results.items():
123
+ for prompt_idx in range(feature_data.shape[0]):
124
+ global_prompt_idx = total_prompts + prompt_idx
125
+ all_dfa_results[feature_idx][global_prompt_idx] = {
126
+ "dfaValues": feature_data[prompt_idx][
127
+ "dfa_values"
128
+ ].tolist(),
129
+ "dfaTargetIndex": int(
130
+ feature_data[prompt_idx]["dfa_target_index"]
131
+ ),
132
+ "dfaMaxValue": float(
133
+ feature_data[prompt_idx]["dfa_max_value"]
134
+ ),
135
+ }
136
+
137
+ total_prompts += len(minibatch)
138
+
139
+ # Update the 1st progress bar (fwd passes & getting sequence data dominates the runtime of these computations)
140
+ if progress is not None:
141
+ progress[0].update(1)
142
+
143
+ all_feat_acts = torch.cat(all_feat_acts, dim=0)
144
+
145
+ return (
146
+ all_feat_acts,
147
+ torch.tensor([]), # all_resid_post, no longer used
148
+ feature_resid_dir,
149
+ feature_out_dir,
150
+ corrcoef_neurons,
151
+ corrcoef_encoder,
152
+ all_dfa_results,
153
+ )
154
+
155
+ @torch.inference_mode()
156
+ def get_model_acts(
157
+ self,
158
+ minibatch_index: int,
159
+ minibatch_tokens: torch.Tensor,
160
+ use_cache: bool = True,
161
+ ) -> Dict[str, torch.Tensor]:
162
+ """
163
+ A function that gets the model activations for a given minibatch of tokens.
164
+ Uses np.memmap for efficient caching.
165
+ """
166
+ if self.cfg.cache_dir is not None:
167
+ cache_path = self.cfg.cache_dir / f"model_activations_{minibatch_index}.pt"
168
+ if use_cache and cache_path.exists():
169
+ activation_dict = load_tensor_dict_torch(cache_path, self.cfg.device)
170
+ else:
171
+ activation_dict = self.model.forward(
172
+ minibatch_tokens.to("cpu"), return_logits=False # type: ignore
173
+ )
174
+ save_tensor_dict_torch(activation_dict, cache_path)
175
+ else:
176
+ activation_dict = self.model.forward(
177
+ minibatch_tokens.to("cpu"), return_logits=False # type: ignore
178
+ )
179
+
180
+ return activation_dict
181
+
182
+ @torch.inference_mode()
183
+ def update_rolling_coefficients(
184
+ self,
185
+ model_acts: Float[Tensor, "batch seq d_in"],
186
+ feature_acts: Float[Tensor, "batch seq feats"],
187
+ corrcoef_neurons: RollingCorrCoef | None,
188
+ corrcoef_encoder: RollingCorrCoef | None,
189
+ ) -> None:
190
+ """
191
+
192
+ Args:
193
+ model_acts: Float[Tensor, "batch seq d_in"]
194
+ The activations of the model, which the SAE was trained on.
195
+ feature_idx: list[int]
196
+ The features we're computing the activations for. This will be used to index the encoder's weights.
197
+ corrcoef_neurons: Optional[RollingCorrCoef]
198
+ The object storing the minimal data necessary to compute corrcoef between feature activations & neurons.
199
+ corrcoef_encoder: Optional[RollingCorrCoef]
200
+ The object storing the minimal data necessary to compute corrcoef between pairwise feature activations.
201
+ """
202
+ # Update the CorrCoef object between feature activation & neurons
203
+ if corrcoef_neurons is not None:
204
+ corrcoef_neurons.update(
205
+ einops.rearrange(feature_acts, "batch seq feats -> feats (batch seq)"),
206
+ einops.rearrange(model_acts, "batch seq d_in -> d_in (batch seq)"),
207
+ )
208
+
209
+ # Update the CorrCoef object between pairwise feature activations
210
+ if corrcoef_encoder is not None:
211
+ corrcoef_encoder.update(
212
+ einops.rearrange(feature_acts, "batch seq feats -> feats (batch seq)"),
213
+ einops.rearrange(feature_acts, "batch seq feats -> feats (batch seq)"),
214
+ )
215
+
216
+
217
+ def save_tensor_dict_torch(tensor_dict: Dict[str, torch.Tensor], filename: Path):
218
+ torch.save(tensor_dict, filename)
219
+
220
+
221
+ def load_tensor_dict_torch(filename: Path, device: str) -> Dict[str, torch.Tensor]:
222
+ return torch.load(
223
+ filename, map_location=torch.device(device)
224
+ ) # Directly load to GPU
225
+
226
+
227
+ class FeatureMaskingContext:
228
+ def __init__(self, sae: SAE[Any], feature_idxs: List[int]):
229
+ self.sae = sae
230
+ self.feature_idxs = feature_idxs
231
+ self.original_weight = {}
232
+
233
+ def __enter__(self):
234
+ ## W_dec
235
+ self.original_weight["W_dec"] = getattr(self.sae, "W_dec").data.clone()
236
+ # mask the weight
237
+ masked_weight = self.sae.W_dec[self.feature_idxs]
238
+ # set the weight
239
+ setattr(self.sae, "W_dec", nn.Parameter(masked_weight))
240
+
241
+ ## W_enc
242
+ # clone the weight.
243
+ self.original_weight["W_enc"] = getattr(self.sae, "W_enc").data.clone()
244
+ # mask the weight
245
+ masked_weight = self.sae.W_enc[:, self.feature_idxs]
246
+ # set the weight
247
+ setattr(self.sae, "W_enc", nn.Parameter(masked_weight))
248
+
249
+ # Handle architecture as either attribute or method
250
+ architecture = self.sae.cfg.architecture
251
+ if callable(architecture):
252
+ architecture = architecture()
253
+
254
+ if architecture in [
255
+ "standard",
256
+ "standard_transcoder",
257
+ "transcoder",
258
+ "skip_transcoder",
259
+ ]:
260
+ ## b_enc
261
+ self.original_weight["b_enc"] = getattr(self.sae, "b_enc").data.clone()
262
+ # mask the weight
263
+ masked_weight = self.sae.b_enc[self.feature_idxs] # type: ignore
264
+ # set the weight
265
+ setattr(self.sae, "b_enc", nn.Parameter(masked_weight))
266
+
267
+ elif architecture in ["jumprelu", "jumprelu_transcoder"]:
268
+ ## b_enc
269
+ self.original_weight["b_enc"] = getattr(self.sae, "b_enc").data.clone()
270
+ # mask the weight
271
+ masked_weight = self.sae.b_enc[self.feature_idxs] # type: ignore
272
+ # set the weight
273
+ setattr(self.sae, "b_enc", nn.Parameter(masked_weight))
274
+
275
+ ## threshold
276
+ self.original_weight["threshold"] = getattr(
277
+ self.sae, "threshold"
278
+ ).data.clone()
279
+ # mask the weight
280
+ masked_weight = self.sae.threshold[self.feature_idxs] # type: ignore
281
+ # set the weight
282
+ setattr(self.sae, "threshold", nn.Parameter(masked_weight))
283
+
284
+ elif architecture in ["gated", "gated_transcoder"]:
285
+ ## b_gate
286
+ self.original_weight["b_gate"] = getattr(self.sae, "b_gate").data.clone()
287
+ # mask the weight
288
+ masked_weight = self.sae.b_gate[self.feature_idxs] # type: ignore
289
+ # set the weight
290
+ setattr(self.sae, "b_gate", nn.Parameter(masked_weight))
291
+
292
+ ## r_mag
293
+ self.original_weight["r_mag"] = getattr(self.sae, "r_mag").data.clone()
294
+ # mask the weight
295
+ masked_weight = self.sae.r_mag[self.feature_idxs] # type: ignore
296
+ # set the weight
297
+ setattr(self.sae, "r_mag", nn.Parameter(masked_weight))
298
+
299
+ ## b_mag
300
+ self.original_weight["b_mag"] = getattr(self.sae, "b_mag").data.clone()
301
+ # mask the weight
302
+ masked_weight = self.sae.b_mag[self.feature_idxs] # type: ignore
303
+ # set the weight
304
+ setattr(self.sae, "b_mag", nn.Parameter(masked_weight))
305
+ else:
306
+ raise (ValueError("Invalid architecture"))
307
+
308
+ return self
309
+
310
+ def __exit__(self, exc_type, exc_value, traceback): # type: ignore
311
+ # set everything back to normal
312
+ for key, value in self.original_weight.items():
313
+ setattr(self.sae, key, nn.Parameter(value))
SAEDashboard/sae_dashboard/html/acts_histogram_template.html ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ <!-- Activation densities histogram -->
2
+ <div class="plotly-hist" id="HISTOGRAM_ACTS_ID" style="height: 150px; margin-top: 0px;"></div>
SAEDashboard/sae_dashboard/html/feature_tables_template.html ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ <!-- Feature Info Tables -->
2
+ <div class="feature-tables" id="FEATURE_TABLES_ID"></div>
SAEDashboard/sae_dashboard/html/logits_histogram_template.html ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ <!-- Logits histogram -->
2
+ <div class="plotly-hist" id="HISTOGRAM_LOGITS_ID" style="height: 150px; margin-top: 0px;"></div>
SAEDashboard/sae_dashboard/html/logits_table_template.html ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ <!-- Logits table -->
2
+ <div class="logits-table" id="LOGITS_TABLE_ID"></div>
SAEDashboard/sae_dashboard/html/sequences_group_template.html ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ <!-- Sequence group -->
2
+ <div class="seq-group" id="SEQUENCE_GROUP_ID"></div>