Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- SAEDashboard/.dockerignore +184 -0
- SAEDashboard/.flake8 +8 -0
- SAEDashboard/.github/workflows/ci.yaml +117 -0
- SAEDashboard/.gitignore +204 -0
- SAEDashboard/.vscode/settings.json +18 -0
- SAEDashboard/CHANGELOG.md +1263 -0
- SAEDashboard/Dockerfile +45 -0
- SAEDashboard/LICENSE +21 -0
- SAEDashboard/Makefile +27 -0
- SAEDashboard/README.md +221 -0
- SAEDashboard/docker/docker-entrypoint.sh +11 -0
- SAEDashboard/docker/docker-hub.yaml +57 -0
- SAEDashboard/neuronpedia_vector_pipeline_demo.ipynb +282 -0
- SAEDashboard/notebooks/experiment_gemma_2_9b_dashboard_generation_np.py +52 -0
- SAEDashboard/notebooks/sae_dashboard_demo_gemma_2_9b.ipynb +618 -0
- SAEDashboard/pyproject.toml +70 -0
- SAEDashboard/sae_dashboard/__init__.py +10 -0
- SAEDashboard/sae_dashboard/__pycache__/__init__.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/components.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/components_config.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/data_parsing_fns.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/data_writing_fns.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/dfa_calculator.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/feature_data.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/feature_data_generator.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/html_fns.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/layout.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/sae_vis_data.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/sae_vis_runner.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/sequence_data_generator.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/transformer_lens_wrapper.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/utils_fns.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/__pycache__/vector_vis_data.cpython-313.pyc +0 -0
- SAEDashboard/sae_dashboard/clt_layer_wrapper.py +697 -0
- SAEDashboard/sae_dashboard/components.py +774 -0
- SAEDashboard/sae_dashboard/components_config.py +206 -0
- SAEDashboard/sae_dashboard/css/dropdown.css +40 -0
- SAEDashboard/sae_dashboard/css/general.css +53 -0
- SAEDashboard/sae_dashboard/css/sequences.css +61 -0
- SAEDashboard/sae_dashboard/css/tables.css +81 -0
- SAEDashboard/sae_dashboard/data_parsing_fns.py +412 -0
- SAEDashboard/sae_dashboard/data_writing_fns.py +210 -0
- SAEDashboard/sae_dashboard/dfa_calculator.py +159 -0
- SAEDashboard/sae_dashboard/feature_data.py +211 -0
- SAEDashboard/sae_dashboard/feature_data_generator.py +313 -0
- SAEDashboard/sae_dashboard/html/acts_histogram_template.html +2 -0
- SAEDashboard/sae_dashboard/html/feature_tables_template.html +2 -0
- SAEDashboard/sae_dashboard/html/logits_histogram_template.html +2 -0
- SAEDashboard/sae_dashboard/html/logits_table_template.html +2 -0
- SAEDashboard/sae_dashboard/html/sequences_group_template.html +2 -0
SAEDashboard/.dockerignore
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# ruff
|
| 147 |
+
.ruff_cache
|
| 148 |
+
|
| 149 |
+
# Pyre type checker
|
| 150 |
+
.pyre/
|
| 151 |
+
|
| 152 |
+
# pytype static type analyzer
|
| 153 |
+
.pytype/
|
| 154 |
+
|
| 155 |
+
# Cython debug symbols
|
| 156 |
+
cython_debug/
|
| 157 |
+
|
| 158 |
+
# PyCharm
|
| 159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 163 |
+
#.idea/
|
| 164 |
+
|
| 165 |
+
*.pkl
|
| 166 |
+
*.pt
|
| 167 |
+
sae_vis/archive_fns.py
|
| 168 |
+
*__pycache__
|
| 169 |
+
mats_sae_training
|
| 170 |
+
callum_instructions.md
|
| 171 |
+
april-fools
|
| 172 |
+
*large.html
|
| 173 |
+
requirements.txt
|
| 174 |
+
tests/fixtures/cache_benchmark/
|
| 175 |
+
tests/fixtures/cache_unit/
|
| 176 |
+
|
| 177 |
+
neuronpedia_outputs/
|
| 178 |
+
cached_activations/
|
| 179 |
+
wandb/
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
**.safetensors
|
| 183 |
+
**flamegraph.html
|
| 184 |
+
artifacts/
|
SAEDashboard/.flake8
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
extend-ignore = E203, E266, E501, W503, E721, F722, E731, E402, F821
|
| 3 |
+
max-line-length = 79
|
| 4 |
+
max-complexity = 25
|
| 5 |
+
extend-select = E9, F63, F7, F82
|
| 6 |
+
show-source = true
|
| 7 |
+
statistics = true
|
| 8 |
+
exclude = ./wandb/*, ./research/wandb/*, .venv/*
|
SAEDashboard/.github/workflows/ci.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: "ci"
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
pull_request:
|
| 5 |
+
branches: ["**"]
|
| 6 |
+
push:
|
| 7 |
+
branches: ["**"]
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
build:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
strategy:
|
| 13 |
+
matrix:
|
| 14 |
+
python-version: ["3.10", "3.11", "3.12"]
|
| 15 |
+
steps:
|
| 16 |
+
- uses: actions/checkout@v4
|
| 17 |
+
|
| 18 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 19 |
+
uses: actions/setup-python@v5
|
| 20 |
+
with:
|
| 21 |
+
python-version: ${{ matrix.python-version }}
|
| 22 |
+
|
| 23 |
+
- name: Cache Huggingface assets
|
| 24 |
+
uses: actions/cache@v4
|
| 25 |
+
with:
|
| 26 |
+
key: huggingface-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
|
| 27 |
+
path: ~/.cache/huggingface
|
| 28 |
+
restore-keys: |
|
| 29 |
+
huggingface-${{ runner.os }}-${{ matrix.python-version }}-
|
| 30 |
+
|
| 31 |
+
- name: Load cached Poetry installation
|
| 32 |
+
id: cached-poetry
|
| 33 |
+
uses: actions/cache@v4
|
| 34 |
+
with:
|
| 35 |
+
path: ~/.local
|
| 36 |
+
key: poetry-${{ runner.os }}-${{ matrix.python-version }}-1 # Incremented to reset cache
|
| 37 |
+
|
| 38 |
+
- name: Install Poetry
|
| 39 |
+
if: steps.cached-poetry.outputs.cache-hit != 'true'
|
| 40 |
+
uses: snok/install-poetry@v1
|
| 41 |
+
with:
|
| 42 |
+
version: 1.5.1 # Specify a version explicitly
|
| 43 |
+
virtualenvs-create: true
|
| 44 |
+
virtualenvs-in-project: true
|
| 45 |
+
installer-parallel: true
|
| 46 |
+
|
| 47 |
+
- name: Check Poetry Version
|
| 48 |
+
run: poetry --version
|
| 49 |
+
|
| 50 |
+
- name: Load cached venv
|
| 51 |
+
id: cached-poetry-dependencies
|
| 52 |
+
uses: actions/cache@v4
|
| 53 |
+
with:
|
| 54 |
+
path: .venv
|
| 55 |
+
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}-1 # Incremented to reset cache
|
| 56 |
+
restore-keys: |
|
| 57 |
+
venv-${{ runner.os }}-${{ matrix.python-version }}-
|
| 58 |
+
|
| 59 |
+
- name: Install dependencies
|
| 60 |
+
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
|
| 61 |
+
run: poetry install --no-interaction
|
| 62 |
+
|
| 63 |
+
- name: List installed packages
|
| 64 |
+
run: poetry run pip list
|
| 65 |
+
|
| 66 |
+
- name: Check flake8 installation
|
| 67 |
+
run: poetry run which flake8
|
| 68 |
+
|
| 69 |
+
- name: check linting
|
| 70 |
+
run: poetry run flake8 .
|
| 71 |
+
|
| 72 |
+
- name: check formatting
|
| 73 |
+
run: poetry run black --check .
|
| 74 |
+
|
| 75 |
+
- name: check types
|
| 76 |
+
run: poetry run pyright .
|
| 77 |
+
|
| 78 |
+
- name: test
|
| 79 |
+
run: poetry run pytest --cov=sae_dashboard --cov-report=term-missing tests/unit
|
| 80 |
+
|
| 81 |
+
- name: build
|
| 82 |
+
run: poetry build
|
| 83 |
+
|
| 84 |
+
release:
|
| 85 |
+
needs: build
|
| 86 |
+
permissions:
|
| 87 |
+
contents: write
|
| 88 |
+
id-token: write
|
| 89 |
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):')
|
| 90 |
+
runs-on: ubuntu-latest
|
| 91 |
+
concurrency: release
|
| 92 |
+
environment:
|
| 93 |
+
name: pypi
|
| 94 |
+
steps:
|
| 95 |
+
- uses: actions/checkout@v4
|
| 96 |
+
with:
|
| 97 |
+
fetch-depth: 0
|
| 98 |
+
|
| 99 |
+
- uses: actions/setup-python@v5
|
| 100 |
+
with:
|
| 101 |
+
python-version: "3.11"
|
| 102 |
+
|
| 103 |
+
- name: Semantic Release
|
| 104 |
+
id: release
|
| 105 |
+
uses: python-semantic-release/python-semantic-release@v9.8.8
|
| 106 |
+
with:
|
| 107 |
+
github_token: ${{ secrets.GITHUB_TOKEN }}
|
| 108 |
+
|
| 109 |
+
- name: Publish package distributions to PyPI
|
| 110 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
| 111 |
+
if: steps.release.outputs.released == 'true'
|
| 112 |
+
|
| 113 |
+
- name: Publish package distributions to GitHub Releases
|
| 114 |
+
uses: python-semantic-release/upload-to-gh-release@main
|
| 115 |
+
if: steps.release.outputs.released == 'true'
|
| 116 |
+
with:
|
| 117 |
+
github_token: ${{ secrets.GITHUB_TOKEN }}
|
SAEDashboard/.gitignore
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# ruff
|
| 147 |
+
.ruff_cache
|
| 148 |
+
|
| 149 |
+
# Pyre type checker
|
| 150 |
+
.pyre/
|
| 151 |
+
|
| 152 |
+
# pytype static type analyzer
|
| 153 |
+
.pytype/
|
| 154 |
+
|
| 155 |
+
# Cython debug symbols
|
| 156 |
+
cython_debug/
|
| 157 |
+
|
| 158 |
+
# PyCharm
|
| 159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 163 |
+
#.idea/
|
| 164 |
+
|
| 165 |
+
*.pkl
|
| 166 |
+
*.pt
|
| 167 |
+
sae_vis/archive_fns.py
|
| 168 |
+
*__pycache__
|
| 169 |
+
mats_sae_training
|
| 170 |
+
callum_instructions.md
|
| 171 |
+
april-fools
|
| 172 |
+
*large.html
|
| 173 |
+
requirements.txt
|
| 174 |
+
tests/fixtures/cache_benchmark/
|
| 175 |
+
tests/fixtures/cache_unit/
|
| 176 |
+
|
| 177 |
+
neuronpedia_outputs/
|
| 178 |
+
cached_activations/
|
| 179 |
+
wandb/
|
| 180 |
+
demo_activations_cache/
|
| 181 |
+
test_activations_cache/
|
| 182 |
+
demo_feature_dashboards.html
|
| 183 |
+
|
| 184 |
+
**.safetensors
|
| 185 |
+
**flamegraph.html
|
| 186 |
+
artifacts/
|
| 187 |
+
prof/
|
| 188 |
+
|
| 189 |
+
.vscode/settings.json
|
| 190 |
+
dfa_tests.ipynb
|
| 191 |
+
|
| 192 |
+
.DS_Store
|
| 193 |
+
|
| 194 |
+
# Test and temporary directories
|
| 195 |
+
crosslayer-coding/
|
| 196 |
+
SAELens/
|
| 197 |
+
clt_test*/
|
| 198 |
+
test_output/
|
| 199 |
+
test_outputs/
|
| 200 |
+
clt-technical-description.md
|
| 201 |
+
|
| 202 |
+
ignore_data/
|
| 203 |
+
|
| 204 |
+
outputs/
|
SAEDashboard/.vscode/settings.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python.testing.pytestArgs": [
|
| 3 |
+
"tests"
|
| 4 |
+
],
|
| 5 |
+
"python.testing.unittestEnabled": false,
|
| 6 |
+
"python.testing.pytestEnabled": true,
|
| 7 |
+
|
| 8 |
+
"[python]": {
|
| 9 |
+
"editor.defaultFormatter": "ms-python.black-formatter",
|
| 10 |
+
"editor.formatOnSave": true,
|
| 11 |
+
"editor.codeActionsOnSave": {
|
| 12 |
+
"source.organizeImports": "explicit"
|
| 13 |
+
}
|
| 14 |
+
},
|
| 15 |
+
"isort.args": ["--profile", "black"],
|
| 16 |
+
"editor.defaultFormatter": "mikoz.black-py",
|
| 17 |
+
"liveServer.settings.port": 5501
|
| 18 |
+
}
|
SAEDashboard/CHANGELOG.md
ADDED
|
@@ -0,0 +1,1263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CHANGELOG
|
| 2 |
+
|
| 3 |
+
## v0.7.3 (2025-10-11)
|
| 4 |
+
|
| 5 |
+
### Fix
|
| 6 |
+
|
| 7 |
+
* fix: broken dependencies ([`25ce6e8`](https://github.com/jbloomAus/SAEDashboard/commit/25ce6e8ae2debe232b0eff4ba910fa10fc816480))
|
| 8 |
+
|
| 9 |
+
### Unknown
|
| 10 |
+
|
| 11 |
+
* Merge pull request #70 from jbloomAus/fix_deps
|
| 12 |
+
|
| 13 |
+
fix: broken dependencies ([`352b9b2`](https://github.com/jbloomAus/SAEDashboard/commit/352b9b2148b62c8fabb7adcf6bf0cbacfa345a74))
|
| 14 |
+
|
| 15 |
+
* update .gitignore ([`6133ca6`](https://github.com/jbloomAus/SAEDashboard/commit/6133ca67b39bf6e033e8b4792f5eba9850668821))
|
| 16 |
+
|
| 17 |
+
## v0.7.2 (2025-09-01)
|
| 18 |
+
|
| 19 |
+
### Fix
|
| 20 |
+
|
| 21 |
+
* fix: use clt pypi library ([`a3197b8`](https://github.com/jbloomAus/SAEDashboard/commit/a3197b8870d43107bba0356af11c64ddc054392d))
|
| 22 |
+
|
| 23 |
+
## v0.7.1 (2025-08-31)
|
| 24 |
+
|
| 25 |
+
### Fix
|
| 26 |
+
|
| 27 |
+
* fix: force build ([`f432ba2`](https://github.com/jbloomAus/SAEDashboard/commit/f432ba2d14edd4c11ac71e947fbb1e97790e753e))
|
| 28 |
+
|
| 29 |
+
## v0.7.0 (2025-08-31)
|
| 30 |
+
|
| 31 |
+
### Feature
|
| 32 |
+
|
| 33 |
+
* feat: Merge pull request #69 from jbloomAus/qwen-transcoder
|
| 34 |
+
|
| 35 |
+
Transcoder Support + SAELens v6 ([`8f8651e`](https://github.com/jbloomAus/SAEDashboard/commit/8f8651edaf8c20bc8eaff09e05de238a4ce780fb))
|
| 36 |
+
|
| 37 |
+
### Fix
|
| 38 |
+
|
| 39 |
+
* fix: relax saelens to not break saelens demo project ([`e38d140`](https://github.com/jbloomAus/SAEDashboard/commit/e38d1408f30f47f9e91646fa8646300750b23fd3))
|
| 40 |
+
|
| 41 |
+
### Unknown
|
| 42 |
+
|
| 43 |
+
* Merge branch 'main' into qwen-transcoder ([`ae9dede`](https://github.com/jbloomAus/SAEDashboard/commit/ae9dedead671996ebf4f7de84eb4252b2af85fc2))
|
| 44 |
+
|
| 45 |
+
* Merge pull request #67 from jbloomAus/relax_saelens
|
| 46 |
+
|
| 47 |
+
fix: relax saelens to not break saelens demo project ([`fa1691a`](https://github.com/jbloomAus/SAEDashboard/commit/fa1691ab224e684618b2e800b8fda8af741eb81b))
|
| 48 |
+
|
| 49 |
+
## v0.6.11 (2025-08-05)
|
| 50 |
+
|
| 51 |
+
### Fix
|
| 52 |
+
|
| 53 |
+
* fix: fixes tool.semantic_release subtable (#66) ([`eb36157`](https://github.com/jbloomAus/SAEDashboard/commit/eb361571550a4653f7fbcc5a9cc2c98c329aaf41))
|
| 54 |
+
|
| 55 |
+
* fix: fixes tool.semantic_release subtable (#66) ([`725d76d`](https://github.com/jbloomAus/SAEDashboard/commit/725d76d9ac00e3b295c6b11f4657f4432c925e9e))
|
| 56 |
+
|
| 57 |
+
### Unknown
|
| 58 |
+
|
| 59 |
+
* upgrades python-semantic-release (#65) ([`fcdae8b`](https://github.com/jbloomAus/SAEDashboard/commit/fcdae8b0e18bbf2c5184977bdf14cc14280fb6bc))
|
| 60 |
+
|
| 61 |
+
* fix CI ([`d7dffdd`](https://github.com/jbloomAus/SAEDashboard/commit/d7dffdd79b0165cbdfb360f13622385f884a2158))
|
| 62 |
+
|
| 63 |
+
* upgrades python-semantic-release (#65) ([`721d683`](https://github.com/jbloomAus/SAEDashboard/commit/721d683b437f378f22c3e713cb4e3f16bdc82e1a))
|
| 64 |
+
|
| 65 |
+
* converter ([`30ff988`](https://github.com/jbloomAus/SAEDashboard/commit/30ff9881f4ea1adc3305675088f5fb808367d5eb))
|
| 66 |
+
|
| 67 |
+
* bos override ([`bd04133`](https://github.com/jbloomAus/SAEDashboard/commit/bd04133033ce5b422f62fc12099eb6527d6f8070))
|
| 68 |
+
|
| 69 |
+
* add prefix tokens to cli ([`4560dd7`](https://github.com/jbloomAus/SAEDashboard/commit/4560dd7301a713b3b45d836df3e43bf25ba52c64))
|
| 70 |
+
|
| 71 |
+
* add prefix tokens to cli ([`50636fb`](https://github.com/jbloomAus/SAEDashboard/commit/50636fb667384fbf1df8d7809cfe3c5ebc44beab))
|
| 72 |
+
|
| 73 |
+
* top acts group 20 ([`1e3d3d4`](https://github.com/jbloomAus/SAEDashboard/commit/1e3d3d4228b5f526f9a2e1d90cb51ff74d8e60b8))
|
| 74 |
+
|
| 75 |
+
* temp updates for qwen transcoder ([`5727ac9`](https://github.com/jbloomAus/SAEDashboard/commit/5727ac944feed9eb78f60f6553feb40c7b6622d8))
|
| 76 |
+
|
| 77 |
+
* some config fixes ([`51d903a`](https://github.com/jbloomAus/SAEDashboard/commit/51d903a98ebc2402f6d4863f1efc1062420ea3eb))
|
| 78 |
+
|
| 79 |
+
* olved double normalization ([`5293073`](https://github.com/jbloomAus/SAEDashboard/commit/5293073adef90efc0804c61d9dbe3ab55430f62b))
|
| 80 |
+
|
| 81 |
+
* updated readme ([`6ea06f2`](https://github.com/jbloomAus/SAEDashboard/commit/6ea06f2371e80191d8a29cca0c5e134943db02d0))
|
| 82 |
+
|
| 83 |
+
* formatting changes ([`ad7422e`](https://github.com/jbloomAus/SAEDashboard/commit/ad7422e11a5525d8584991e278a3342ebf4ff892))
|
| 84 |
+
|
| 85 |
+
* Update CLT test script parameters ([`9961859`](https://github.com/jbloomAus/SAEDashboard/commit/996185947217ccd13d6763462d764cbd2977a28e))
|
| 86 |
+
|
| 87 |
+
* Merge pull request #64 from jbloomAus/clt-support
|
| 88 |
+
|
| 89 |
+
CLT Support ([`57971a9`](https://github.com/jbloomAus/SAEDashboard/commit/57971a9e85b62cc2a9e7bf04ee8d8c26fac9cecc))
|
| 90 |
+
|
| 91 |
+
* Add Cross-Layer Transcoder (CLT) support to SAEDashboard
|
| 92 |
+
|
| 93 |
+
- Add CLTLayerWrapper to provide SAE-compatible interface for CLTs
|
| 94 |
+
- Integrate CLT loading into NeuronpediaRunner with --use-clt flag
|
| 95 |
+
- Add CLT-specific configuration parameters (clt_layer_idx, clt_weights_filename)
|
| 96 |
+
- Support JumpReLU activation with learned thresholds
|
| 97 |
+
- Add normalization statistics loading from norm_stats.json
|
| 98 |
+
- Handle CLT-specific hook naming conventions (tl_input_template)
|
| 99 |
+
- Add comprehensive unit tests for CLT functionality
|
| 100 |
+
- Fix existing unit tests to use StandardSAE/StandardSAEConfig
|
| 101 |
+
|
| 102 |
+
🤖 Generated with [Claude Code](https://claude.ai/code)
|
| 103 |
+
|
| 104 |
+
Co-Authored-By: Claude <noreply@anthropic.com> ([`fab4c6c`](https://github.com/jbloomAus/SAEDashboard/commit/fab4c6cc7a026bbf9beed229734a34323f22c158))
|
| 105 |
+
|
| 106 |
+
* Add CLT (Cross-Layer Transcoder) support
|
| 107 |
+
|
| 108 |
+
- Add CLTLayerWrapper to wrap CLT models for SAE-compatible interface
|
| 109 |
+
- Add CLT loading logic in neuronpedia_runner with local file support
|
| 110 |
+
- Add conditional logic to skip fold_W_dec_norm() for CLT wrappers
|
| 111 |
+
- Add conditional logic to skip hook_z_reshaping_mode for CLT wrappers
|
| 112 |
+
- Add support for additional hook types (hook_mlp_out, hook_attn_out, etc.)
|
| 113 |
+
- Add CLI arguments for CLT configuration (--use-clt, --clt-layer-idx, etc.)
|
| 114 |
+
- Ensure set_use_hook_mlp_in is called for CLT models ([`0c58760`](https://github.com/jbloomAus/SAEDashboard/commit/0c587608d2da05aa5ed3656afc0ca6b001fbf79b))
|
| 115 |
+
|
| 116 |
+
* script for CLT dashboard generation ([`31e7154`](https://github.com/jbloomAus/SAEDashboard/commit/31e7154363af98af2d6ec269ee9a78f9723048bb))
|
| 117 |
+
|
| 118 |
+
* formatting ([`210fdc4`](https://github.com/jbloomAus/SAEDashboard/commit/210fdc4bd3d8e2dc957c2925719a7a7f7bca1de3))
|
| 119 |
+
|
| 120 |
+
* simplified init function ([`883deb4`](https://github.com/jbloomAus/SAEDashboard/commit/883deb46d75df68af96c2366f59a41dd4a6db964))
|
| 121 |
+
|
| 122 |
+
* formatted tests ([`676b0f1`](https://github.com/jbloomAus/SAEDashboard/commit/676b0f17db4ad919c951414e29c1eb224efa8cef))
|
| 123 |
+
|
| 124 |
+
* added tests and formatting ([`70cbad3`](https://github.com/jbloomAus/SAEDashboard/commit/70cbad3074d2c0a86929921d78d20954e1caf6d8))
|
| 125 |
+
|
| 126 |
+
* Fix compatibility with new SAELens API structure
|
| 127 |
+
|
| 128 |
+
- Fix hook_layer extraction from hook_name when not in config
|
| 129 |
+
- Remove deprecated unpacking of SAE.from_pretrained() return values
|
| 130 |
+
- Handle prepend_bos in both config and metadata locations
|
| 131 |
+
- Add support for extracting layer number from hook_name pattern
|
| 132 |
+
|
| 133 |
+
Co-Authored-By: Claude <noreply@anthropic.com> ([`9018991`](https://github.com/jbloomAus/SAEDashboard/commit/9018991a40b62b7f4ece4f91aa8f07b9f2119d9f))
|
| 134 |
+
|
| 135 |
+
* Fix indentation and hook_name access for transcoders
|
| 136 |
+
|
| 137 |
+
- Fix indentation errors in neuronpedia_runner.py
|
| 138 |
+
- Fix hook_name access - it's always in metadata for both SAEs and transcoders
|
| 139 |
+
- Add test scripts for transcoder functionality
|
| 140 |
+
- Successfully tested transcoder dashboard generation
|
| 141 |
+
|
| 142 |
+
Co-Authored-By: Claude <noreply@anthropic.com> ([`31368e4`](https://github.com/jbloomAus/SAEDashboard/commit/31368e4387757063df87e3b04bd512e13a8cc7d5))
|
| 143 |
+
|
| 144 |
+
* Update .gitignore to exclude test directories and submodules ([`5bf67c5`](https://github.com/jbloomAus/SAEDashboard/commit/5bf67c56e242f496b849957675dd610863485aeb))
|
| 145 |
+
|
| 146 |
+
* Add transcoder support to SAEDashboard
|
| 147 |
+
|
| 148 |
+
- Update imports from sae_lens to use new API structure
|
| 149 |
+
- Add support for loading Transcoder and SkipTranscoder
|
| 150 |
+
- Handle differences between SAE and Transcoder configs
|
| 151 |
+
- Add support for normalized hooks in transformer_lens_wrapper
|
| 152 |
+
- Fix architecture handling in FeatureMaskingContext
|
| 153 |
+
- Update ActivationsStore.from_sae() to include dataset parameter
|
| 154 |
+
|
| 155 |
+
Co-Authored-By: Claude <noreply@anthropic.com> ([`02f78a0`](https://github.com/jbloomAus/SAEDashboard/commit/02f78a0b2a6d08729e60caf06649fca4dfc38ec7))
|
| 156 |
+
|
| 157 |
+
## v0.6.10 (2025-07-16)
|
| 158 |
+
|
| 159 |
+
### Fix
|
| 160 |
+
|
| 161 |
+
* fix: relax SAELens requirement ([`a83147e`](https://github.com/jbloomAus/SAEDashboard/commit/a83147efbf30ef4c4380f306a03468a0c8d41be0))
|
| 162 |
+
|
| 163 |
+
* fix: Merge pull request #45 from Hzfinfdu/main
|
| 164 |
+
|
| 165 |
+
fix: reading model_from_pretrained_kwargs from SAELens config with th… ([`0a509fe`](https://github.com/jbloomAus/SAEDashboard/commit/0a509fede04737b8087140fb4fe5f7addc259806))
|
| 166 |
+
|
| 167 |
+
* fix: reading model_from_pretrained_kwargs from SAELens config with the correct key ([`9938812`](https://github.com/jbloomAus/SAEDashboard/commit/9938812ad209764ceb021eedffb08c0fc5a31c89))
|
| 168 |
+
|
| 169 |
+
### Unknown
|
| 170 |
+
|
| 171 |
+
* Merge pull request #60 from jbloomAus/fix-unit-tests
|
| 172 |
+
|
| 173 |
+
fixes unit tests ([`1a3975d`](https://github.com/jbloomAus/SAEDashboard/commit/1a3975df60198dc169dfcf3354a8e8da5383029f))
|
| 174 |
+
|
| 175 |
+
* fixes unit tests ([`2a35d5d`](https://github.com/jbloomAus/SAEDashboard/commit/2a35d5dd4ad08e3d3532158030d86dbef93ad309))
|
| 176 |
+
|
| 177 |
+
* dedupes get_tokens() (#55)
|
| 178 |
+
|
| 179 |
+
* dedupes get_tokens()
|
| 180 |
+
|
| 181 |
+
* adds newline ([`faeb6f1`](https://github.com/jbloomAus/SAEDashboard/commit/faeb6f119d35a275a304d39c6e8cc9c7c40d31ce))
|
| 182 |
+
|
| 183 |
+
* fixes make commands (#57) ([`cb74411`](https://github.com/jbloomAus/SAEDashboard/commit/cb74411039d0c9d0b0883c85434f27601cf940a5))
|
| 184 |
+
|
| 185 |
+
* deletes print statements in tests (#56) ([`026ba30`](https://github.com/jbloomAus/SAEDashboard/commit/026ba305f4e31f5b47d4f9ada04c7cb0c3aae7f0))
|
| 186 |
+
|
| 187 |
+
* deletes unused direct_effect_feature_ablation_experiment() (#52) ([`391ff94`](https://github.com/jbloomAus/SAEDashboard/commit/391ff949a997a99b605bd55a706d6fed2892249c))
|
| 188 |
+
|
| 189 |
+
* removes unused files (#54) ([`5381cc7`](https://github.com/jbloomAus/SAEDashboard/commit/5381cc7118c7655c6c14cdbbd12e1f6c00278fc2))
|
| 190 |
+
|
| 191 |
+
* Merge pull request #47 from Marlon154/main
|
| 192 |
+
|
| 193 |
+
Fixing deprecated fn call for SAE Lens ([`61c9bd4`](https://github.com/jbloomAus/SAEDashboard/commit/61c9bd4ad8ccd5d96cb5c89eb961db0e7fbc2ab0))
|
| 194 |
+
|
| 195 |
+
* Merge branch 'main' into main ([`50b202a`](https://github.com/jbloomAus/SAEDashboard/commit/50b202a0b2fef413dd46b4ce2838bae27c0ac252))
|
| 196 |
+
|
| 197 |
+
* Merge pull request #35 from chanind/relax-saelens-dep
|
| 198 |
+
|
| 199 |
+
fix: relax SAELens and einops requirements ([`6c71bbf`](https://github.com/jbloomAus/SAEDashboard/commit/6c71bbfd7b6f1562093f1192616a7a55188631d3))
|
| 200 |
+
|
| 201 |
+
* fixing type checking ([`42a9845`](https://github.com/jbloomAus/SAEDashboard/commit/42a9845bba856c942f6d70182429cdb49e0ea917))
|
| 202 |
+
|
| 203 |
+
* Merge branch 'main' into relax-saelens-dep ([`3e6c870`](https://github.com/jbloomAus/SAEDashboard/commit/3e6c8703afd5ce80c29ec1ed0fc729def3f7f8fa))
|
| 204 |
+
|
| 205 |
+
* also relax einops ([`62614ac`](https://github.com/jbloomAus/SAEDashboard/commit/62614ac27ca50527556cc7c891e589e63a14e9bc))
|
| 206 |
+
|
| 207 |
+
* fix type checks ([`5a2cca0`](https://github.com/jbloomAus/SAEDashboard/commit/5a2cca0334a0907e7685cbef798cda71cd249ba4))
|
| 208 |
+
|
| 209 |
+
* Fixing deprecated fn call ([`f1da0e6`](https://github.com/jbloomAus/SAEDashboard/commit/f1da0e6ea7d663e5ff54612d7979d1b1ed9a6b77))
|
| 210 |
+
|
| 211 |
+
## v0.6.9 (2025-02-25)
|
| 212 |
+
|
| 213 |
+
### Fix
|
| 214 |
+
|
| 215 |
+
* fix: Merge pull request #44 from jbloomAus/update_saelens
|
| 216 |
+
|
| 217 |
+
fix: don't use sparsity ([`f30a19b`](https://github.com/jbloomAus/SAEDashboard/commit/f30a19b9cb42f15302848c31ddf1d14462209a42))
|
| 218 |
+
|
| 219 |
+
* fix: don't use sparsity ([`d5ba79b`](https://github.com/jbloomAus/SAEDashboard/commit/d5ba79bbf3d51cbc67e276297d18d85add9d33e7))
|
| 220 |
+
|
| 221 |
+
* fix: update SAELens version and remove unsupported load_sparsity ([`63192ba`](https://github.com/jbloomAus/SAEDashboard/commit/63192ba7d9de7afae9cb67f65db2e79a39b898c6))
|
| 222 |
+
|
| 223 |
+
### Unknown
|
| 224 |
+
|
| 225 |
+
* Merge pull request #43 from jbloomAus/update_saelens
|
| 226 |
+
|
| 227 |
+
fix: update SAELens version and remove unsupported load_sparsity ([`c083723`](https://github.com/jbloomAus/SAEDashboard/commit/c083723237090165725e587f8bdb8f01338394b4))
|
| 228 |
+
|
| 229 |
+
## v0.6.8 (2025-02-15)
|
| 230 |
+
|
| 231 |
+
### Fix
|
| 232 |
+
|
| 233 |
+
* fix: prepended chat template text should not be in activations ([`f3c20ee`](https://github.com/jbloomAus/SAEDashboard/commit/f3c20eec31976db48c1f1d37aabd077e068f66ac))
|
| 234 |
+
|
| 235 |
+
### Unknown
|
| 236 |
+
|
| 237 |
+
* Merge pull request #42 from jbloomAus/prepend_text_fix
|
| 238 |
+
|
| 239 |
+
fix: prepended chat template text should not be in activations ([`eea0b83`](https://github.com/jbloomAus/SAEDashboard/commit/eea0b830e97e791571986cf1ccae1605606ddb4f))
|
| 240 |
+
|
| 241 |
+
## v0.6.7 (2025-02-13)
|
| 242 |
+
|
| 243 |
+
### Fix
|
| 244 |
+
|
| 245 |
+
* fix: force build ([`9b96ac5`](https://github.com/jbloomAus/SAEDashboard/commit/9b96ac57e0b23c1a4cc73fbd9fd855ab0961cce7))
|
| 246 |
+
|
| 247 |
+
### Unknown
|
| 248 |
+
|
| 249 |
+
* Merge pull request #41 from jbloomAus/prepend_chat_template
|
| 250 |
+
|
| 251 |
+
feat: Prepend chat template and activation threshold ([`c7347fa`](https://github.com/jbloomAus/SAEDashboard/commit/c7347faa7d1c800dae398ee8dbded53afced9aa4))
|
| 252 |
+
|
| 253 |
+
* add example ([`9ab42d6`](https://github.com/jbloomAus/SAEDashboard/commit/9ab42d66a2a3e445b300bf83f677f143da3fecd9))
|
| 254 |
+
|
| 255 |
+
* proper 'activation threshold' ([`a6d7c1c`](https://github.com/jbloomAus/SAEDashboard/commit/a6d7c1c8ca4a1ec46676298c4661f623bada9049))
|
| 256 |
+
|
| 257 |
+
* prepend chat template text ([`c8829a1`](https://github.com/jbloomAus/SAEDashboard/commit/c8829a14d6351b5cfd52af927942af5c6897db60))
|
| 258 |
+
|
| 259 |
+
## v0.6.6 (2025-02-11)
|
| 260 |
+
|
| 261 |
+
### Fix
|
| 262 |
+
|
| 263 |
+
* fix: run_settings.json should properly log model_id and layer ([`2e661d9`](https://github.com/jbloomAus/SAEDashboard/commit/2e661d95f30bc28e7d818bbd67de931a334d837f))
|
| 264 |
+
|
| 265 |
+
### Unknown
|
| 266 |
+
|
| 267 |
+
* Merge pull request #40 from jbloomAus/run_settings_fix
|
| 268 |
+
|
| 269 |
+
fix: run_settings.json should properly log model_id and layer ([`f3bde39`](https://github.com/jbloomAus/SAEDashboard/commit/f3bde395843720674d4c60e21bc2453d958ff402))
|
| 270 |
+
|
| 271 |
+
## v0.6.5 (2025-02-11)
|
| 272 |
+
|
| 273 |
+
### Fix
|
| 274 |
+
|
| 275 |
+
* fix: Force Build ([`2e4979c`](https://github.com/jbloomAus/SAEDashboard/commit/2e4979c07ad8bcd2760ee0981ee415d17fef2e5a))
|
| 276 |
+
|
| 277 |
+
### Unknown
|
| 278 |
+
|
| 279 |
+
* Merge pull request #39 from jbloomAus/allow_vector_output
|
| 280 |
+
|
| 281 |
+
feat: allow outputting raw vector in neuronpedia outputs ([`1444786`](https://github.com/jbloomAus/SAEDashboard/commit/14447862418b18d112053a7af8810b049400089a))
|
| 282 |
+
|
| 283 |
+
* remove debug log ([`6efeb6c`](https://github.com/jbloomAus/SAEDashboard/commit/6efeb6c1976dea8374800e7625d63a14a3b6438d))
|
| 284 |
+
|
| 285 |
+
* allow outputting vector ([`4c6cb35`](https://github.com/jbloomAus/SAEDashboard/commit/4c6cb35752317db9f22476d26dc1bab7e4d6e511))
|
| 286 |
+
|
| 287 |
+
* Merge pull request #37 from jbloomAus/feature/vector-dashboards
|
| 288 |
+
|
| 289 |
+
Feature/vector dashboards ([`64c44a9`](https://github.com/jbloomAus/SAEDashboard/commit/64c44a9c11b2dce26b030d5e7bbf782ef90a2985))
|
| 290 |
+
|
| 291 |
+
* typing ([`09aeeab`](https://github.com/jbloomAus/SAEDashboard/commit/09aeeabb4c0f45f4bdabb884f683208eb7073142))
|
| 292 |
+
|
| 293 |
+
* Fixed missing parameter ([`a91d9f5`](https://github.com/jbloomAus/SAEDashboard/commit/a91d9f5dc8c65249c032dc4088aead4364bc42e9))
|
| 294 |
+
|
| 295 |
+
* Fixed parameterization and formatting ([`1956fbc`](https://github.com/jbloomAus/SAEDashboard/commit/1956fbc0ab6e4200d771ad4b504946ff81707969))
|
| 296 |
+
|
| 297 |
+
* Renamed demo notebook, some cleanup ([`6a486a5`](https://github.com/jbloomAus/SAEDashboard/commit/6a486a5834377823f380f59cafcbc3debbdcc3ed))
|
| 298 |
+
|
| 299 |
+
* Working pipeline flow ([`fdb2292`](https://github.com/jbloomAus/SAEDashboard/commit/fdb2292ad84b083246fdf3be2820e0b31168dce2))
|
| 300 |
+
|
| 301 |
+
* First draft of vector vis pipeline ([`4351ef9`](https://github.com/jbloomAus/SAEDashboard/commit/4351ef938a82d2b5e5a37391c236791cc23b41e5))
|
| 302 |
+
|
| 303 |
+
* Merge pull request #38 from jbloomAus/feature/hf-model-override
|
| 304 |
+
|
| 305 |
+
enable passing custom HF model to replace model weights ([`5d98417`](https://github.com/jbloomAus/SAEDashboard/commit/5d98417877c2cfe52bb09ddada0b4b53849b344a))
|
| 306 |
+
|
| 307 |
+
* enable passing custom HF model to replace model weights ([`b2d6ae5`](https://github.com/jbloomAus/SAEDashboard/commit/b2d6ae5446fb79f4662bcdac6030cb6072b09b60))
|
| 308 |
+
|
| 309 |
+
* Don't copy to output folder by default ([`4dbde12`](https://github.com/jbloomAus/SAEDashboard/commit/4dbde1214d49eaaf9b591f083f34e57c8c0c1dbd))
|
| 310 |
+
|
| 311 |
+
* Don't save html file for NP outputs ([`a160bff`](https://github.com/jbloomAus/SAEDashboard/commit/a160bff204b7464d2de00e3f80c255123d11171b))
|
| 312 |
+
|
| 313 |
+
## v0.6.4 (2024-10-24)
|
| 314 |
+
|
| 315 |
+
### Fix
|
| 316 |
+
|
| 317 |
+
* fix: Merge pull request #33 from jbloomAus/fix/topk-selection-purview
|
| 318 |
+
|
| 319 |
+
Fix/topk selection purview ([`afccd5a`](https://github.com/jbloomAus/SAEDashboard/commit/afccd5aaa00d00672eb1270b258b69f0e51c046a))
|
| 320 |
+
|
| 321 |
+
### Unknown
|
| 322 |
+
|
| 323 |
+
* updated formatting/typing ([`fb141ae`](https://github.com/jbloomAus/SAEDashboard/commit/fb141ae991261408d296286bf6777b2ec5f1f319))
|
| 324 |
+
|
| 325 |
+
* TopK will now select from all latents regardless of feature batch size ([`c1f0e14`](https://github.com/jbloomAus/SAEDashboard/commit/c1f0e14dda7aa3364bfd78ca2b8c04c95b2d14b3))
|
| 326 |
+
|
| 327 |
+
* Update README.md ([`8235a9e`](https://github.com/jbloomAus/SAEDashboard/commit/8235a9e3adaea50b6b9f26f575e25a254d67a135))
|
| 328 |
+
|
| 329 |
+
* Merge pull request #32 from jbloomAus/docs/readme-update
|
| 330 |
+
|
| 331 |
+
docs: updated readme ([`b5e5480`](https://github.com/jbloomAus/SAEDashboard/commit/b5e54808ee05fc75e68d74ec319bf49826b45508))
|
| 332 |
+
|
| 333 |
+
* Update README.md ([`a1546fd`](https://github.com/jbloomAus/SAEDashboard/commit/a1546fdef32745cdc862a5a2dd0478e57e45320d))
|
| 334 |
+
|
| 335 |
+
* Removed outdated vis type ([`b0676af`](https://github.com/jbloomAus/SAEDashboard/commit/b0676afcca0845b73a54d983eaa9d72b0e9dff05))
|
| 336 |
+
|
| 337 |
+
* Update README.md ([`9b8446a`](https://github.com/jbloomAus/SAEDashboard/commit/9b8446aa47f287ba80bf0ac4a39f7c77f0492990))
|
| 338 |
+
|
| 339 |
+
* Updated format ([`90e4a09`](https://github.com/jbloomAus/SAEDashboard/commit/90e4a09eedd7f428b64e58d5ca2fd1cfa658b0da))
|
| 340 |
+
|
| 341 |
+
* Updated readme ([`f6819a6`](https://github.com/jbloomAus/SAEDashboard/commit/f6819a6da594673cad65c9ccd3a4f67746de796d))
|
| 342 |
+
|
| 343 |
+
## v0.6.3 (2024-10-23)
|
| 344 |
+
|
| 345 |
+
### Fix
|
| 346 |
+
|
| 347 |
+
* fix: update cached_activations directory to include number of prompts ([`0308cb1`](https://github.com/jbloomAus/SAEDashboard/commit/0308cb146bf2eb9cee26f03d3098511d03022485))
|
| 348 |
+
|
| 349 |
+
## v0.6.2 (2024-10-23)
|
| 350 |
+
|
| 351 |
+
### Fix
|
| 352 |
+
|
| 353 |
+
* fix: lint ([`3fc0e2c`](https://github.com/jbloomAus/SAEDashboard/commit/3fc0e2ccb39ed1d3e31d66ae0aba2b2b367d46aa))
|
| 354 |
+
|
| 355 |
+
### Unknown
|
| 356 |
+
|
| 357 |
+
* Merge branch 'main' of https://github.com/jbloomAus/SAEDashboard ([`8f74a96`](https://github.com/jbloomAus/SAEDashboard/commit/8f74a969f48a7e0fd8de17cc983acf3886db95ef))
|
| 358 |
+
|
| 359 |
+
## v0.6.1 (2024-10-22)
|
| 360 |
+
|
| 361 |
+
### Unknown
|
| 362 |
+
|
| 363 |
+
* Fix: divide by zero, cached_activations folder name ([`1792298`](https://github.com/jbloomAus/SAEDashboard/commit/179229805ae6489d86e235240c65d26db64b5cd7))
|
| 364 |
+
|
| 365 |
+
* Merge branch 'main' of https://github.com/jbloomAus/SAEDashboard ([`508a74d`](https://github.com/jbloomAus/SAEDashboard/commit/508a74df8ff279716501e4179c501b5089a8d706))
|
| 366 |
+
|
| 367 |
+
## v0.6.0 (2024-10-21)
|
| 368 |
+
|
| 369 |
+
### Feature
|
| 370 |
+
|
| 371 |
+
* feat: np sae id suffix ([`448b14e`](https://github.com/jbloomAus/SAEDashboard/commit/448b14e0b3aea8ff854a5365f164b6ce5f419f0d))
|
| 372 |
+
|
| 373 |
+
### Fix
|
| 374 |
+
|
| 375 |
+
* fix: update saelens to v4 ([`ef1a330`](https://github.com/jbloomAus/SAEDashboard/commit/ef1a3302d0483eddb247defab5c88816850f7f63))
|
| 376 |
+
|
| 377 |
+
### Unknown
|
| 378 |
+
|
| 379 |
+
* Merge pull request #31 from jbloomAus/fix/reduce-mem
|
| 380 |
+
|
| 381 |
+
fix: added mem cleanup ([`60bd716`](https://github.com/jbloomAus/SAEDashboard/commit/60bd716c7b52bb0eaea0937e097eb77ed78bd33d))
|
| 382 |
+
|
| 383 |
+
* Fixed formatting ([`f1fab0c`](https://github.com/jbloomAus/SAEDashboard/commit/f1fab0c1fd5be281e2162ab3f54ffc7f4c09a1ce))
|
| 384 |
+
|
| 385 |
+
* Added cleanup ([`305c46d`](https://github.com/jbloomAus/SAEDashboard/commit/305c46d7a30330bbae6893b83cb6d498c2c975f1))
|
| 386 |
+
|
| 387 |
+
* Merge pull request #30 from jbloomAus/feat-mask-via-position
|
| 388 |
+
|
| 389 |
+
feat: prepending/appending tokens for prompt template + feat mask via Position ([`4c60e4c`](https://github.com/jbloomAus/SAEDashboard/commit/4c60e4c834dfb5759ce55dc90d1f88768abfea0d))
|
| 390 |
+
|
| 391 |
+
* add a few tests ([`96247d5`](https://github.com/jbloomAus/SAEDashboard/commit/96247d5afaf141b8b1279c17fd135240b0d8e869))
|
| 392 |
+
|
| 393 |
+
* handle prefixes / suffixes and ignored positions ([`bff7fd9`](https://github.com/jbloomAus/SAEDashboard/commit/bff7fd98b09318a1b01d2bc4a06467f8afa156f9))
|
| 394 |
+
|
| 395 |
+
* simplify masking ([`385b6e1`](https://github.com/jbloomAus/SAEDashboard/commit/385b6e116ecac53ad4df8585f7513c3416707d8b))
|
| 396 |
+
|
| 397 |
+
* add option for ignoring tokens at particular positions ([`ed3426d`](https://github.com/jbloomAus/SAEDashboard/commit/ed3426de5cb1495c138f770eefa5f941408aa390))
|
| 398 |
+
|
| 399 |
+
* Merge pull request #29 from jbloomAus/refactor/optimize-dfa-speed
|
| 400 |
+
|
| 401 |
+
Sped up DFA calculation 60x ([`f992e3c`](https://github.com/jbloomAus/SAEDashboard/commit/f992e3cf116189625b3a92529cf68d6226a1221c))
|
| 402 |
+
|
| 403 |
+
* Sped up DFA calculation ([`be11cd5`](https://github.com/jbloomAus/SAEDashboard/commit/be11cd5652f0f8a8ae425555666b747b9b99314e))
|
| 404 |
+
|
| 405 |
+
* Added test to check for decoder weight dist (head dist) ([`f147696`](https://github.com/jbloomAus/SAEDashboard/commit/f1476967af5fee95313264ccaee668605d23b9ad))
|
| 406 |
+
|
| 407 |
+
* Merge pull request #28 from jbloomAus/feature/np-topk-size-arg
|
| 408 |
+
|
| 409 |
+
Feature/np topk size arg ([`c5c1365`](https://github.com/jbloomAus/SAEDashboard/commit/c5c136576609991177d3a8924b5bf75a42b66399))
|
| 410 |
+
|
| 411 |
+
* Simply updated default value for top K ([`5c855fe`](https://github.com/jbloomAus/SAEDashboard/commit/5c855fec0e58a114a537590d1400eaa42dd3610c))
|
| 412 |
+
|
| 413 |
+
* Testing variable topk sizes ([`79fe14b`](https://github.com/jbloomAus/SAEDashboard/commit/79fe14b840991bd1f8ada8462aeb65d72821c4aa))
|
| 414 |
+
|
| 415 |
+
* Merge pull request #25 from jbloomAus/fix/dfa-for-gqa
|
| 416 |
+
|
| 417 |
+
Fix/dfa for gqa ([`85c345f`](https://github.com/jbloomAus/SAEDashboard/commit/85c345f3ad8069a59be8d495242395c50381ab01))
|
| 418 |
+
|
| 419 |
+
* Fixed formatting ([`48a67c7`](https://github.com/jbloomAus/SAEDashboard/commit/48a67c79247d05745d355e6a4bf380e9df20474e))
|
| 420 |
+
|
| 421 |
+
* Removed redundant code from rebase ([`a71fb9d`](https://github.com/jbloomAus/SAEDashboard/commit/a71fb9dde6e880b0f4297277d27696c9d524d052))
|
| 422 |
+
|
| 423 |
+
* fixed rebase ([`57ee280`](https://github.com/jbloomAus/SAEDashboard/commit/57ee28021efd3678bcd9d12d55e048c14a2f2d47))
|
| 424 |
+
|
| 425 |
+
* Added tests for DFA for GQA ([`3b99e36`](https://github.com/jbloomAus/SAEDashboard/commit/3b99e36c74d2c61617cfed107bee3b0eb3b63294))
|
| 426 |
+
|
| 427 |
+
* Removed duplicate code ([`7093773`](https://github.com/jbloomAus/SAEDashboard/commit/7093773d079cd235aea99273a1365363a5bf8b6d))
|
| 428 |
+
|
| 429 |
+
* More rebasing stuff ([`59c6cd8`](https://github.com/jbloomAus/SAEDashboard/commit/59c6cd85ead287b2774aa591463d131840c7f270))
|
| 430 |
+
|
| 431 |
+
* Fixed formatting ([`ed7d3b1`](https://github.com/jbloomAus/SAEDashboard/commit/ed7d3b16a99e3e3a272e73356cc0509b2c59a292))
|
| 432 |
+
|
| 433 |
+
* Removed debugging statements ([`6489d1c`](https://github.com/jbloomAus/SAEDashboard/commit/6489d1c5b52ed86cb280c237c08e10238e0d0564))
|
| 434 |
+
|
| 435 |
+
* more debug prints x3 ([`5ba2b8a`](https://github.com/jbloomAus/SAEDashboard/commit/5ba2b8a69f1881b901131976c7d52f142068dbd2))
|
| 436 |
+
|
| 437 |
+
* more debug prints x2 ([`e124ff9`](https://github.com/jbloomAus/SAEDashboard/commit/e124ff906ec7b37083af4e4721b9e33902146e47))
|
| 438 |
+
|
| 439 |
+
* more debug prints ([`e2b0c35`](https://github.com/jbloomAus/SAEDashboard/commit/e2b0c35467e5d405abd3cca664dfd1960dbba0eb))
|
| 440 |
+
|
| 441 |
+
* temp print statements ([`95df55b`](https://github.com/jbloomAus/SAEDashboard/commit/95df55b29f9250f67c5b986216e587c37f72aa9e))
|
| 442 |
+
|
| 443 |
+
* Lowered default threshold ([`dc1f31a`](https://github.com/jbloomAus/SAEDashboard/commit/dc1f31a55400231e46feb58a8c100f66472baa1b))
|
| 444 |
+
|
| 445 |
+
* updated ignore ([`eb0d56a`](https://github.com/jbloomAus/SAEDashboard/commit/eb0d56a9f813b9cf82742093fae00bb0ccfdac45))
|
| 446 |
+
|
| 447 |
+
* Reduced memory load of GQA DFA ([`05867f1`](https://github.com/jbloomAus/SAEDashboard/commit/05867f1d0c8b5f2a5b76f3ea45ab9c87eaae9c09))
|
| 448 |
+
|
| 449 |
+
* DFA will now work for models with grouped query attention ([`91a5dd1`](https://github.com/jbloomAus/SAEDashboard/commit/91a5dd17a2e567efa7d8a89d228eb7de47ae6766))
|
| 450 |
+
|
| 451 |
+
* Added head attr weights functionality for when DFA is use ([`03a615f`](https://github.com/jbloomAus/SAEDashboard/commit/03a615f7c70a6f6e634845dab4051874698fac5b))
|
| 452 |
+
|
| 453 |
+
* Edited default chunk size ([`7d68f9e`](https://github.com/jbloomAus/SAEDashboard/commit/7d68f9e7131b8c5558e886022625dac267f20aab))
|
| 454 |
+
|
| 455 |
+
* Fixed formatting ([`4d5f38b`](https://github.com/jbloomAus/SAEDashboard/commit/4d5f38beca15f2ce05c89f83eb3e955c291f9687))
|
| 456 |
+
|
| 457 |
+
* Removed debugging statements and added device changes ([`76e17c9`](https://github.com/jbloomAus/SAEDashboard/commit/76e17c91a41b5df6047baa5bcfa33d253b029d29))
|
| 458 |
+
|
| 459 |
+
* more debug prints x3 ([`06535d3`](https://github.com/jbloomAus/SAEDashboard/commit/06535d3df168d92ac79d2f5a14b345c757dfd9de))
|
| 460 |
+
|
| 461 |
+
* more debug prints x2 ([`26e8297`](https://github.com/jbloomAus/SAEDashboard/commit/26e8297888de066f0097e3b73245eb149bfb327f))
|
| 462 |
+
|
| 463 |
+
* more debug prints ([`9ded356`](https://github.com/jbloomAus/SAEDashboard/commit/9ded356ea8c3c5dd841bf5a45ea65ae8c67935f5))
|
| 464 |
+
|
| 465 |
+
* temp print statements ([`024ad57`](https://github.com/jbloomAus/SAEDashboard/commit/024ad578b65b8f3592b42b66dc6a56aeae2a3116))
|
| 466 |
+
|
| 467 |
+
* Lowered default threshold ([`a3b5977`](https://github.com/jbloomAus/SAEDashboard/commit/a3b5977c0f1bb7a865f7349304a5dd8092f7c2e8))
|
| 468 |
+
|
| 469 |
+
* updated ignore ([`d5d325a`](https://github.com/jbloomAus/SAEDashboard/commit/d5d325a63b3b26b890c2bab512f2a8473bdc926a))
|
| 470 |
+
|
| 471 |
+
* Reduced memory load of GQA DFA ([`93eb1a9`](https://github.com/jbloomAus/SAEDashboard/commit/93eb1a9a92320d9f4645b500e22a566135918e3d))
|
| 472 |
+
|
| 473 |
+
* DFA will now work for models with grouped query attention ([`6594155`](https://github.com/jbloomAus/SAEDashboard/commit/65941559bac03a3e4fb128d5327033e01f19c18d))
|
| 474 |
+
|
| 475 |
+
* Added head attr weights functionality for when DFA is use ([`9312d90`](https://github.com/jbloomAus/SAEDashboard/commit/9312d901bf17e14400199c86e0284be6c750162a))
|
| 476 |
+
|
| 477 |
+
* Added tests for DFA for GQA ([`fcfac37`](https://github.com/jbloomAus/SAEDashboard/commit/fcfac37e148461e585f38fddf868ad2a32d908a8))
|
| 478 |
+
|
| 479 |
+
* Removed duplicate code ([`cc00944`](https://github.com/jbloomAus/SAEDashboard/commit/cc00944855720d5b8139d4267b44c1a230ef5319))
|
| 480 |
+
|
| 481 |
+
* Fixed formatting ([`50b08b4`](https://github.com/jbloomAus/SAEDashboard/commit/50b08b4eb50734afe0f085274ccaee71ec4017a4))
|
| 482 |
+
|
| 483 |
+
* Removed debugging statements ([`f7b949b`](https://github.com/jbloomAus/SAEDashboard/commit/f7b949b4af6bc8ca7557bfa5fa2441fbaa0284a0))
|
| 484 |
+
|
| 485 |
+
* more debug prints x3 ([`53536b0`](https://github.com/jbloomAus/SAEDashboard/commit/53536b03d624783b6b2f95b07b9318139ef0c49e))
|
| 486 |
+
|
| 487 |
+
* more debug prints x2 ([`6f2c504`](https://github.com/jbloomAus/SAEDashboard/commit/6f2c504a355f9071e766fc7fa3b6aad9890572a8))
|
| 488 |
+
|
| 489 |
+
* more debug prints ([`e1bef90`](https://github.com/jbloomAus/SAEDashboard/commit/e1bef90d16e8c73c9532b19a08c842757828c7ed))
|
| 490 |
+
|
| 491 |
+
* temp print statements ([`fd75714`](https://github.com/jbloomAus/SAEDashboard/commit/fd75714ee4631463c1f754d68f83b9ef75eb2285))
|
| 492 |
+
|
| 493 |
+
* updated ignore ([`c01062f`](https://github.com/jbloomAus/SAEDashboard/commit/c01062faecfaa132d87c56a7ba7add573c6b0f4e))
|
| 494 |
+
|
| 495 |
+
* Reduced memory load of GQA DFA ([`1ae40e9`](https://github.com/jbloomAus/SAEDashboard/commit/1ae40e9d487af7e8a7b148629588ef87fdd0a6e5))
|
| 496 |
+
|
| 497 |
+
* DFA will now work for models with grouped query attention ([`c66c90f`](https://github.com/jbloomAus/SAEDashboard/commit/c66c90f5d51961cafd5f13c26a94193ee38f828a))
|
| 498 |
+
|
| 499 |
+
* Edited default chunk size ([`3c78bdc`](https://github.com/jbloomAus/SAEDashboard/commit/3c78bdcfda12e5873de082a7f1e631a801bd9407))
|
| 500 |
+
|
| 501 |
+
* Fixed formatting ([`10a36e3`](https://github.com/jbloomAus/SAEDashboard/commit/10a36e3e8da3c7593058d3638ac3b7a32953b1b0))
|
| 502 |
+
|
| 503 |
+
* Removed debugging statements and added device changes ([`0f51dd9`](https://github.com/jbloomAus/SAEDashboard/commit/0f51dd953cd214244c71e8b9156b90483ceaa2be))
|
| 504 |
+
|
| 505 |
+
* more debug prints x3 ([`112ef42`](https://github.com/jbloomAus/SAEDashboard/commit/112ef4292b81a64f6168e7527ec583faa9ba20a4))
|
| 506 |
+
|
| 507 |
+
* more debug prints x2 ([`ef154d6`](https://github.com/jbloomAus/SAEDashboard/commit/ef154d6044bb67d17a2aa225ddf4099ccfc16b55))
|
| 508 |
+
|
| 509 |
+
* more debug prints ([`1b18d14`](https://github.com/jbloomAus/SAEDashboard/commit/1b18d141dd33e3a99c2abd5a6d195ab5142890d8))
|
| 510 |
+
|
| 511 |
+
* temp print statements ([`2194d2c`](https://github.com/jbloomAus/SAEDashboard/commit/2194d2cea16856c96ace47ad5ac560f088e769b0))
|
| 512 |
+
|
| 513 |
+
* Lowered default threshold ([`a49d1e5`](https://github.com/jbloomAus/SAEDashboard/commit/a49d1e5b94c8ef680448f20ded849c7752fb5131))
|
| 514 |
+
|
| 515 |
+
* updated ignore ([`2067655`](https://github.com/jbloomAus/SAEDashboard/commit/20676554541d29fddd87215a47e8e94891e342ac))
|
| 516 |
+
|
| 517 |
+
* Reduced memory load of GQA DFA ([`8ec1956`](https://github.com/jbloomAus/SAEDashboard/commit/8ec19566e8898413d349fe3f2e43fbff232ffa62))
|
| 518 |
+
|
| 519 |
+
* DFA will now work for models with grouped query attention ([`8f3cf55`](https://github.com/jbloomAus/SAEDashboard/commit/8f3cf5532e57abc6e694fb11c5f9c7c2915215c0))
|
| 520 |
+
|
| 521 |
+
* Added head attr weights functionality for when DFA is use ([`234ea32`](https://github.com/jbloomAus/SAEDashboard/commit/234ea3211ce7dbf84d101c4e8bfe844c3903b16a))
|
| 522 |
+
|
| 523 |
+
* Merge pull request #27 from jbloomAus/fix/resolve-duplication
|
| 524 |
+
|
| 525 |
+
Removed sources of duplicate sequences ([`525bffe`](https://github.com/jbloomAus/SAEDashboard/commit/525bffee516a630c4b4f033d3971fad8c6dd5a74))
|
| 526 |
+
|
| 527 |
+
* Updated location of wandb finish() ([`921da77`](https://github.com/jbloomAus/SAEDashboard/commit/921da77132a560505fa61decf287ca3833f96ec7))
|
| 528 |
+
|
| 529 |
+
* Added two sets of tests for duplication checks ([`3e95ffd`](https://github.com/jbloomAus/SAEDashboard/commit/3e95ffd1dafd01deb1f7817845ccb6229fb4ae09))
|
| 530 |
+
|
| 531 |
+
* Restored original random indices function as it seemed ok ([`388719b`](https://github.com/jbloomAus/SAEDashboard/commit/388719bec99b4306e81e0cdb772b9924db210774))
|
| 532 |
+
|
| 533 |
+
* Removed sources of duplicate sequences ([`853306c`](https://github.com/jbloomAus/SAEDashboard/commit/853306c4e08d9ec95674fdc5c87f807019055d0d))
|
| 534 |
+
|
| 535 |
+
## v0.5.1 (2024-08-27)
|
| 536 |
+
|
| 537 |
+
### Fix
|
| 538 |
+
|
| 539 |
+
* fix: multi-gpu-tlens
|
| 540 |
+
|
| 541 |
+
fix: handle multiple tlens devices ([`ed1e967`](https://github.com/jbloomAus/SAEDashboard/commit/ed1e967d44b887f4b99d2257934ca920d5c6a508))
|
| 542 |
+
|
| 543 |
+
* fix: handle multiple tlens devices ([`ba5368f`](https://github.com/jbloomAus/SAEDashboard/commit/ba5368f9999f08332c153816ba5836f8a1eb9ba1))
|
| 544 |
+
|
| 545 |
+
## v0.5.0 (2024-08-25)
|
| 546 |
+
|
| 547 |
+
### Feature
|
| 548 |
+
|
| 549 |
+
* feat: accelerate caching. Torch load / save faster when files are small.
|
| 550 |
+
|
| 551 |
+
Refactor/accelerate caching ([`6027d0a`](https://github.com/jbloomAus/SAEDashboard/commit/6027d0a3fc0d70908bad036a9658caa406d9f809))
|
| 552 |
+
|
| 553 |
+
### Unknown
|
| 554 |
+
|
| 555 |
+
* Updated formatting ([`c1ea288`](https://github.com/jbloomAus/SAEDashboard/commit/c1ea2882a17e0d1b7b28743a34fca9d0754bd8a7))
|
| 556 |
+
|
| 557 |
+
* Sped up caching with native torch functions ([`230840a`](https://github.com/jbloomAus/SAEDashboard/commit/230840aea50b8b7055a6aa61961d7ac50855b763))
|
| 558 |
+
|
| 559 |
+
* Increased cache loading speed ([`83fe5f4`](https://github.com/jbloomAus/SAEDashboard/commit/83fe5f4bdf1252d533f203bc3f53ea9f71880ab8))
|
| 560 |
+
|
| 561 |
+
## v0.4.0 (2024-08-22)
|
| 562 |
+
|
| 563 |
+
### Feature
|
| 564 |
+
|
| 565 |
+
* feat: Refactor json writer and trigger DFA release
|
| 566 |
+
|
| 567 |
+
JSON writer has been refactored for reusability and readability ([`664f487`](https://github.com/jbloomAus/SAEDashboard/commit/664f4874b585c5510d2d3dd639c5e893023f6332))
|
| 568 |
+
|
| 569 |
+
### Unknown
|
| 570 |
+
|
| 571 |
+
* Refactored JSON creation from the neuronpedia runner ([`d6bb24b`](https://github.com/jbloomAus/SAEDashboard/commit/d6bb24b6d773874d8e99be4d84402d559741907b))
|
| 572 |
+
|
| 573 |
+
* Merge pull request #20 from jbloomAus/feature/dfa
|
| 574 |
+
|
| 575 |
+
SAEVisRunner DFA Implementation ([`926ea87`](https://github.com/jbloomAus/SAEDashboard/commit/926ea87dd344548489201f68cc92b33662430813))
|
| 576 |
+
|
| 577 |
+
* Update ci.yaml ([`4b2807d`](https://github.com/jbloomAus/SAEDashboard/commit/4b2807dd865904120d236b355c0ccb1680c2919e))
|
| 578 |
+
|
| 579 |
+
* Fixed formatting ([`a62cc8f`](https://github.com/jbloomAus/SAEDashboard/commit/a62cc8f1bdd4c6e49b76d2d594e5a6b4b8183a8c))
|
| 580 |
+
|
| 581 |
+
* Fixed target index ([`ca2668d`](https://github.com/jbloomAus/SAEDashboard/commit/ca2668da03ea4d06cdc9f198988b80e0db844316))
|
| 582 |
+
|
| 583 |
+
* Corrected DFA indexing ([`d5028ae`](https://github.com/jbloomAus/SAEDashboard/commit/d5028aec875db4c03196726400c3b90b5d9d4d01))
|
| 584 |
+
|
| 585 |
+
* Adding temporary testing notebook ([`98e4b2f`](https://github.com/jbloomAus/SAEDashboard/commit/98e4b2f93d300ad4e94985d8d2594739a277e0c8))
|
| 586 |
+
|
| 587 |
+
* Added DFA output to neuronpedia runner ([`68eeff3`](https://github.com/jbloomAus/SAEDashboard/commit/68eeff3172b0c8637a6566c07951c28fd14a1c03))
|
| 588 |
+
|
| 589 |
+
* Fixed test typehints ([`d358e6f`](https://github.com/jbloomAus/SAEDashboard/commit/d358e6f5cc37304935eed949a0b0b985ba12b94f))
|
| 590 |
+
|
| 591 |
+
* Fixed formatting ([`5cb19e2`](https://github.com/jbloomAus/SAEDashboard/commit/5cb19e241051503730b6982813a6730556990c92))
|
| 592 |
+
|
| 593 |
+
* Corrected typehints ([`6173fbd`](https://github.com/jbloomAus/SAEDashboard/commit/6173fbd3824b7cba58e1cf0c7ee239762ee533ce))
|
| 594 |
+
|
| 595 |
+
* Removed another unused import ([`8be1572`](https://github.com/jbloomAus/SAEDashboard/commit/8be1572370b1adf341e2a650953bf17cd179808d))
|
| 596 |
+
|
| 597 |
+
* Removed unused imports ([`9071210`](https://github.com/jbloomAus/SAEDashboard/commit/90712105f74b287d77a06c045e8c32fd05f2e668))
|
| 598 |
+
|
| 599 |
+
* Added support for DFA calculations up to SAE Vis runner ([`4a08ffd`](https://github.com/jbloomAus/SAEDashboard/commit/4a08ffd13a8f29ff16808a20cd663c9d2d369e6a))
|
| 600 |
+
|
| 601 |
+
* Added activation collection flow for DFA ([`0ebb1f3`](https://github.com/jbloomAus/SAEDashboard/commit/0ebb1f3ca61603662f4f2cc8b1341470bf75b5d1))
|
| 602 |
+
|
| 603 |
+
* Merge pull request #19 from jbloomAus/fix/remove_precision_reduction
|
| 604 |
+
|
| 605 |
+
Removed precision reduction option ([`a5f8df1`](https://github.com/jbloomAus/SAEDashboard/commit/a5f8df15ef8619c4d08655e777d379a05b453346))
|
| 606 |
+
|
| 607 |
+
* Removed float16 option entirely from quantile calc ([`1b6a4a9`](https://github.com/jbloomAus/SAEDashboard/commit/1b6a4a93403ca2e9a869aa73600f37960090f03d))
|
| 608 |
+
|
| 609 |
+
* Removed precision reduction option ([`cd03ffb`](https://github.com/jbloomAus/SAEDashboard/commit/cd03ffb182e93a42480c01408b47ebae94d4c349))
|
| 610 |
+
|
| 611 |
+
## v0.3.0 (2024-08-15)
|
| 612 |
+
|
| 613 |
+
### Feature
|
| 614 |
+
|
| 615 |
+
* feat: seperate files per dashboard html ([`cd8d050`](https://github.com/jbloomAus/SAEDashboard/commit/cd8d050218ae3c6eeb7a9779072e60b78bfe0b58))
|
| 616 |
+
|
| 617 |
+
### Unknown
|
| 618 |
+
|
| 619 |
+
* Merge pull request #17 from jbloomAus/refactor/remove_enc_b
|
| 620 |
+
|
| 621 |
+
Removed all encoder B code ([`67c9c3f`](https://github.com/jbloomAus/SAEDashboard/commit/67c9c3fdc8bd220938f65c1f97214034cc7528b4))
|
| 622 |
+
|
| 623 |
+
* Removed all encoder B code ([`5174e2e`](https://github.com/jbloomAus/SAEDashboard/commit/5174e2e161030dc756c148f1740e50c52baf6a91))
|
| 624 |
+
|
| 625 |
+
* Merge pull request #18 from jbloomAus/feat-seperate-files-per-html-dashboard
|
| 626 |
+
|
| 627 |
+
feat: seperate files per dashboard html ([`8ff69ba`](https://github.com/jbloomAus/SAEDashboard/commit/8ff69ba207692d4acb8d5fc19d038090067690df))
|
| 628 |
+
|
| 629 |
+
* Merge pull request #16 from jbloomAus/performance_refactor
|
| 630 |
+
|
| 631 |
+
Create() will now reduce precision by default ([`fb07b90`](https://github.com/jbloomAus/SAEDashboard/commit/fb07b90eaac395a58f02ba927460dcc2c9e61d1a))
|
| 632 |
+
|
| 633 |
+
* Removed line ([`d795490`](https://github.com/jbloomAus/SAEDashboard/commit/d795490c1c9d8193c8cf84d0352b9d93c41947fe))
|
| 634 |
+
|
| 635 |
+
* Removed unnecessary print ([`4544f86`](https://github.com/jbloomAus/SAEDashboard/commit/4544f86472480f0df00344fa84111a7c2a52fcef))
|
| 636 |
+
|
| 637 |
+
* Precision will now be reduced by default for quantile calc ([`539d222`](https://github.com/jbloomAus/SAEDashboard/commit/539d222ded9e3a0944f5240f3a4cd84497d11a74))
|
| 638 |
+
|
| 639 |
+
* Merge pull request #15 from jbloomAus/quantile_efficiency
|
| 640 |
+
|
| 641 |
+
Quantile OOM prevention ([`4a40c37`](https://github.com/jbloomAus/SAEDashboard/commit/4a40c3704aab9363163fef3e2830d42f2fecdc6b))
|
| 642 |
+
|
| 643 |
+
* Made quantile batch optional and removed sampling code ([`2df51d3`](https://github.com/jbloomAus/SAEDashboard/commit/2df51d353f818a196916a15f2bc56f70480dd853))
|
| 644 |
+
|
| 645 |
+
* Added device check for test ([`afbb960`](https://github.com/jbloomAus/SAEDashboard/commit/afbb960d3c9376ad512607146826b7d1c1e68d48))
|
| 646 |
+
|
| 647 |
+
* Added parameter for quantile calculation batching ([`49d0a7a`](https://github.com/jbloomAus/SAEDashboard/commit/49d0a7ab37896a085f80409900e3d0b261b8c9e0))
|
| 648 |
+
|
| 649 |
+
* Added type annotation ([`c71c4aa`](https://github.com/jbloomAus/SAEDashboard/commit/c71c4aa1c8bc25d85b9a955b482823cbde445a51))
|
| 650 |
+
|
| 651 |
+
* Removed unused imports ([`ec01bfe`](https://github.com/jbloomAus/SAEDashboard/commit/ec01bfefc2f0f4d880cd5744ff6a2ea71991349b))
|
| 652 |
+
|
| 653 |
+
* Added float16 version of quantile calculation ([`2f01eb8`](https://github.com/jbloomAus/SAEDashboard/commit/2f01eb8d9f84a20918f19e81c23df86ddc9d7f0c))
|
| 654 |
+
|
| 655 |
+
* Merge pull request #13 from jbloomAus/hook_z_support
|
| 656 |
+
|
| 657 |
+
fix: restore hook_z support following regression. ([`ea87559`](https://github.com/jbloomAus/SAEDashboard/commit/ea87559359f9821e352dcab582e23b42fef1cebf))
|
| 658 |
+
|
| 659 |
+
* format ([`21e3617`](https://github.com/jbloomAus/SAEDashboard/commit/21e3617196ef57944c141563e9263101baf9c7f1))
|
| 660 |
+
|
| 661 |
+
* make sure hook_z works ([`efaeec0`](https://github.com/jbloomAus/SAEDashboard/commit/efaeec0fdf8c2c43bb13bfd652b812a38ebc0200))
|
| 662 |
+
|
| 663 |
+
* Merge pull request #12 from jbloomAus/use_sae_lens_loading
|
| 664 |
+
|
| 665 |
+
Use sae lens loading ([`89bba3e`](https://github.com/jbloomAus/SAEDashboard/commit/89bba3e7a10877782608c50f4b8dd9054f204381))
|
| 666 |
+
|
| 667 |
+
* add settings.json ([`d8f3034`](https://github.com/jbloomAus/SAEDashboard/commit/d8f3034c0ed7241c35e9761d60a9ee4072403fd0))
|
| 668 |
+
|
| 669 |
+
* add dtype ([`0d8008a`](https://github.com/jbloomAus/SAEDashboard/commit/0d8008afe93a2a2a5bfc954571c680a529ab883f))
|
| 670 |
+
|
| 671 |
+
* cli util ([`9da440e`](https://github.com/jbloomAus/SAEDashboard/commit/9da440eb3d50d48a7fdc4d3ee3d26de13a458593))
|
| 672 |
+
|
| 673 |
+
* wandb logging improvement ([`a077369`](https://github.com/jbloomAus/SAEDashboard/commit/a077369ca43009f4e50c0b1e7176cae398703856))
|
| 674 |
+
|
| 675 |
+
* add override for np set name ([`8906d10`](https://github.com/jbloomAus/SAEDashboard/commit/8906d103ab8d10bd01b791331dfc5485ac047a4f))
|
| 676 |
+
|
| 677 |
+
* auto add folder path to output dir ([`35e06ab`](https://github.com/jbloomAus/SAEDashboard/commit/35e06ab89bce257fc15ffaa4918b9598577d6df0))
|
| 678 |
+
|
| 679 |
+
* update tests ([`50163b0`](https://github.com/jbloomAus/SAEDashboard/commit/50163b04ca29b492b9fb71244aa26798655b663f))
|
| 680 |
+
|
| 681 |
+
* first step towards sae_lens remote loading ([`415a2d1`](https://github.com/jbloomAus/SAEDashboard/commit/415a2d1e484e9ea2351bf98de221f6a83a805107))
|
| 682 |
+
|
| 683 |
+
## v0.2.3 (2024-08-06)
|
| 684 |
+
|
| 685 |
+
### Fix
|
| 686 |
+
|
| 687 |
+
* fix: neuronpedia uses api_key for uploading features, and update sae_id->sae_set ([`0336a35`](https://github.com/jbloomAus/SAEDashboard/commit/0336a3587f825f0be15af79cc9a0033dda3d4a3f))
|
| 688 |
+
|
| 689 |
+
### Unknown
|
| 690 |
+
|
| 691 |
+
* Merge pull request #11 from jbloomAus/ignore_bos_option
|
| 692 |
+
|
| 693 |
+
Ignore bos option ([`ae34b70`](https://github.com/jbloomAus/SAEDashboard/commit/ae34b70b61993b4cce49a758bf85514410c67bd8))
|
| 694 |
+
|
| 695 |
+
* change threshold ([`4a0be67`](https://github.com/jbloomAus/SAEDashboard/commit/4a0be67622826f879191ced225c8c075d34bfe56))
|
| 696 |
+
|
| 697 |
+
* type fix ([`525b6a1`](https://github.com/jbloomAus/SAEDashboard/commit/525b6a10331b9fa0a464ae0c7f01af90ae97d0bb))
|
| 698 |
+
|
| 699 |
+
* default ignore bos eos pad ([`d2396a7`](https://github.com/jbloomAus/SAEDashboard/commit/d2396a714dd9ea3d59e516aa0fe30a9c9225e22f))
|
| 700 |
+
|
| 701 |
+
* ignore bos tokens ([`96cf6e9`](https://github.com/jbloomAus/SAEDashboard/commit/96cf6e9427cadf13fa13b55b7d1bc83ae81d9ec0))
|
| 702 |
+
|
| 703 |
+
* jump relu support in feature masking context ([`a1ba87a`](https://github.com/jbloomAus/SAEDashboard/commit/a1ba87a5c5e03687d7d7b5c5677bd9773fa49517))
|
| 704 |
+
|
| 705 |
+
* depend on latest sae lens ([`4988207`](https://github.com/jbloomAus/SAEDashboard/commit/4988207abaca24256f52235e474fe5fbb5028c1a))
|
| 706 |
+
|
| 707 |
+
* Merge pull request #10 from jbloomAus/auth_and_sae_set
|
| 708 |
+
|
| 709 |
+
fix: neuronpedia uses api_key for uploading features, and update sae_id -> sae_set ([`4684aca`](https://github.com/jbloomAus/SAEDashboard/commit/4684aca54b69dbc913c1122f1a322ed4d808dce0))
|
| 710 |
+
|
| 711 |
+
* Combine upload-features and upload-dead-stubs ([`faac839`](https://github.com/jbloomAus/SAEDashboard/commit/faac8398fee8582b12c2d1a29df6d4de7e542bed))
|
| 712 |
+
|
| 713 |
+
* Activation store device should be cuda when available ([`93050b1`](https://github.com/jbloomAus/SAEDashboard/commit/93050b1f5c2b87c8e889fe3449d440016c996762))
|
| 714 |
+
|
| 715 |
+
* Activation store device should be cuda when available ([`4469066`](https://github.com/jbloomAus/SAEDashboard/commit/4469066af06bb4944832f2e596e36afa09adf160))
|
| 716 |
+
|
| 717 |
+
* Better support for huggingface dataset path ([`3dc4b78`](https://github.com/jbloomAus/SAEDashboard/commit/3dc4b783a1ced7b938ab45c4d10effedd148a829))
|
| 718 |
+
|
| 719 |
+
* Docker tweak ([`a1a70cb`](https://github.com/jbloomAus/SAEDashboard/commit/a1a70cb28c726887de9439024b7b1d01082d3932))
|
| 720 |
+
|
| 721 |
+
## v0.2.2 (2024-07-12)
|
| 722 |
+
|
| 723 |
+
### Fix
|
| 724 |
+
|
| 725 |
+
* fix: don't sample too many tokens + other fixes
|
| 726 |
+
|
| 727 |
+
fix: don't sample too many tokens ([`b2554b0`](https://github.com/jbloomAus/SAEDashboard/commit/b2554b017e75d14b38b343fc6e0c1bcc32be2359))
|
| 728 |
+
|
| 729 |
+
* fix: don't sample too many tokens ([`0cbb2ed`](https://github.com/jbloomAus/SAEDashboard/commit/0cbb2edb480b83823dc1a98dd7e5978ecdda0d81))
|
| 730 |
+
|
| 731 |
+
### Unknown
|
| 732 |
+
|
| 733 |
+
* - Don't force manual overrides for dtype - default to SAE's dtype
|
| 734 |
+
- Add n_prompts_in_forward_pass to neuronpedia.py
|
| 735 |
+
- Add n_prompts_total, n_tokens_in_prompt, and dataset to neuronpedia artifact
|
| 736 |
+
- Remove NPDashboardSettings for now (just save the NPRunnerConfig later)
|
| 737 |
+
- Fix lint error
|
| 738 |
+
- Consolidate minibatch_size_features/tokens to n_feats_at_a_time and n_prompts_in_fwd_pass
|
| 739 |
+
- Update/Fix NP acceptance test ([`b6282c8`](https://github.com/jbloomAus/SAEDashboard/commit/b6282c83e1898e356e271af0926e2271fb23f707))
|
| 740 |
+
|
| 741 |
+
* Merge pull request #7 from jbloomAus/performance-improvement
|
| 742 |
+
|
| 743 |
+
feat: performance improvement ([`f98b3dc`](https://github.com/jbloomAus/SAEDashboard/commit/f98b3dcf84c42687dfc92fa38377edd1c3f6fa30))
|
| 744 |
+
|
| 745 |
+
* delete unused snapshots ([`4210b48`](https://github.com/jbloomAus/SAEDashboard/commit/4210b48608792adc9b841ea92a64050311e66cd6))
|
| 746 |
+
|
| 747 |
+
* format ([`de57a2d`](https://github.com/jbloomAus/SAEDashboard/commit/de57a2d84564fc0eb7d5e42799c00f73c7007cf8))
|
| 748 |
+
|
| 749 |
+
* linter ([`4725ffa`](https://github.com/jbloomAus/SAEDashboard/commit/4725ffa2cbe743aa0bb615213f11105b6911f10d))
|
| 750 |
+
|
| 751 |
+
* hope flaky tests start passing ([`8ac9e8e`](https://github.com/jbloomAus/SAEDashboard/commit/8ac9e8e93127d4ab811019fc62bbe050a9a00e2c))
|
| 752 |
+
|
| 753 |
+
* np.memmap caching and more explicit hyperparams ([`9a24186`](https://github.com/jbloomAus/SAEDashboard/commit/9a24186cc1c118725c6db7dc3c77feb815cf938f))
|
| 754 |
+
|
| 755 |
+
* Move docker" ([`27b1a27`](https://github.com/jbloomAus/SAEDashboard/commit/27b1a27118bcccf54576eb1891b936bd92848f3f))
|
| 756 |
+
|
| 757 |
+
* Add docker to workflow ([`a354fa4`](https://github.com/jbloomAus/SAEDashboard/commit/a354fa47cfb005dd2304b4237f9182e2408daeed))
|
| 758 |
+
|
| 759 |
+
* Dockerignore file ([`ed9fcf3`](https://github.com/jbloomAus/SAEDashboard/commit/ed9fcf3a634cd57f6517170784d56d86431e1710))
|
| 760 |
+
|
| 761 |
+
* new versions ([`f64e54d`](https://github.com/jbloomAus/SAEDashboard/commit/f64e54df5c1b643fc3acaff7f4d40d5597edf61a))
|
| 762 |
+
|
| 763 |
+
* Add tools to docker image ([`2a70f64`](https://github.com/jbloomAus/SAEDashboard/commit/2a70f64cfd4177d807a8345e64699054dd103e8d))
|
| 764 |
+
|
| 765 |
+
* Fix docker ([`3805f20`](https://github.com/jbloomAus/SAEDashboard/commit/3805f20bff622582d16fd6603bef4b77e6bada9e))
|
| 766 |
+
|
| 767 |
+
* Fix docker image ([`7f9ff2f`](https://github.com/jbloomAus/SAEDashboard/commit/7f9ff2f9b10ce08264b2153e8191eca32f9ee48a))
|
| 768 |
+
|
| 769 |
+
* Fix NP simple test, remove check for correlated neurons/features ([`355fad5`](https://github.com/jbloomAus/SAEDashboard/commit/355fad58ab2ab036a33375c02d9006db634702b9))
|
| 770 |
+
|
| 771 |
+
* Dockerfile, small batching fix ([`4df4c51`](https://github.com/jbloomAus/SAEDashboard/commit/4df4c5138341a1c233c3d0fe1a3d399846e92407))
|
| 772 |
+
|
| 773 |
+
* set sae_device, activation_store device ([`6d65b22`](https://github.com/jbloomAus/SAEDashboard/commit/6d65b22ef541326cc9558119b40baeb95cc2e47e))
|
| 774 |
+
|
| 775 |
+
* Fix NP dtype error ([`8bb4d9d`](https://github.com/jbloomAus/SAEDashboard/commit/8bb4d9de0c75ffed5daaba4d5ec563fbbee38f86))
|
| 776 |
+
|
| 777 |
+
* format ([`f667d92`](https://github.com/jbloomAus/SAEDashboard/commit/f667d92d9359e5c7976e21e821ac0dde8a081da6))
|
| 778 |
+
|
| 779 |
+
* depend on latest sae_lens ([`4a2a6a0`](https://github.com/jbloomAus/SAEDashboard/commit/4a2a6a0fd70d7b4a3f1f870a510a800b31f57264))
|
| 780 |
+
|
| 781 |
+
* use a much better method for getting subsets of feature activations ([`7101f13`](https://github.com/jbloomAus/SAEDashboard/commit/7101f13e13b4de5659623433ec359ecf2142daef))
|
| 782 |
+
|
| 783 |
+
* add to gitignore ([`20180e0`](https://github.com/jbloomAus/SAEDashboard/commit/20180e06a279ef93d6127b467511911db352bce5))
|
| 784 |
+
|
| 785 |
+
* add isort ([`3ab0fda`](https://github.com/jbloomAus/SAEDashboard/commit/3ab0fdaf75f735ec2eedc904529909111d0db0de))
|
| 786 |
+
|
| 787 |
+
## v0.2.1 (2024-07-08)
|
| 788 |
+
|
| 789 |
+
### Fix
|
| 790 |
+
|
| 791 |
+
* fix: trigger release ([`87bf0b5`](https://github.com/jbloomAus/SAEDashboard/commit/87bf0b5f21f0d1f5397e514090601ec21c718e35))
|
| 792 |
+
|
| 793 |
+
### Unknown
|
| 794 |
+
|
| 795 |
+
* Merge pull request #6 from jbloomAus/fix-bfloat16
|
| 796 |
+
|
| 797 |
+
fix bfloat 16 error ([`2f3c597`](https://github.com/jbloomAus/SAEDashboard/commit/2f3c597c1795357679e92caec3dd7e522c669fdb))
|
| 798 |
+
|
| 799 |
+
* fix bfloat 16 error ([`63c3c62`](https://github.com/jbloomAus/SAEDashboard/commit/63c3c62f0a03e5656ed78cc0e8f853bea3f0938e))
|
| 800 |
+
|
| 801 |
+
* Merge pull request #5 from jbloomAus/np-updates
|
| 802 |
+
|
| 803 |
+
Updates + fixes for Neuronpedia ([`9e6b5c4`](https://github.com/jbloomAus/SAEDashboard/commit/9e6b5c427024b8a468b0d06e4e096c2561c35d5d))
|
| 804 |
+
|
| 805 |
+
* Fix SAELens compatibility ([`139e1a2`](https://github.com/jbloomAus/SAEDashboard/commit/139e1a2f219d790c6f8faa9be34d9fbc9403dda3))
|
| 806 |
+
|
| 807 |
+
* Rename file ([`16709ad`](https://github.com/jbloomAus/SAEDashboard/commit/16709add9ee5063b3682be34eef0aea2ddf4eceb))
|
| 808 |
+
|
| 809 |
+
* Fix type ([`6b20386`](https://github.com/jbloomAus/SAEDashboard/commit/6b2038682ca41423dda3a3597bbe88120b120262))
|
| 810 |
+
|
| 811 |
+
* Make Neuronpedia outputs an object, and add a real acceptance test ([`a5db256`](https://github.com/jbloomAus/SAEDashboard/commit/a5db2560e5f90a49257124635b3fdbee117ed860))
|
| 812 |
+
|
| 813 |
+
* Np Runner: Multi-gpu defaults ([`07f7128`](https://github.com/jbloomAus/SAEDashboard/commit/07f71282681ffa801dd15f9265be349cd5745b42))
|
| 814 |
+
|
| 815 |
+
* Ensure minibatch is on correct device ([`e206546`](https://github.com/jbloomAus/SAEDashboard/commit/e2065462c445df0e0985fb6588d4c01cb39bbef5))
|
| 816 |
+
|
| 817 |
+
* NP Runner: Automatically use multi-gpu, devices ([`bf280e6`](https://github.com/jbloomAus/SAEDashboard/commit/bf280e685dc4dd2018cd41aa94a29bc853fcee18))
|
| 818 |
+
|
| 819 |
+
* Allow dtype override ([`a40077d`](https://github.com/jbloomAus/SAEDashboard/commit/a40077dac1fa2ae880fcdabe3227878ef2cfaebe))
|
| 820 |
+
|
| 821 |
+
* NP-Runner: Remove unnecessary layer of batching. ([`e2ac92b`](https://github.com/jbloomAus/SAEDashboard/commit/e2ac92b036d0192e132c8a8700a5a2f448d1983b))
|
| 822 |
+
|
| 823 |
+
* NP Runner: Allow skipping sparsity check ([`ef74d2a`](https://github.com/jbloomAus/SAEDashboard/commit/ef74d2aeea2463afe150a5e8824da5a5206cd3d0))
|
| 824 |
+
|
| 825 |
+
* Merge pull request #2 from jbloomAus/multiple-devices
|
| 826 |
+
|
| 827 |
+
feat: Multiple devices ([`535e6c9`](https://github.com/jbloomAus/SAEDashboard/commit/535e6c9689d855f82a6ddfd9f169720fe367bde3))
|
| 828 |
+
|
| 829 |
+
* format ([`7f892ad`](https://github.com/jbloomAus/SAEDashboard/commit/7f892ad0efb42025df0bcf26bdddd6fac4c2d8b1))
|
| 830 |
+
|
| 831 |
+
* NP runner takes device args seperately ([`8fc31dd`](https://github.com/jbloomAus/SAEDashboard/commit/8fc31dd6ccd59f4f35742a4e15c380673c8cb2a3))
|
| 832 |
+
|
| 833 |
+
* multi-gpu-support ([`5e24e4e`](https://github.com/jbloomAus/SAEDashboard/commit/5e24e4e6598dd7943f8d677042dcf84bc6f7a0a6))
|
| 834 |
+
|
| 835 |
+
## v0.2.0 (2024-06-10)
|
| 836 |
+
|
| 837 |
+
### Feature
|
| 838 |
+
|
| 839 |
+
* feat: experimental release 2 ([`e264f97`](https://github.com/jbloomAus/SAEDashboard/commit/e264f97d90299f6ade294db8ed03aed9cd7491ee))
|
| 840 |
+
|
| 841 |
+
## v0.1.0 (2024-06-10)
|
| 842 |
+
|
| 843 |
+
### Feature
|
| 844 |
+
|
| 845 |
+
* feat: experimental release ([`d79310a`](https://github.com/jbloomAus/SAEDashboard/commit/d79310a7b6599f7b813e214c9268d736e0cb87f0))
|
| 846 |
+
|
| 847 |
+
### Unknown
|
| 848 |
+
|
| 849 |
+
* fix pyproject.toml ([`a27c87d`](https://github.com/jbloomAus/SAEDashboard/commit/a27c87da987f043b470abce3404e305ec3f0d620))
|
| 850 |
+
|
| 851 |
+
* test deployment ([`288a2d9`](https://github.com/jbloomAus/SAEDashboard/commit/288a2d9bf797a1a2f9947b1ceac5e47edc1684ba))
|
| 852 |
+
|
| 853 |
+
* refactor np runner and add acceptance test ([`212593c`](https://github.com/jbloomAus/SAEDashboard/commit/212593c33b3aec33078a121738c0a826f705722f))
|
| 854 |
+
|
| 855 |
+
* Fix: Default context tokens length for neuronpedia runner ([`aefe95c`](https://github.com/jbloomAus/SAEDashboard/commit/aefe95cb1be4139ac45f042abdc78e0feccfb490))
|
| 856 |
+
|
| 857 |
+
* Allow custom context tokens length for Neuronpedia runner ([`d204cc8`](https://github.com/jbloomAus/SAEDashboard/commit/d204cc8fbb2ef376a1a5e00cd4f1cc5db2afb279))
|
| 858 |
+
|
| 859 |
+
* Fix: Streaming default true ([`1b91dff`](https://github.com/jbloomAus/SAEDashboard/commit/1b91dff045fdbd8c118c5f209750eca60c260f5f))
|
| 860 |
+
|
| 861 |
+
* Fix n_devices error for non-cuda ([`70b2dbd`](https://github.com/jbloomAus/SAEDashboard/commit/70b2dbdb2da51f5d78b1c2ce3210865fc259c97b))
|
| 862 |
+
|
| 863 |
+
* fix import path for ci ([`3bd4687`](https://github.com/jbloomAus/SAEDashboard/commit/3bd468727e2ab0b7d77224b7c0dad88e0727b773))
|
| 864 |
+
|
| 865 |
+
* make pyright happy, start config ([`b39ae85`](https://github.com/jbloomAus/SAEDashboard/commit/b39ae85d938a0db7c70b7dff9683f68f255dfb67))
|
| 866 |
+
|
| 867 |
+
* add black ([`236855b`](https://github.com/jbloomAus/SAEDashboard/commit/236855be1ef1464ea85b2afc6aaee963326f9257))
|
| 868 |
+
|
| 869 |
+
* fix ci ([`12818d7`](https://github.com/jbloomAus/SAEDashboard/commit/12818d7e6cd3e483258598b668805c1a9a048049))
|
| 870 |
+
|
| 871 |
+
* add pytest cov ([`aae0571`](https://github.com/jbloomAus/SAEDashboard/commit/aae057159639cd247a82fdeda9eddb98612ceec6))
|
| 872 |
+
|
| 873 |
+
* bring checks in line with sae_lens ([`7cd9679`](https://github.com/jbloomAus/SAEDashboard/commit/7cd9679cc18c64a7c8a0a07a1f12e6fc87543537))
|
| 874 |
+
|
| 875 |
+
* activation scaling factor ([`333d377`](https://github.com/jbloomAus/SAEDashboard/commit/333d3770d0d1d3c40dfeb3335dcfc46e9b7da717))
|
| 876 |
+
|
| 877 |
+
* Move Neuronpedia runner to SAEDashboard ([`4e691ea`](https://github.com/jbloomAus/SAEDashboard/commit/4e691eaad919e12b9cae6ff707eaa3cf322ea030))
|
| 878 |
+
|
| 879 |
+
* fold w_dec norm by default ([`b6c9bc7`](https://github.com/jbloomAus/SAEDashboard/commit/b6c9bc70dc419d1e32bfb5580997369215e15429))
|
| 880 |
+
|
| 881 |
+
* rename sae_vis to sae_dashboard ([`f0f5341`](https://github.com/jbloomAus/SAEDashboard/commit/f0f5341ffdf31a11884777d6ba8100cd302b9dab))
|
| 882 |
+
|
| 883 |
+
* rename feature data generator ([`e02ed0a`](https://github.com/jbloomAus/SAEDashboard/commit/e02ed0a18e92c497aea3e137cf43e9f354f8f30f))
|
| 884 |
+
|
| 885 |
+
* update demo ([`8aa9e52`](https://github.com/jbloomAus/SAEDashboard/commit/8aa9e5272f54d04b741e63aa335bfa1212a2d0f7))
|
| 886 |
+
|
| 887 |
+
* add demo ([`dd3036f`](https://github.com/jbloomAus/SAEDashboard/commit/dd3036f90e6a4ed459ec21647744d491911900ac))
|
| 888 |
+
|
| 889 |
+
* delete old demo files ([`3d86202`](https://github.com/jbloomAus/SAEDashboard/commit/3d8620204cf6acb21b5e7f9983c300341345cd88))
|
| 890 |
+
|
| 891 |
+
* remove unnecessary print statement ([`9d3d937`](https://github.com/jbloomAus/SAEDashboard/commit/9d3d937e74f5575dde68d5a21fb73ce6f826d0d4))
|
| 892 |
+
|
| 893 |
+
* set sae lens version ([`87a7691`](https://github.com/jbloomAus/SAEDashboard/commit/87a76911ff0f0d46ab421d9b5107aef27216e88b))
|
| 894 |
+
|
| 895 |
+
* update older readme ([`c5c98e5`](https://github.com/jbloomAus/SAEDashboard/commit/c5c98e53531874efab5bc16235d9c72816fa61d5))
|
| 896 |
+
|
| 897 |
+
* test ([`923da42`](https://github.com/jbloomAus/SAEDashboard/commit/923da427b56178acd99b988d6d6b51368b5d2359))
|
| 898 |
+
|
| 899 |
+
* remove sae lens dep ([`2c26d5f`](https://github.com/jbloomAus/SAEDashboard/commit/2c26d5f4c40c41f750971601968577f316e15598))
|
| 900 |
+
|
| 901 |
+
* Merge branch 'refactor_b' ([`3154d63`](https://github.com/jbloomAus/SAEDashboard/commit/3154d636e1a9f8a30b54c17e62a842bed3f8b2a1))
|
| 902 |
+
|
| 903 |
+
* pass linting ([`0c079a1`](https://github.com/jbloomAus/SAEDashboard/commit/0c079a105b1b98e0edf2ff1a15593567c81bb103))
|
| 904 |
+
|
| 905 |
+
* format ([`6f37e2e`](https://github.com/jbloomAus/SAEDashboard/commit/6f37e2eb050a3207a2d3b9defd5d416645215c7c))
|
| 906 |
+
|
| 907 |
+
* run ci on all branches ([`faa0cc4`](https://github.com/jbloomAus/SAEDashboard/commit/faa0cc4eed4ff35f1e04656a968214c4fefbd573))
|
| 908 |
+
|
| 909 |
+
* don't use feature ablations ([`dc6e6dc`](https://github.com/jbloomAus/SAEDashboard/commit/dc6e6dc2d2affce331894d8bb61942e103182652))
|
| 910 |
+
|
| 911 |
+
* mock information in sequences to make normal sequence generation pass ([`c87b82f`](https://github.com/jbloomAus/SAEDashboard/commit/c87b82fdcc5e849d970cdc8bd1e841ec3e3e48ce))
|
| 912 |
+
|
| 913 |
+
* Remove resid ([`ff83737`](https://github.com/jbloomAus/SAEDashboard/commit/ff837373b65e60d8a9ba7c6e61f78bddc4d170f2))
|
| 914 |
+
|
| 915 |
+
* adding a test for direct_effect_feature_ablation_experiment ([`a9f3d1b`](https://github.com/jbloomAus/SAEDashboard/commit/a9f3d1b8021d8eeb60cf465934037d07583fa0b2))
|
| 916 |
+
|
| 917 |
+
* shortcut direct_effect_feature_ablation_experiment if everything is zero ([`2c68ff0`](https://github.com/jbloomAus/SAEDashboard/commit/2c68ff0c8496c58cc0732f3c51905c9c9f405393))
|
| 918 |
+
|
| 919 |
+
* fixing CI and replacing manual snapshots with syrupy snapshots ([`3b97640`](https://github.com/jbloomAus/SAEDashboard/commit/3b97640803cab3e3915202ac80c43b855c69c1cb))
|
| 920 |
+
|
| 921 |
+
* more refactor, WIP ([`81657c8`](https://github.com/jbloomAus/SAEDashboard/commit/81657c8c897a81102c0df7b29c49d526e639bb44))
|
| 922 |
+
|
| 923 |
+
* continue refactor, make data generator ([`eb1ae0f`](https://github.com/jbloomAus/SAEDashboard/commit/eb1ae0fc621407b33481c50b78b041079b08393d))
|
| 924 |
+
|
| 925 |
+
* add use of safetensors cache for repeated calculations ([`a241c32`](https://github.com/jbloomAus/SAEDashboard/commit/a241c322334340a84c2a252bc0b4a40ed2f19bc9))
|
| 926 |
+
|
| 927 |
+
* more refactor / benchmarking ([`d65ee87`](https://github.com/jbloomAus/SAEDashboard/commit/d65ee87cd191b2ed279f9f6efabb9e98bb700855))
|
| 928 |
+
|
| 929 |
+
* only run unit tests ([`5f11ddd`](https://github.com/jbloomAus/SAEDashboard/commit/5f11ddd9bc25f9c9bb7cbeba11224ba12b260ea8))
|
| 930 |
+
|
| 931 |
+
* fix lint issue ([`24daf17`](https://github.com/jbloomAus/SAEDashboard/commit/24daf17cb92534901681affbcebea314e2cf6580))
|
| 932 |
+
|
| 933 |
+
* format ([`83e89ed`](https://github.com/jbloomAus/SAEDashboard/commit/83e89ed4860d886ccf19be591bf72d0e029e7344))
|
| 934 |
+
|
| 935 |
+
* organise tests, make sure only unit tests run on CI ([`21f5fb1`](https://github.com/jbloomAus/SAEDashboard/commit/21f5fb155665329531b16d10673ddd988e7034ea))
|
| 936 |
+
|
| 937 |
+
* see if we can do some caching ([`c1dca6f`](https://github.com/jbloomAus/SAEDashboard/commit/c1dca6faa61de0849453acc83ae23baab6cf48be))
|
| 938 |
+
|
| 939 |
+
* more refactoring ([`b3f0f41`](https://github.com/jbloomAus/SAEDashboard/commit/b3f0f41f36f0eee57a08142880d4b6654309e62c))
|
| 940 |
+
|
| 941 |
+
* further refactor, possible significant speed up ([`ddd3496`](https://github.com/jbloomAus/SAEDashboard/commit/ddd3496206c0f3e751b596ca51e3544c77ddaf94))
|
| 942 |
+
|
| 943 |
+
* more refactor ([`a5f6deb`](https://github.com/jbloomAus/SAEDashboard/commit/a5f6deb4263c58803e7af23d767fc5cb17dfd2b2))
|
| 944 |
+
|
| 945 |
+
* refactoring in progress ([`d210b60`](https://github.com/jbloomAus/SAEDashboard/commit/d210b6056aa5316d2fd917e24ca8a819331a8114))
|
| 946 |
+
|
| 947 |
+
* use named arguments ([`4a81053`](https://github.com/jbloomAus/SAEDashboard/commit/4a8105355d3b86e460f32cd5c736dde0dbeaa2e3))
|
| 948 |
+
|
| 949 |
+
* remove create method ([`43b2018`](https://github.com/jbloomAus/SAEDashboard/commit/43b20184ed5ed0c2f08cfd13423f2271fd871274))
|
| 950 |
+
|
| 951 |
+
* move chunk ([`0f26aa8`](https://github.com/jbloomAus/SAEDashboard/commit/0f26aa85bc9fbe358f4c5f90971d51b86159f095))
|
| 952 |
+
|
| 953 |
+
* use fixtures ([`7c11dd9`](https://github.com/jbloomAus/SAEDashboard/commit/7c11dd914d467957e5c00b914f302c291924e411))
|
| 954 |
+
|
| 955 |
+
* refactor to create runner ([`9202c19`](https://github.com/jbloomAus/SAEDashboard/commit/9202c19f4ad6134eb6b68f857c9e4bfd0b911cf8))
|
| 956 |
+
|
| 957 |
+
* format ([`abd8747`](https://github.com/jbloomAus/SAEDashboard/commit/abd87472b76cfc151abfe2a6e312ea43b29c2250))
|
| 958 |
+
|
| 959 |
+
* target ci at this branch ([`ea3b2a3`](https://github.com/jbloomAus/SAEDashboard/commit/ea3b2a3181f2eb1ff52d83b2040b586d6fdfef4a))
|
| 960 |
+
|
| 961 |
+
* comment out release process for now ([`7084b5b`](https://github.com/jbloomAus/SAEDashboard/commit/7084b5ba3325bb559a8377d379bd2f3ba6d68348))
|
| 962 |
+
|
| 963 |
+
* test generated output ([`7b8b2ab`](https://github.com/jbloomAus/SAEDashboard/commit/7b8b2abd94213d67c378b7746107f6a7c811d93c))
|
| 964 |
+
|
| 965 |
+
* commit current demo html ([`00a03a0`](https://github.com/jbloomAus/SAEDashboard/commit/00a03a02fbf181caa55704defac25578b4444452))
|
| 966 |
+
|
| 967 |
+
## v0.0.1 (2024-04-25)
|
| 968 |
+
|
| 969 |
+
### Chore
|
| 970 |
+
|
| 971 |
+
* chore: setting up pytest ([`2079d00`](https://github.com/jbloomAus/SAEDashboard/commit/2079d00911d1a00ee19cde478b5cab61ca9c0495))
|
| 972 |
+
|
| 973 |
+
* chore: setting up semantic-release ([`09075af`](https://github.com/jbloomAus/SAEDashboard/commit/09075afbec279fb89d157f73e9a0ed47ba66d3c8))
|
| 974 |
+
|
| 975 |
+
### Fix
|
| 976 |
+
|
| 977 |
+
* fix: remove circular dep with sae lens ([`1dd9f6c`](https://github.com/jbloomAus/SAEDashboard/commit/1dd9f6cd22f879e8d6904ba72f3e52b4344433cd))
|
| 978 |
+
|
| 979 |
+
### Unknown
|
| 980 |
+
|
| 981 |
+
* Merge pull request #44 from chanind/pytest-setup
|
| 982 |
+
|
| 983 |
+
chore: setting up pytest ([`034eefa`](https://github.com/jbloomAus/SAEDashboard/commit/034eefa5a4163e9a560b574e2e255cd06f8f49a1))
|
| 984 |
+
|
| 985 |
+
* Merge pull request #43 from callummcdougall/move_saelens_dep
|
| 986 |
+
|
| 987 |
+
Remove dependency on saelens from pyproject, add to demo.ipynb ([`147d87e`](https://github.com/jbloomAus/SAEDashboard/commit/147d87ee9534d30e764851cbe73aadb5783d2515))
|
| 988 |
+
|
| 989 |
+
* Add missing matplotlib ([`572a3cc`](https://github.com/jbloomAus/SAEDashboard/commit/572a3cc79709a14117bbeafb871a33f0107600d8))
|
| 990 |
+
|
| 991 |
+
* Remove dependency on saelens from pyproject, add to demo.ipynb ([`1e6f3cf`](https://github.com/jbloomAus/SAEDashboard/commit/1e6f3cf9b2bcfb381a73d9333581c430faa531fd))
|
| 992 |
+
|
| 993 |
+
* Merge branch 'main' of https://github.com/callummcdougall/sae_vis ([`4e7a24c`](https://github.com/jbloomAus/SAEDashboard/commit/4e7a24c37444f11d718035eede68ac728d949a20))
|
| 994 |
+
|
| 995 |
+
* Merge pull request #41 from callummcdougall/allow_disable_buffer
|
| 996 |
+
|
| 997 |
+
oops I forgot to switch back to main before pushing ([`1312cd0`](https://github.com/jbloomAus/SAEDashboard/commit/1312cd09d6e274b1163e79d2ac01f2df54c65157))
|
| 998 |
+
|
| 999 |
+
* Merge branch 'main' into allow_disable_buffer ([`e7edf5a`](https://github.com/jbloomAus/SAEDashboard/commit/e7edf5a9bae4714bf4983ce6a19a0fe6fdf1f118))
|
| 1000 |
+
|
| 1001 |
+
* Merge pull request #40 from chanind/semantic-release-autodeploy
|
| 1002 |
+
|
| 1003 |
+
chore: setting up semantic-release for auto-deploy ([`a4d44d1`](https://github.com/jbloomAus/SAEDashboard/commit/a4d44d1a0e86055fb82ef41f51f0adbb7868df3c))
|
| 1004 |
+
|
| 1005 |
+
* Merge pull request #38 from chanind/type-checking
|
| 1006 |
+
|
| 1007 |
+
Enabling type checking with Pyright ([`f1fd792`](https://github.com/jbloomAus/SAEDashboard/commit/f1fd7926f46f00dca46024377f53aa8f2db98773))
|
| 1008 |
+
|
| 1009 |
+
* enabling type checking with Pyright ([`05d14ea`](https://github.com/jbloomAus/SAEDashboard/commit/05d14eafea707d3db81e78b4be87199087cb8e37))
|
| 1010 |
+
|
| 1011 |
+
* Merge pull request #39 from callummcdougall/fix_loading_saelens_sae
|
| 1012 |
+
|
| 1013 |
+
FIX: SAELens new format has "scaling_factor" key, which causes assert to fail ([`983aee5`](https://github.com/jbloomAus/SAEDashboard/commit/983aee562aea31e90657caf8c6ab6e450e952120))
|
| 1014 |
+
|
| 1015 |
+
* Fix Formatting ([`13b8106`](https://github.com/jbloomAus/SAEDashboard/commit/13b81062485f5dce2568e7832bfb2aae218dd4e9))
|
| 1016 |
+
|
| 1017 |
+
* Merge branch 'main' into fix_loading_saelens_sae ([`21b0086`](https://github.com/jbloomAus/SAEDashboard/commit/21b0086b8af3603441795e925a15e7cded122acb))
|
| 1018 |
+
|
| 1019 |
+
* format ([`8f1506b`](https://github.com/jbloomAus/SAEDashboard/commit/8f1506b6eb7dc0a2d4437d2aa23a0898c46a156d))
|
| 1020 |
+
|
| 1021 |
+
* Allow SAELens autoencoder keys to be superset of required keys, instead of exact match ([`6852170`](https://github.com/jbloomAus/SAEDashboard/commit/6852170d55e7d3cf22632c5807cfab219516da98))
|
| 1022 |
+
|
| 1023 |
+
* v0.2.17 ([`2bb14da`](https://github.com/jbloomAus/SAEDashboard/commit/2bb14daa88a0af601e13f4e51b50a2b00cd75b48))
|
| 1024 |
+
|
| 1025 |
+
* Use main branch of SAELens ([`2b34505`](https://github.com/jbloomAus/SAEDashboard/commit/2b345052bdc92ee9c1255cab0978916307a0a9dc))
|
| 1026 |
+
|
| 1027 |
+
* Update version 0.2.16 ([`bf90293`](https://github.com/jbloomAus/SAEDashboard/commit/bf902930844db9b0f8db4fbe8b3610557352660b))
|
| 1028 |
+
|
| 1029 |
+
* Merge pull request #36 from callummcdougall/allow_disable_buffer
|
| 1030 |
+
|
| 1031 |
+
FEATURE: Allow setting buffer to None, which gives the whole activation sequence ([`f5f9594`](https://github.com/jbloomAus/SAEDashboard/commit/f5f9594fcaf5edb6036a85446e092278004ea200))
|
| 1032 |
+
|
| 1033 |
+
* 16 ([`64e7018`](https://github.com/jbloomAus/SAEDashboard/commit/64e701849570d9e172dc065812c9a3e7149a9176))
|
| 1034 |
+
|
| 1035 |
+
* version 0.2.16 ([`afca0be`](https://github.com/jbloomAus/SAEDashboard/commit/afca0be8826e0c007b5730fa9fa18454699d16a3))
|
| 1036 |
+
|
| 1037 |
+
* Fix version ([`5a43916`](https://github.com/jbloomAus/SAEDashboard/commit/5a43916cbd9836396f051f7a258fdca8664e05e9))
|
| 1038 |
+
|
| 1039 |
+
* fix all indices view ([`5f87d52`](https://github.com/jbloomAus/SAEDashboard/commit/5f87d52154d6a8e8c8984836bbe8f85ee25f279d))
|
| 1040 |
+
|
| 1041 |
+
* Merge branch 'fix_gpt2_demo' into allow_disable_buffer ([`ea57bfc`](https://github.com/jbloomAus/SAEDashboard/commit/ea57bfc2ee1e23666810982abf32e6e9cbb74193))
|
| 1042 |
+
|
| 1043 |
+
* Allow disabling the buffer ([`c1be9f8`](https://github.com/jbloomAus/SAEDashboard/commit/c1be9f8e4b8ee6d8f18c4a1a0445840304440c1d))
|
| 1044 |
+
|
| 1045 |
+
* fix conflicts ([`ea3d624`](https://github.com/jbloomAus/SAEDashboard/commit/ea3d624013b9aa7cbd2d6eaa7212a1f7c4ee8e28))
|
| 1046 |
+
|
| 1047 |
+
* Merge pull request #35 from callummcdougall/fix_gpt2_demo
|
| 1048 |
+
|
| 1049 |
+
Fix usage of SAELens and demo notebook ([`88b5933`](https://github.com/jbloomAus/SAEDashboard/commit/88b59338d3cadbd5c70f0c1117dff00f01a54e6a))
|
| 1050 |
+
|
| 1051 |
+
* Import updated SAELens, use correct tokens, fix missing file cfg.json file error. ([`14ba9b0`](https://github.com/jbloomAus/SAEDashboard/commit/14ba9b03d4ce791ba8f4cac553fb82a93c47dfb8))
|
| 1052 |
+
|
| 1053 |
+
* Merge pull request #34 from ArthurConmy/patch-1
|
| 1054 |
+
|
| 1055 |
+
Update README.md ([`3faac82`](https://github.com/jbloomAus/SAEDashboard/commit/3faac82686f546800492d8aeb5e1d5919cbf1517))
|
| 1056 |
+
|
| 1057 |
+
* Update README.md ([`416eca8`](https://github.com/jbloomAus/SAEDashboard/commit/416eca8073c6cb2b120c759330ec47f52ab32d1e))
|
| 1058 |
+
|
| 1059 |
+
* Merge pull request #33 from chanind/setup-poetry-and-ruff
|
| 1060 |
+
|
| 1061 |
+
Setting up poetry / ruff / github actions ([`287f30f`](https://github.com/jbloomAus/SAEDashboard/commit/287f30f1d8fc39ab583f202c9277e07e5eeeaf62))
|
| 1062 |
+
|
| 1063 |
+
* setting up poetry and ruff for linting/formatting ([`0e0eba9`](https://github.com/jbloomAus/SAEDashboard/commit/0e0eba9e4d54c746cddc835ef4f6ddf2bab96844))
|
| 1064 |
+
|
| 1065 |
+
* fix feature vis demo gpt ([`821781e`](https://github.com/jbloomAus/SAEDashboard/commit/821781e96b732a5909d8735714482c965891b2ea))
|
| 1066 |
+
|
| 1067 |
+
* add scatter plot support ([`6eab28b`](https://github.com/jbloomAus/SAEDashboard/commit/6eab28bef9ef5cd9360fef73e02763301fa1a028))
|
| 1068 |
+
|
| 1069 |
+
* update setup ([`8d2ca53`](https://github.com/jbloomAus/SAEDashboard/commit/8d2ca53e8a6bba860fe71368741d06a718adaa27))
|
| 1070 |
+
|
| 1071 |
+
* fix setup ([`9cae8f4`](https://github.com/jbloomAus/SAEDashboard/commit/9cae8f461bd780e23eb2d994f56b495ede16201a))
|
| 1072 |
+
|
| 1073 |
+
* Merge branch 'main' of https://github.com/callummcdougall/sae_vis ([`ed8f8cb`](https://github.com/jbloomAus/SAEDashboard/commit/ed8f8cb7ad1fba2383dcdd471c33ce4a1b9f32e3))
|
| 1074 |
+
|
| 1075 |
+
* Merge pull request #27 from wllgrnt/will-add-eindex-dependency
|
| 1076 |
+
|
| 1077 |
+
Update setup.py with eindex dependency ([`8d7ed12`](https://github.com/jbloomAus/SAEDashboard/commit/8d7ed123505ac7ecf93dd310f57888547aead1d7))
|
| 1078 |
+
|
| 1079 |
+
* two more deps ([`7f231a8`](https://github.com/jbloomAus/SAEDashboard/commit/7f231a83acfef2494c1866249f57e10c21a1a443))
|
| 1080 |
+
|
| 1081 |
+
* Update setup.py with eindex
|
| 1082 |
+
|
| 1083 |
+
Without this, 'pip install sae-vis' will cause errors when e.g. you do 'from sae_vis.data_fetching_fns import get_feature_data' ([`a9d7de9`](https://github.com/jbloomAus/SAEDashboard/commit/a9d7de90b492f7305758e15303ba890fb9b503d0))
|
| 1084 |
+
|
| 1085 |
+
* fix sae bug ([`247d14b`](https://github.com/jbloomAus/SAEDashboard/commit/247d14b55f209ed9ccf50e5ce091ed66ffbf19d2))
|
| 1086 |
+
|
| 1087 |
+
* Merge pull request #32 from hijohnnylin/pin_older_sae_training
|
| 1088 |
+
|
| 1089 |
+
Demo notebook errors under "Multi-layer models" vis ([`9ac1dac`](https://github.com/jbloomAus/SAEDashboard/commit/9ac1dac51af32909666977cb5b3794965c70f62f))
|
| 1090 |
+
|
| 1091 |
+
* Pin older commit of mats_sae_training ([`8ca7ac1`](https://github.com/jbloomAus/SAEDashboard/commit/8ca7ac14b919fedb91240630ac7072cac40a6d6a))
|
| 1092 |
+
|
| 1093 |
+
* update version number ([`72e584b`](https://github.com/jbloomAus/SAEDashboard/commit/72e584b6492ed1ef3989968f6588a17fca758650))
|
| 1094 |
+
|
| 1095 |
+
* add gifs to readme ([`1393740`](https://github.com/jbloomAus/SAEDashboard/commit/13937405da31cca70cd1027aaca6c9cc84797ff1))
|
| 1096 |
+
|
| 1097 |
+
* test gif ([`4fbafa6`](https://github.com/jbloomAus/SAEDashboard/commit/4fbafa69343dc58dc18d0f78e393b5fcc9e24c0c))
|
| 1098 |
+
|
| 1099 |
+
* fix height issue ([`3f272f6`](https://github.com/jbloomAus/SAEDashboard/commit/3f272f61a954effef7bd648cc8117346da3bb971))
|
| 1100 |
+
|
| 1101 |
+
* fix pypi ([`7151164`](https://github.com/jbloomAus/SAEDashboard/commit/7151164cc0df8af278617147f07cbfbe3977cfeb))
|
| 1102 |
+
|
| 1103 |
+
* update setup ([`8c43478`](https://github.com/jbloomAus/SAEDashboard/commit/8c43478ad2eba8d3d4106fe4239c1229b8720fe6))
|
| 1104 |
+
|
| 1105 |
+
* Merge pull request #26 from hijohnnylin/update_html_anomalies
|
| 1106 |
+
|
| 1107 |
+
Update and add some HTML_ANOMALIES ([`1874a47`](https://github.com/jbloomAus/SAEDashboard/commit/1874a47a099ce32795bdbb5f98b9167dcca85ff2))
|
| 1108 |
+
|
| 1109 |
+
* Update and add some HTML_ANOMALIES ([`c541b7f`](https://github.com/jbloomAus/SAEDashboard/commit/c541b7f06108046ad1e2eb82c89f30f061f4411e))
|
| 1110 |
+
|
| 1111 |
+
* 0.2.9 ([`a5c8a6d`](https://github.com/jbloomAus/SAEDashboard/commit/a5c8a6d2008b818db90566cba50211845c753444))
|
| 1112 |
+
|
| 1113 |
+
* fix readme ([`5a8a7e3`](https://github.com/jbloomAus/SAEDashboard/commit/5a8a7e3173fc50fdb5ff0e56d7fa83e475af38a3))
|
| 1114 |
+
|
| 1115 |
+
* include feature tables ([`7c4c263`](https://github.com/jbloomAus/SAEDashboard/commit/7c4c263a2e069482d341b6265015664792bde817))
|
| 1116 |
+
|
| 1117 |
+
* add license ([`fa02a3d`](https://github.com/jbloomAus/SAEDashboard/commit/fa02a3dc93b721322b3902e2ac416ed156bf9d80))
|
| 1118 |
+
|
| 1119 |
+
* Merge branch 'main' of https://github.com/callummcdougall/sae_vis ([`ca5efcd`](https://github.com/jbloomAus/SAEDashboard/commit/ca5efcdc81074d3c3002bd997b35e326a44a4a25))
|
| 1120 |
+
|
| 1121 |
+
* Merge pull request #24 from chanind/fix-pypi-repo-link
|
| 1122 |
+
|
| 1123 |
+
fixing repo URL in setup.py ([`14a0be5`](https://github.com/jbloomAus/SAEDashboard/commit/14a0be54a57b1bc73ac4741611f9c8d1bd229e6f))
|
| 1124 |
+
|
| 1125 |
+
* fixing repo URL in setup.py ([`4faeca5`](https://github.com/jbloomAus/SAEDashboard/commit/4faeca5da06c0bb4384e202a91d895a217365d30))
|
| 1126 |
+
|
| 1127 |
+
* re-fix html anomalies ([`2fbae4c`](https://github.com/jbloomAus/SAEDashboard/commit/2fbae4c9a7dd663737bae25e73e978d40c59064a))
|
| 1128 |
+
|
| 1129 |
+
* fix hook point bug ([`9b573b2`](https://github.com/jbloomAus/SAEDashboard/commit/9b573b27590db1cbd6c8ef08fca7ff8c9d26b340))
|
| 1130 |
+
|
| 1131 |
+
* Merge pull request #20 from chanind/fix-final-resid-layer
|
| 1132 |
+
|
| 1133 |
+
fixing bug if hook_point == hook_point_resid_final ([`d6882e3`](https://github.com/jbloomAus/SAEDashboard/commit/d6882e3f813ef0d399e07548871f61b1f6a98ac6))
|
| 1134 |
+
|
| 1135 |
+
* fixing bug using hook_point_resid_final ([`cfe9b30`](https://github.com/jbloomAus/SAEDashboard/commit/cfe9b3042cfe127d5f7958064ffe817c25a19b56))
|
| 1136 |
+
|
| 1137 |
+
* fix indexing speed ([`865ff64`](https://github.com/jbloomAus/SAEDashboard/commit/865ff64329538641cd863dc7668dfc77907fb384))
|
| 1138 |
+
|
| 1139 |
+
* enable JSON saving ([`feea47a`](https://github.com/jbloomAus/SAEDashboard/commit/feea47a342d52296b72784ed18ea628848d4c7d4))
|
| 1140 |
+
|
| 1141 |
+
* Merge pull request #19 from chanind/support-mlp-and-attn-out
|
| 1142 |
+
|
| 1143 |
+
supporting mlp and attn out hooks ([`1c5463b`](https://github.com/jbloomAus/SAEDashboard/commit/1c5463b12f85cd0598b4e2fba5c556b1e9c0fbbe))
|
| 1144 |
+
|
| 1145 |
+
* supporting mlp and attn out hooks ([`a100e58`](https://github.com/jbloomAus/SAEDashboard/commit/a100e586498e8cae14df475bc7924cdecaed71ea))
|
| 1146 |
+
|
| 1147 |
+
* Merge branch 'main' of https://github.com/callummcdougall/sae_vis ([`083aeba`](https://github.com/jbloomAus/SAEDashboard/commit/083aeba0e4048d9976ec5cbee8df7dc8fd4db4e9))
|
| 1148 |
+
|
| 1149 |
+
* Merge pull request #18 from chanind/remove-build-artifacts
|
| 1150 |
+
|
| 1151 |
+
removing Python build artifacts and adding to .gitignore ([`b0e0594`](https://github.com/jbloomAus/SAEDashboard/commit/b0e0594590b4472b34052c6eb3ebceb6c9f58a11))
|
| 1152 |
+
|
| 1153 |
+
* removing Python build artifacts and adding to .gitignore ([`b6486f5`](https://github.com/jbloomAus/SAEDashboard/commit/b6486f56bea9d4bb7544c36afe70e6f891101b63))
|
| 1154 |
+
|
| 1155 |
+
* fix variable naming ([`2507918`](https://github.com/jbloomAus/SAEDashboard/commit/25079186b3f31d2271b1ecdb11f26904af7146d2))
|
| 1156 |
+
|
| 1157 |
+
* update readme ([`0ee3608`](https://github.com/jbloomAus/SAEDashboard/commit/0ee3608af396a1a6586dfb809f2f6480bb4f6390))
|
| 1158 |
+
|
| 1159 |
+
* update readme ([`f8351f8`](https://github.com/jbloomAus/SAEDashboard/commit/f8351f88e8432ccd4b2206e859daea316304d6c6))
|
| 1160 |
+
|
| 1161 |
+
* update version number ([`1e74408`](https://github.com/jbloomAus/SAEDashboard/commit/1e7440883f44a92705299430215f802fea4e1915))
|
| 1162 |
+
|
| 1163 |
+
* fix formatting and docstrings ([`b9fe2bb`](https://github.com/jbloomAus/SAEDashboard/commit/b9fe2bbb15a48e4b0415f6f4240d895990d54c9a))
|
| 1164 |
+
|
| 1165 |
+
* Merge pull request #17 from jordansauce/sae-agnostic-functions-new
|
| 1166 |
+
|
| 1167 |
+
Added SAE class agnostic functions ([`0039c6f`](https://github.com/jbloomAus/SAEDashboard/commit/0039c6f8f99d6e8a1b2ff56aa85f60a3eba3afb0))
|
| 1168 |
+
|
| 1169 |
+
* Added sae class agnostic functions
|
| 1170 |
+
|
| 1171 |
+
Added parse_feature_data() and parse_prompt_data() ([`e2709d0`](https://github.com/jbloomAus/SAEDashboard/commit/e2709d0b4c55d73d6026f3b9ce534f59ce61f344))
|
| 1172 |
+
|
| 1173 |
+
* add to pypi ([`02a5b9a`](https://github.com/jbloomAus/SAEDashboard/commit/02a5b9acd15433cc59d438271b9bd5e12d62b662))
|
| 1174 |
+
|
| 1175 |
+
* update notebook images ([`b87ad4d`](https://github.com/jbloomAus/SAEDashboard/commit/b87ad4d256f12c23605b0e7db307ee56913c93ef))
|
| 1176 |
+
|
| 1177 |
+
* fix layer parse and custom device ([`14c7ae9`](https://github.com/jbloomAus/SAEDashboard/commit/14c7ae9d0c8b7dad21b953cfc93fe7f34c74e149))
|
| 1178 |
+
|
| 1179 |
+
* update dropdown styling ([`83be219`](https://github.com/jbloomAus/SAEDashboard/commit/83be219bfe31b985a26762e06345c574aa0e6fe1))
|
| 1180 |
+
|
| 1181 |
+
* add custom prompt vis ([`cabdc5c`](https://github.com/jbloomAus/SAEDashboard/commit/cabdc5cb31f881cddf236490c41332c525d2ee74))
|
| 1182 |
+
|
| 1183 |
+
* d3 & multifeature refactor ([`f79a919`](https://github.com/jbloomAus/SAEDashboard/commit/f79a919691862f60a9e30fe0f79fd8e771bc932a))
|
| 1184 |
+
|
| 1185 |
+
* remove readme links ([`4bcef48`](https://github.com/jbloomAus/SAEDashboard/commit/4bcef489b644dd3357b1975f3245d534f6f0d2e0))
|
| 1186 |
+
|
| 1187 |
+
* add demo html ([`629c713`](https://github.com/jbloomAus/SAEDashboard/commit/629c713345407562dc4ccd9875bf3cfab5480bdd))
|
| 1188 |
+
|
| 1189 |
+
* remove demos ([`beedea9`](https://github.com/jbloomAus/SAEDashboard/commit/beedea9667761534a5293015aff9cc17638666a5))
|
| 1190 |
+
|
| 1191 |
+
* fix quantile error ([`3a23cfd`](https://github.com/jbloomAus/SAEDashboard/commit/3a23cfd56f21fe0775a1a9957db340d15f75f51a))
|
| 1192 |
+
|
| 1193 |
+
* width 425 ([`f25c776`](https://github.com/jbloomAus/SAEDashboard/commit/f25c776d5cb746916d3f2fdf368cbd5448742949))
|
| 1194 |
+
|
| 1195 |
+
* fix device bug ([`85dfa49`](https://github.com/jbloomAus/SAEDashboard/commit/85dfa497bc804945911e80607ac31cf3afbdc759))
|
| 1196 |
+
|
| 1197 |
+
* dont return vocab dict ([`b4c7138`](https://github.com/jbloomAus/SAEDashboard/commit/b4c713873870acb4035986cc5bff3a4ce1e466c9))
|
| 1198 |
+
|
| 1199 |
+
* save as JSON, fix device ([`eba2cff`](https://github.com/jbloomAus/SAEDashboard/commit/eba2cff3eb6215558577a6b4d4f8cc716766b927))
|
| 1200 |
+
|
| 1201 |
+
* simple fixed and issues ([`b28a0f7`](https://github.com/jbloomAus/SAEDashboard/commit/b28a0f7c7e936f4bea05528d952dfcd438533cce))
|
| 1202 |
+
|
| 1203 |
+
* Merge pull request #8 from lucyfarnik/topk-empty-mask
|
| 1204 |
+
|
| 1205 |
+
Topk error handling for empty masks ([`2740c00`](https://github.com/jbloomAus/SAEDashboard/commit/2740c0047e78df7e56d7bcf707c909ac18e71c1f))
|
| 1206 |
+
|
| 1207 |
+
* Topk error handling for empty masks ([`1c2627e`](https://github.com/jbloomAus/SAEDashboard/commit/1c2627e237f8f67725fc44e60a190bc141d36fc8))
|
| 1208 |
+
|
| 1209 |
+
* viz to vis ([`216d02b`](https://github.com/jbloomAus/SAEDashboard/commit/216d02b550d6fbcb9b37d39c1b272a7dda91aadc))
|
| 1210 |
+
|
| 1211 |
+
* update readme links ([`f9b3f95`](https://github.com/jbloomAus/SAEDashboard/commit/f9b3f95e31e7150024be27ec62246f43bf9bcbb8))
|
| 1212 |
+
|
| 1213 |
+
* update for TL ([`1941db1`](https://github.com/jbloomAus/SAEDashboard/commit/1941db1e22093d6fc88fb3fcd6f4c7d535d8b3b4))
|
| 1214 |
+
|
| 1215 |
+
* Merge pull request #5 from lucyfarnik/transformer-lens-models
|
| 1216 |
+
|
| 1217 |
+
Compatibility with TransformerLens models ([`8d59c6c`](https://github.com/jbloomAus/SAEDashboard/commit/8d59c6c5a5f2b98c486e5c74130371ad9254d1c9))
|
| 1218 |
+
|
| 1219 |
+
* Merge branch 'main' into transformer-lens-models ([`73057d7`](https://github.com/jbloomAus/SAEDashboard/commit/73057d7e2a3e4e9669fc0556e64190811ac8b52d))
|
| 1220 |
+
|
| 1221 |
+
* Merge pull request #4 from lucyfarnik/resid-saes-support
|
| 1222 |
+
|
| 1223 |
+
Added support for residual-adjacent SAEs ([`b02e98b`](https://github.com/jbloomAus/SAEDashboard/commit/b02e98b3b852c0613a890f8949d04b5560fb6fd6))
|
| 1224 |
+
|
| 1225 |
+
* Added support for residual-adjacent SAEs ([`89aacf1`](https://github.com/jbloomAus/SAEDashboard/commit/89aacf1b22aa81b393b10eca8611c9dbf406c638))
|
| 1226 |
+
|
| 1227 |
+
* Merge pull request #7 from lucyfarnik/fix-histogram-div-zero
|
| 1228 |
+
|
| 1229 |
+
Fixed division by zero in histogram calculation ([`3aee20e`](https://github.com/jbloomAus/SAEDashboard/commit/3aee20ea7f99cc07e6c5085fddb70cadd8327f4d))
|
| 1230 |
+
|
| 1231 |
+
* Fixed division by zero in histogram calculation ([`e986e90`](https://github.com/jbloomAus/SAEDashboard/commit/e986e907cc42790efc93ce75ebf7b28a0278aaa2))
|
| 1232 |
+
|
| 1233 |
+
* Merge pull request #6 from lucyfarnik/handling-dead-features
|
| 1234 |
+
|
| 1235 |
+
Edge case handling for dead features ([`9e43c30`](https://github.com/jbloomAus/SAEDashboard/commit/9e43c308e58769828234e1505f1c1102ba651dfd))
|
| 1236 |
+
|
| 1237 |
+
* Edge case handling for dead features ([`5197aee`](https://github.com/jbloomAus/SAEDashboard/commit/5197aee2c9f92bce7c5fd6d22201152a68c2e6ca))
|
| 1238 |
+
|
| 1239 |
+
* add features argument ([`f24ef7e`](https://github.com/jbloomAus/SAEDashboard/commit/f24ef7ebebb3d4fd92e299858dbd5b968b78c69e))
|
| 1240 |
+
|
| 1241 |
+
* fix image link ([`22c8734`](https://github.com/jbloomAus/SAEDashboard/commit/22c873434dfa84e3aed5ee0aab0fd25b288428a6))
|
| 1242 |
+
|
| 1243 |
+
* Merge pull request #1 from lucyfarnik/read-me-links-fix
|
| 1244 |
+
|
| 1245 |
+
Fixed readme links pointing to the old colab ([`86f8e20`](https://github.com/jbloomAus/SAEDashboard/commit/86f8e2012e376b6c498e5e708324f812af6fbc98))
|
| 1246 |
+
|
| 1247 |
+
* Fixed readme links pointing to the old colab ([`28ef1cb`](https://github.com/jbloomAus/SAEDashboard/commit/28ef1cbd1b91f6c09c842f48e1f997d189ca04e7))
|
| 1248 |
+
|
| 1249 |
+
* Added readme section about models ([`7523e7f`](https://github.com/jbloomAus/SAEDashboard/commit/7523e7f6363e030196496b3c6a3dc70b234c2d9a))
|
| 1250 |
+
|
| 1251 |
+
* Compatibility with TransformerLens models ([`ba708e9`](https://github.com/jbloomAus/SAEDashboard/commit/ba708e987be6cc7a09d34ea8fb83de009312684d))
|
| 1252 |
+
|
| 1253 |
+
* Added support for MPS ([`196c0a2`](https://github.com/jbloomAus/SAEDashboard/commit/196c0a24d0e8277b327eb2d57662075f9106990b))
|
| 1254 |
+
|
| 1255 |
+
* black font ([`d81e74d`](https://github.com/jbloomAus/SAEDashboard/commit/d81e74d575326ef786881fb9182a768f9de2cb70))
|
| 1256 |
+
|
| 1257 |
+
* fix html bug ([`265dedd`](https://github.com/jbloomAus/SAEDashboard/commit/265dedd376991230e2041fd37d5b6a0eda048545))
|
| 1258 |
+
|
| 1259 |
+
* add jax and dataset deps ([`f1caeaf`](https://github.com/jbloomAus/SAEDashboard/commit/f1caeafc9613e27c7663447cf862301ac11d842d))
|
| 1260 |
+
|
| 1261 |
+
* remove TL dependency ([`155991f`](https://github.com/jbloomAus/SAEDashboard/commit/155991fe61d0199d081d344ac44996edce35d118))
|
| 1262 |
+
|
| 1263 |
+
* first commit ([`7782eb6`](https://github.com/jbloomAus/SAEDashboard/commit/7782eb6d5058372630c5bbb8693eb540a7bceaf4))
|
SAEDashboard/Dockerfile
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# docker build --target development -t decoderesearch/saedashboard-cuda --file Dockerfile .
|
| 2 |
+
# docker run --entrypoint /bin/bash -it decoderesearch/saedashboard-cuda
|
| 3 |
+
|
| 4 |
+
ARG APP_NAME=sae_dashboard
|
| 5 |
+
ARG APP_PATH=/opt/$APP_NAME
|
| 6 |
+
ARG PYTHON_VERSION=3.12.2
|
| 7 |
+
ARG POETRY_VERSION=1.8.3
|
| 8 |
+
|
| 9 |
+
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel AS staging
|
| 10 |
+
ARG APP_NAME
|
| 11 |
+
ARG APP_PATH
|
| 12 |
+
ARG POETRY_VERSION
|
| 13 |
+
|
| 14 |
+
ENV \
|
| 15 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 16 |
+
PYTHONUNBUFFERED=1 \
|
| 17 |
+
PYTHONFAULTHANDLER=1
|
| 18 |
+
ENV \
|
| 19 |
+
POETRY_VERSION=$POETRY_VERSION \
|
| 20 |
+
POETRY_HOME="/opt/poetry" \
|
| 21 |
+
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
| 22 |
+
POETRY_NO_INTERACTION=1
|
| 23 |
+
|
| 24 |
+
RUN apt-get update && apt-get install --no-install-recommends -y curl git-lfs vim && rm -rf /var/lib/apt/lists/*
|
| 25 |
+
|
| 26 |
+
RUN curl -sSL https://install.python-poetry.org | python
|
| 27 |
+
ENV PATH="$POETRY_HOME/bin:$PATH"
|
| 28 |
+
|
| 29 |
+
WORKDIR $APP_PATH
|
| 30 |
+
COPY ./pyproject.toml ./README.md ./
|
| 31 |
+
COPY ./$APP_NAME ./$APP_NAME
|
| 32 |
+
|
| 33 |
+
FROM staging AS development
|
| 34 |
+
ARG APP_NAME
|
| 35 |
+
ARG APP_PATH
|
| 36 |
+
|
| 37 |
+
WORKDIR $APP_PATH
|
| 38 |
+
RUN poetry lock
|
| 39 |
+
RUN poetry install --no-dev --no-cache
|
| 40 |
+
|
| 41 |
+
RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash
|
| 42 |
+
RUN git lfs install
|
| 43 |
+
|
| 44 |
+
ENTRYPOINT ["/bin/bash"]
|
| 45 |
+
CMD ["poetry", "shell"]
|
SAEDashboard/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Decode Research
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
SAEDashboard/Makefile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
format:
|
| 2 |
+
poetry run black .
|
| 3 |
+
poetry run isort .
|
| 4 |
+
|
| 5 |
+
lint:
|
| 6 |
+
poetry run flake8 .
|
| 7 |
+
poetry run black --check .
|
| 8 |
+
poetry run isort --check-only --diff .
|
| 9 |
+
|
| 10 |
+
check-type:
|
| 11 |
+
poetry run pyright .
|
| 12 |
+
|
| 13 |
+
test:
|
| 14 |
+
poetry run pytest --cov=sae_dashboard --cov-report=term-missing tests/unit
|
| 15 |
+
|
| 16 |
+
check-ci:
|
| 17 |
+
make format
|
| 18 |
+
make lint
|
| 19 |
+
make check-type
|
| 20 |
+
make test
|
| 21 |
+
|
| 22 |
+
profile-memory-unit:
|
| 23 |
+
poetry run pytest --memray tests/unit
|
| 24 |
+
|
| 25 |
+
profile-speed-unit:
|
| 26 |
+
poetry run py.test tests/unit --profile-svg -k "test_SaeVisData_create_results_look_reasonable[Default]"
|
| 27 |
+
open prof/combined.svg
|
SAEDashboard/README.md
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAEDashboard
|
| 2 |
+
|
| 3 |
+
SAEDashboard is a tool for visualizing and analyzing Sparse Autoencoders (SAEs) in neural networks. This repository is an adaptation and extension of Callum McDougal's [SAEVis](https://github.com/callummcdougall/sae_vis/tree/main), providing enhanced functionality for feature visualization and analysis as well as feature dashboard creation at scale.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This codebase was originally designed to replicate Anthropic's sparse autoencoder visualizations, which you can see [here](https://transformer-circuits.pub/2023/monosemantic-features/vis/a1.html). SAEDashboard primarily provides visualizations of features, including their activations, logits, and correlations--similar to what is shown in the Anthropic link.
|
| 8 |
+
|
| 9 |
+
<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/feature-vis-video.gif" width="800">
|
| 10 |
+
|
| 11 |
+
## Features
|
| 12 |
+
|
| 13 |
+
- Customizable dashboards with various plots and data representations for SAE features
|
| 14 |
+
- Support for any SAE in the SAELens library
|
| 15 |
+
- Neuronpedia integration for hosting and comprehensive neuron analysis (note: this requires a Neuronpedia account and is currently only used internally)
|
| 16 |
+
- Ability to handle large datasets and models efficiently
|
| 17 |
+
|
| 18 |
+
## Installation
|
| 19 |
+
|
| 20 |
+
Install SAEDashboard using pip:
|
| 21 |
+
```bash
|
| 22 |
+
pip install sae-dashboard
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
## Quick Start
|
| 27 |
+
|
| 28 |
+
Here's a basic example of how to use SAEDashboard with SaeVisRunner:
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
from sae_lens import SAE
|
| 32 |
+
from transformer_lens import HookedTransformer
|
| 33 |
+
from sae_dashboard.sae_vis_data import SaeVisConfig
|
| 34 |
+
from sae_dashboard.sae_vis_runner import SaeVisRunner
|
| 35 |
+
|
| 36 |
+
# Load model and SAE
|
| 37 |
+
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda", dtype="bfloat16")
|
| 38 |
+
sae, _, _ = SAE.from_pretrained(
|
| 39 |
+
release="gpt2-small-res-jb",
|
| 40 |
+
sae_id="blocks.6.hook_resid_pre",
|
| 41 |
+
device="cuda"
|
| 42 |
+
)
|
| 43 |
+
sae.fold_W_dec_norm()
|
| 44 |
+
|
| 45 |
+
# Configure visualization
|
| 46 |
+
config = SaeVisConfig(
|
| 47 |
+
hook_point=sae.cfg.hook_name,
|
| 48 |
+
features=list(range(256)),
|
| 49 |
+
minibatch_size_features=64,
|
| 50 |
+
minibatch_size_tokens=256,
|
| 51 |
+
device="cuda",
|
| 52 |
+
dtype="bfloat16"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Generate data
|
| 56 |
+
data = SaeVisRunner(config).run(encoder=sae, model=model, tokens=your_token_dataset)
|
| 57 |
+
|
| 58 |
+
# Save feature-centric visualization
|
| 59 |
+
from sae_dashboard.data_writing_fns import save_feature_centric_vis
|
| 60 |
+
save_feature_centric_vis(sae_vis_data=data, filename="feature_dashboard.html")
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
For a more detailed tutorial, check out our [demo notebook](https://colab.research.google.com/drive/1oqDS35zibmL1IUQrk_OSTxdhcGrSS6yO?usp=drive_link).
|
| 64 |
+
|
| 65 |
+
## Advanced Usage: Neuronpedia Runner
|
| 66 |
+
|
| 67 |
+
For internal use or advanced analysis, SAEDashboard provides a Neuronpedia runner that generates data compatible with Neuronpedia. Here's a basic example:
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from sae_dashboard.neuronpedia.neuronpedia_runner_config import NeuronpediaRunnerConfig
|
| 71 |
+
from sae_dashboard.neuronpedia.neuronpedia_runner import NeuronpediaRunner
|
| 72 |
+
|
| 73 |
+
config = NeuronpediaRunnerConfig(
|
| 74 |
+
sae_set="your_sae_set",
|
| 75 |
+
sae_path="path/to/sae",
|
| 76 |
+
np_set_name="your_neuronpedia_set_name",
|
| 77 |
+
huggingface_dataset_path="dataset/path",
|
| 78 |
+
n_prompts_total=1000,
|
| 79 |
+
n_features_at_a_time=64
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
runner = NeuronpediaRunner(config)
|
| 83 |
+
runner.run()
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
For more options and detailed configuration, refer to the `NeuronpediaRunnerConfig` class in the code.
|
| 87 |
+
|
| 88 |
+
## Cross-Layer Transcoder (CLT) Support
|
| 89 |
+
|
| 90 |
+
SAEDashboard now supports visualization of Cross-Layer Transcoders (CLTs), which are a variant of SAEs that process activations across transformer layers. To use CLT visualization:
|
| 91 |
+
|
| 92 |
+
### Required Files
|
| 93 |
+
|
| 94 |
+
When using a CLT model, you'll need these files in your CLT model directory:
|
| 95 |
+
|
| 96 |
+
1. **Model weights**: A `.safetensors` or `.pt` file containing the CLT weights
|
| 97 |
+
2. **Configuration**: A `cfg.json` file with the CLT configuration, including:
|
| 98 |
+
- `num_features`: Number of features in the CLT
|
| 99 |
+
- `num_layers`: Number of transformer layers
|
| 100 |
+
- `d_model`: Model dimension
|
| 101 |
+
- `activation_fn`: Activation function (e.g., "jumprelu", "relu")
|
| 102 |
+
- `normalization_method`: How inputs are normalized (e.g., "mean_std", "none")
|
| 103 |
+
- `tl_input_template`: TransformerLens hook template (e.g., "blocks.{}.ln2.hook_normalized"). Note that this will usually differ from the hook name in the model's cfg.json, which is based on NNsight/transformers. You will need to find the corresponding TransformerLens hook name.
|
| 104 |
+
3. **Normalization statistics** (if `normalization_method` is "mean_std"): A `norm_stats.json` file containing the mean and standard deviation for each layer's inputs, generated from the dataset when activations were generated (or afterwards). The file should have this structure:
|
| 105 |
+
```json
|
| 106 |
+
{
|
| 107 |
+
"0": {
|
| 108 |
+
"inputs": {
|
| 109 |
+
"mean": [0.1, -0.2, ...], // Array of d_model values
|
| 110 |
+
"std": [1.0, 0.9, ...] // Array of d_model values
|
| 111 |
+
}
|
| 112 |
+
},
|
| 113 |
+
"1": {
|
| 114 |
+
"inputs": {
|
| 115 |
+
"mean": [...],
|
| 116 |
+
"std": [...]
|
| 117 |
+
}
|
| 118 |
+
},
|
| 119 |
+
// ... entries for each layer
|
| 120 |
+
}
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### Example Usage
|
| 124 |
+
|
| 125 |
+
```python
|
| 126 |
+
from sae_dashboard.neuronpedia.neuronpedia_runner_config import NeuronpediaRunnerConfig
|
| 127 |
+
from sae_dashboard.neuronpedia.neuronpedia_runner import NeuronpediaRunner
|
| 128 |
+
|
| 129 |
+
config = NeuronpediaRunnerConfig(
|
| 130 |
+
sae_set="your_clt_set",
|
| 131 |
+
sae_path="/path/to/clt/model/directory", # Directory containing the files above
|
| 132 |
+
model_id="gpt2", # Base model the CLT was trained on
|
| 133 |
+
outputs_dir="clt_outputs",
|
| 134 |
+
huggingface_dataset_path="your/dataset",
|
| 135 |
+
use_clt=True, # Enable CLT mode
|
| 136 |
+
clt_layer_idx=5, # Which layer to visualize (0-indexed)
|
| 137 |
+
clt_weights_filename="model.safetensors", # Optional: specify exact weights file
|
| 138 |
+
n_prompts_total=1000,
|
| 139 |
+
n_features_at_a_time=64
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
runner = NeuronpediaRunner(config)
|
| 143 |
+
runner.run()
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
### Notes on CLT Support
|
| 147 |
+
|
| 148 |
+
- CLTs must be loaded from local files (HuggingFace Hub loading not yet supported)
|
| 149 |
+
- The `--use-clt` flag is mutually exclusive with `--use-transcoder` and `--use-skip-transcoder`
|
| 150 |
+
- JumpReLU activation functions with learned thresholds are supported
|
| 151 |
+
- The visualization will show features for the specified layer only
|
| 152 |
+
|
| 153 |
+
## Configuration Options
|
| 154 |
+
|
| 155 |
+
SAEDashboard offers a wide range of configuration options for both SaeVisRunner and NeuronpediaRunner. Key options include:
|
| 156 |
+
|
| 157 |
+
- `hook_point`: The layer to analyze in the model
|
| 158 |
+
- `features`: List of feature indices to visualize
|
| 159 |
+
- `minibatch_size_features`: Number of features to process in each batch
|
| 160 |
+
- `minibatch_size_tokens`: Number of tokens to process in each forward pass
|
| 161 |
+
- `device`: Computation device (e.g., "cuda", "cpu")
|
| 162 |
+
- `dtype`: Data type for computations
|
| 163 |
+
- `sparsity_threshold`: Threshold for feature sparsity (Neuronpedia runner)
|
| 164 |
+
- `n_prompts_total`: Total number of prompts to analyze
|
| 165 |
+
- `use_wandb`: Enable logging with Weights & Biases
|
| 166 |
+
|
| 167 |
+
Refer to `SaeVisConfig` and `NeuronpediaRunnerConfig` for full lists of options.
|
| 168 |
+
|
| 169 |
+
## Contributing
|
| 170 |
+
|
| 171 |
+
This project uses [Poetry](https://python-poetry.org/) for dependency management. After cloning the repo, install dependencies with `poetry lock && poetry install`.
|
| 172 |
+
|
| 173 |
+
We welcome contributions to SAEDashboard! Please follow these steps:
|
| 174 |
+
|
| 175 |
+
1. Fork the repository
|
| 176 |
+
2. Create a new branch for your feature
|
| 177 |
+
3. Implement your changes
|
| 178 |
+
4. Run tests and checks:
|
| 179 |
+
- Use `make format` to format your code
|
| 180 |
+
- Use `make check-ci` to run all checks and tests
|
| 181 |
+
5. Submit a pull request
|
| 182 |
+
|
| 183 |
+
Ensure your code passes all checks, including:
|
| 184 |
+
- Black and Flake8 for formatting and linting
|
| 185 |
+
- Pyright for type-checking
|
| 186 |
+
- Pytest for tests
|
| 187 |
+
|
| 188 |
+
## Citing This Work
|
| 189 |
+
|
| 190 |
+
To cite SAEDashboard in your research, please use the following BibTeX entry:
|
| 191 |
+
|
| 192 |
+
```bibtex
|
| 193 |
+
@misc{sae_dashboard,
|
| 194 |
+
title = {{SAE Dashboard}},
|
| 195 |
+
author = {Decode Research},
|
| 196 |
+
howpublished = {\url{https://github.com/jbloomAus/sae-dashboard}},
|
| 197 |
+
year = {2024}
|
| 198 |
+
}
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
## License
|
| 202 |
+
|
| 203 |
+
SAE Dashboard is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
| 204 |
+
|
| 205 |
+
## Acknowledgment and Citation
|
| 206 |
+
|
| 207 |
+
This project is based on the work by Callum McDougall. If you use SAEDashboard in your research, please cite the original SAEVis project as well:
|
| 208 |
+
|
| 209 |
+
```bibtex
|
| 210 |
+
@misc{sae_vis,
|
| 211 |
+
title = {{SAE Visualizer}},
|
| 212 |
+
author = {Callum McDougall},
|
| 213 |
+
howpublished = {\url{https://github.com/callummcdougall/sae_vis}},
|
| 214 |
+
year = {2024}
|
| 215 |
+
}
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
## Contact
|
| 219 |
+
|
| 220 |
+
For questions or support, please [open an issue](https://github.com/your-username/sae-dashboard/issues) on our GitHub repository.
|
| 221 |
+
|
SAEDashboard/docker/docker-entrypoint.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
|
| 3 |
+
set -e
|
| 4 |
+
|
| 5 |
+
# activate our virtual environment here
|
| 6 |
+
. /opt/pysetup/.venv/bin/activate
|
| 7 |
+
|
| 8 |
+
# You can put other setup logic here
|
| 9 |
+
|
| 10 |
+
# Evaluating passed command:
|
| 11 |
+
exec "$@"
|
SAEDashboard/docker/docker-hub.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Build and Push Docker Image to Docker Hub
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ "main" ]
|
| 6 |
+
pull_request:
|
| 7 |
+
branches: [ "main" ]
|
| 8 |
+
|
| 9 |
+
env:
|
| 10 |
+
REGISTRY: docker.io
|
| 11 |
+
IMAGE_NAME: decoderesearch/saedashboard-cuda
|
| 12 |
+
|
| 13 |
+
jobs:
|
| 14 |
+
|
| 15 |
+
build:
|
| 16 |
+
|
| 17 |
+
runs-on: ubuntu-latest
|
| 18 |
+
|
| 19 |
+
steps:
|
| 20 |
+
- uses: actions/checkout@v3
|
| 21 |
+
- name: Build the Docker image
|
| 22 |
+
run: docker build --target development -t ${{ env.IMAGE_NAME }} --file Dockerfile .
|
| 23 |
+
# test:
|
| 24 |
+
# runs-on: ubuntu-latest
|
| 25 |
+
# steps:
|
| 26 |
+
# - uses: actions/checkout@v2
|
| 27 |
+
# - name: Test the Docker image
|
| 28 |
+
# run: docker-compose up -d
|
| 29 |
+
push_to_registry:
|
| 30 |
+
name: Push Docker image to Docker Hub
|
| 31 |
+
runs-on: ubuntu-latest
|
| 32 |
+
steps:
|
| 33 |
+
- name: Check out the repo
|
| 34 |
+
uses: actions/checkout@v3
|
| 35 |
+
|
| 36 |
+
- name: Set up Docker Buildx
|
| 37 |
+
uses: docker/setup-buildx-action@v2
|
| 38 |
+
|
| 39 |
+
- name: Log in to Docker Hub
|
| 40 |
+
uses: docker/login-action@v3
|
| 41 |
+
with:
|
| 42 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
| 43 |
+
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
| 44 |
+
|
| 45 |
+
- name: Extract metadata (tags, labels) for Docker
|
| 46 |
+
id: meta
|
| 47 |
+
uses: docker/metadata-action@v5
|
| 48 |
+
with:
|
| 49 |
+
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
| 50 |
+
|
| 51 |
+
- name: Build and push Docker image
|
| 52 |
+
uses: docker/build-push-action@v2
|
| 53 |
+
with:
|
| 54 |
+
context: "{{defaultContext}}"
|
| 55 |
+
push: true
|
| 56 |
+
tags: ${{ steps.meta.outputs.tags }}
|
| 57 |
+
labels: ${{ steps.meta.outputs.labels }}
|
SAEDashboard/neuronpedia_vector_pipeline_demo.ipynb
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stdout",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"Number of vectors: 1\n",
|
| 13 |
+
"Vector dimension: 768\n",
|
| 14 |
+
"Vector names: ['sentiment_vector']\n"
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"source": [
|
| 19 |
+
"# Example usage\n",
|
| 20 |
+
"import json\n",
|
| 21 |
+
"import torch\n",
|
| 22 |
+
"from pathlib import Path\n",
|
| 23 |
+
"from sae_dashboard.neuronpedia.vector_set import VectorSet\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"# Load vector from file. Note that the vectors should be stored in this format, as a list of lists of floats:\n",
|
| 27 |
+
"# {\n",
|
| 28 |
+
"# \"vectors\": [\n",
|
| 29 |
+
"# [vector_1],\n",
|
| 30 |
+
"# [vector_2],\n",
|
| 31 |
+
"# ...\n",
|
| 32 |
+
"# ]\n",
|
| 33 |
+
"# }\n",
|
| 34 |
+
"json_path = Path(\"test_vectors/logistic_direction.json\")\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"# Load the vector into a VectorSet\n",
|
| 37 |
+
"vector_set = VectorSet.from_json(\n",
|
| 38 |
+
" json_path=json_path,\n",
|
| 39 |
+
" d_model=768, # Example dimension for GPT-2 Small\n",
|
| 40 |
+
" hook_point=\"blocks.7.hook_resid_pre\",\n",
|
| 41 |
+
" hook_layer=7,\n",
|
| 42 |
+
" model_name=\"gpt2\",\n",
|
| 43 |
+
" names=[\"sentiment_vector\"], # Optional custom name\n",
|
| 44 |
+
")\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"# Now you can use the vector set\n",
|
| 47 |
+
"print(f\"Number of vectors: {vector_set.vectors.shape[0]}\")\n",
|
| 48 |
+
"print(f\"Vector dimension: {vector_set.vectors.shape[1]}\")\n",
|
| 49 |
+
"print(f\"Vector names: {vector_set.names}\")"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": 3,
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"# You can also save and load the vector set as a VectorSet object as opposed to a simple list of lists of floats\n",
|
| 59 |
+
"vector_set.save(Path(\"test_vectors/logistic_direction_vector_set.json\"))\n",
|
| 60 |
+
"vector_set = VectorSet.load(Path(\"test_vectors/logistic_direction_vector_set.json\"))"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"execution_count": 4,
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"source": [
|
| 69 |
+
"from sae_dashboard.neuronpedia.neuronpedia_vector_runner import (\n",
|
| 70 |
+
" NeuronpediaVectorRunner,\n",
|
| 71 |
+
" NeuronpediaVectorRunnerConfig,\n",
|
| 72 |
+
")\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"cfg = NeuronpediaVectorRunnerConfig(\n",
|
| 75 |
+
" outputs_dir=\"test_outputs/\",\n",
|
| 76 |
+
" huggingface_dataset_path=\"monology/pile-uncopyrighted\",\n",
|
| 77 |
+
" vector_dtype=\"float32\",\n",
|
| 78 |
+
" model_dtype=\"float32\",\n",
|
| 79 |
+
" # Small test settings\n",
|
| 80 |
+
" n_prompts_total=16384,\n",
|
| 81 |
+
" n_tokens_in_prompt=128, # Shorter sequences\n",
|
| 82 |
+
" n_prompts_in_forward_pass=256,\n",
|
| 83 |
+
" n_vectors_at_a_time=1,\n",
|
| 84 |
+
" use_wandb=False, # Disable wandb for testing\n",
|
| 85 |
+
")"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"execution_count": 5,
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [
|
| 93 |
+
{
|
| 94 |
+
"name": "stdout",
|
| 95 |
+
"output_type": "stream",
|
| 96 |
+
"text": [
|
| 97 |
+
"Device Count: 1\n",
|
| 98 |
+
"Using specified vector dtype: float32\n",
|
| 99 |
+
"SAE Device: mps\n",
|
| 100 |
+
"Model Device: mps\n",
|
| 101 |
+
"Model Num Devices: 1\n",
|
| 102 |
+
"Activation Store Device: mps\n",
|
| 103 |
+
"Dataset Path: monology/pile-uncopyrighted\n",
|
| 104 |
+
"Forward Pass size: 128\n",
|
| 105 |
+
"Total number of tokens: 2097152\n",
|
| 106 |
+
"Total number of contexts (prompts): 16384\n",
|
| 107 |
+
"Vector DType: float32\n",
|
| 108 |
+
"Model DType: float32\n"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"name": "stderr",
|
| 113 |
+
"output_type": "stream",
|
| 114 |
+
"text": [
|
| 115 |
+
"/Users/curttigges/miniconda3/envs/sae-d/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
|
| 116 |
+
" warnings.warn(\n"
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"name": "stdout",
|
| 121 |
+
"output_type": "stream",
|
| 122 |
+
"text": [
|
| 123 |
+
"Loaded pretrained model gpt2 into HookedTransformer\n"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"data": {
|
| 128 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 129 |
+
"model_id": "f1a49eee02cd482e9de6deaa88e4afde",
|
| 130 |
+
"version_major": 2,
|
| 131 |
+
"version_minor": 0
|
| 132 |
+
},
|
| 133 |
+
"text/plain": [
|
| 134 |
+
"Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
"metadata": {},
|
| 138 |
+
"output_type": "display_data"
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"name": "stdout",
|
| 142 |
+
"output_type": "stream",
|
| 143 |
+
"text": [
|
| 144 |
+
"Warning: Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.\n",
|
| 145 |
+
"Tokens don't exist, making them.\n"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"name": "stderr",
|
| 150 |
+
"output_type": "stream",
|
| 151 |
+
"text": [
|
| 152 |
+
" 0%| | 0/2048 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (3180 > 1024). Running this sequence through the model will result in indexing errors\n",
|
| 153 |
+
"100%|██████████| 2048/2048 [00:18<00:00, 108.67it/s]\n",
|
| 154 |
+
"0it [00:00, ?it/s]"
|
| 155 |
+
]
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"name": "stdout",
|
| 159 |
+
"output_type": "stream",
|
| 160 |
+
"text": [
|
| 161 |
+
"========== Running Batch #0 ==========\n"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"data": {
|
| 166 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 167 |
+
"model_id": "418cf3ccf15d4ae597e06d24e4c89b11",
|
| 168 |
+
"version_major": 2,
|
| 169 |
+
"version_minor": 0
|
| 170 |
+
},
|
| 171 |
+
"text/plain": [
|
| 172 |
+
"Forward passes to cache data for vis: 0%| | 0/60 [00:00<?, ?it/s]"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"output_type": "display_data"
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"data": {
|
| 180 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 181 |
+
"model_id": "2a0ddbda94d0407598edf564b4487407",
|
| 182 |
+
"version_major": 2,
|
| 183 |
+
"version_minor": 0
|
| 184 |
+
},
|
| 185 |
+
"text/plain": [
|
| 186 |
+
"Extracting vis data from cached data: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
"metadata": {},
|
| 190 |
+
"output_type": "display_data"
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"name": "stderr",
|
| 194 |
+
"output_type": "stream",
|
| 195 |
+
"text": [
|
| 196 |
+
"/Users/curttigges/Projects/SAEDashboard/sae_dashboard/vector_data_generator.py:205: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
| 197 |
+
" return torch.load(\n"
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"name": "stdout",
|
| 202 |
+
"output_type": "stream",
|
| 203 |
+
"text": [
|
| 204 |
+
"feature_indices: [0]\n"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"data": {
|
| 209 |
+
"text/html": [
|
| 210 |
+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━┳━━━━━━┳━━━━━━━┓\n",
|
| 211 |
+
"┃<span style=\"font-weight: bold\"> Task </span>┃<span style=\"font-weight: bold\"> Time </span>┃<span style=\"font-weight: bold\"> Pct % </span>┃\n",
|
| 212 |
+
"┡━━━━━━╇━━━━━━╇━━━━━━━┩\n",
|
| 213 |
+
"└──────┴──────┴───────┘\n",
|
| 214 |
+
"</pre>\n"
|
| 215 |
+
],
|
| 216 |
+
"text/plain": [
|
| 217 |
+
"┏━━━━━━┳━━━━━━┳━━━━━━━┓\n",
|
| 218 |
+
"┃\u001b[1m \u001b[0m\u001b[1mTask\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mTime\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mPct %\u001b[0m\u001b[1m \u001b[0m┃\n",
|
| 219 |
+
"┡━━━━━━╇━━━━━━╇━━━━━━━┩\n",
|
| 220 |
+
"└──────┴──────┴───────┘\n"
|
| 221 |
+
]
|
| 222 |
+
},
|
| 223 |
+
"metadata": {},
|
| 224 |
+
"output_type": "display_data"
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"name": "stderr",
|
| 228 |
+
"output_type": "stream",
|
| 229 |
+
"text": [
|
| 230 |
+
"1it [00:02, 2.65s/it]"
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"name": "stdout",
|
| 235 |
+
"output_type": "stream",
|
| 236 |
+
"text": [
|
| 237 |
+
"Output written to test_outputs/gpt2_blocks.7.hook_resid_pre/batch-0.json\n"
|
| 238 |
+
]
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"name": "stderr",
|
| 242 |
+
"output_type": "stream",
|
| 243 |
+
"text": [
|
| 244 |
+
"\n"
|
| 245 |
+
]
|
| 246 |
+
}
|
| 247 |
+
],
|
| 248 |
+
"source": [
|
| 249 |
+
"runner = NeuronpediaVectorRunner(vector_set, cfg)\n",
|
| 250 |
+
"runner.run()"
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "code",
|
| 255 |
+
"execution_count": null,
|
| 256 |
+
"metadata": {},
|
| 257 |
+
"outputs": [],
|
| 258 |
+
"source": []
|
| 259 |
+
}
|
| 260 |
+
],
|
| 261 |
+
"metadata": {
|
| 262 |
+
"kernelspec": {
|
| 263 |
+
"display_name": "sae-d",
|
| 264 |
+
"language": "python",
|
| 265 |
+
"name": "python3"
|
| 266 |
+
},
|
| 267 |
+
"language_info": {
|
| 268 |
+
"codemirror_mode": {
|
| 269 |
+
"name": "ipython",
|
| 270 |
+
"version": 3
|
| 271 |
+
},
|
| 272 |
+
"file_extension": ".py",
|
| 273 |
+
"mimetype": "text/x-python",
|
| 274 |
+
"name": "python",
|
| 275 |
+
"nbconvert_exporter": "python",
|
| 276 |
+
"pygments_lexer": "ipython3",
|
| 277 |
+
"version": "3.12.4"
|
| 278 |
+
}
|
| 279 |
+
},
|
| 280 |
+
"nbformat": 4,
|
| 281 |
+
"nbformat_minor": 2
|
| 282 |
+
}
|
SAEDashboard/notebooks/experiment_gemma_2_9b_dashboard_generation_np.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# I'm running this in an A100 with 90GB of GPU Ram.
|
| 2 |
+
# I'm using TransformerLens 2.2 which I manually installed from source.
|
| 3 |
+
# I'm a few edits to fix bfloat16 errors (but I've since made PR's so latest SAE Lens / SAE dashboard should be fine here).
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from sae_dashboard.neuronpedia.neuronpedia_runner import (
|
| 7 |
+
NeuronpediaRunner,
|
| 8 |
+
NeuronpediaRunnerConfig,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# GET WEIGHTS FROM WANDB
|
| 12 |
+
# import wandb
|
| 13 |
+
# run = wandb.init()
|
| 14 |
+
# artifact = run.use_artifact('jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7', type='model')
|
| 15 |
+
# artifact_dir = artifact.download()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Get Sparsity from Wandb (and manually move it accross)
|
| 19 |
+
# import wandb
|
| 20 |
+
# run = wandb.init()
|
| 21 |
+
# artifact = run.use_artifact('jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688_log_feature_sparsity:v7', type='log_feature_sparsity')
|
| 22 |
+
# artifact_dir = artifact.download()
|
| 23 |
+
|
| 24 |
+
NP_OUTPUT_FOLDER = "neuronpedia_outputs/gemma-2-9b-test"
|
| 25 |
+
SAE_SET = "res-jb-test"
|
| 26 |
+
SAE_PATH = "artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7"
|
| 27 |
+
print(SAE_PATH)
|
| 28 |
+
|
| 29 |
+
# delete output files if present
|
| 30 |
+
os.system(f"rm -rf {NP_OUTPUT_FOLDER}")
|
| 31 |
+
cfg = NeuronpediaRunnerConfig(
|
| 32 |
+
sae_set=SAE_SET,
|
| 33 |
+
sae_path=SAE_PATH,
|
| 34 |
+
outputs_dir=NP_OUTPUT_FOLDER,
|
| 35 |
+
sparsity_threshold=-6,
|
| 36 |
+
n_prompts_total=4096,
|
| 37 |
+
huggingface_dataset_path="monology/pile-uncopyrighted",
|
| 38 |
+
n_features_at_a_time=1024,
|
| 39 |
+
n_tokens_in_prompt=128,
|
| 40 |
+
start_batch=0,
|
| 41 |
+
end_batch=8,
|
| 42 |
+
use_wandb=True,
|
| 43 |
+
sae_device="cuda",
|
| 44 |
+
model_device="cuda",
|
| 45 |
+
model_n_devices=1,
|
| 46 |
+
activation_store_device="cuda",
|
| 47 |
+
model_dtype="bfloat16",
|
| 48 |
+
sae_dtype="float32",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
runner = NeuronpediaRunner(cfg)
|
| 52 |
+
runner.run()
|
SAEDashboard/notebooks/sae_dashboard_demo_gemma_2_9b.ipynb
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Demo Notebook"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "markdown",
|
| 12 |
+
"metadata": {},
|
| 13 |
+
"source": [
|
| 14 |
+
"Steps:\n",
|
| 15 |
+
"1. Download SAE with SAE Lens.\n",
|
| 16 |
+
"2. Create a dataset consistent with that SAE. \n",
|
| 17 |
+
"3. Fold the SAE decoder norm weights so that feature activations are \"correct\".\n",
|
| 18 |
+
"4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.\n",
|
| 19 |
+
"5. Run the SAE generator for the features you want."
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "markdown",
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"source": [
|
| 26 |
+
"# Set Up"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": null,
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"# Download Gemma-2-9b weights\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"import wandb\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"run = wandb.init()\n",
|
| 40 |
+
"artifact = run.use_artifact(\n",
|
| 41 |
+
" \"jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7\",\n",
|
| 42 |
+
" type=\"model\",\n",
|
| 43 |
+
")\n",
|
| 44 |
+
"artifact_dir = artifact.download()"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [],
|
| 52 |
+
"source": [
|
| 53 |
+
"import wandb\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"run = wandb.init()\n",
|
| 56 |
+
"artifact = run.use_artifact(\n",
|
| 57 |
+
" \"jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688_log_feature_sparsity:v7\",\n",
|
| 58 |
+
" type=\"log_feature_sparsity\",\n",
|
| 59 |
+
")\n",
|
| 60 |
+
"artifact_dir = artifact.download()"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"execution_count": null,
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"source": [
|
| 69 |
+
"import torch\n",
|
| 70 |
+
"import matplotlib.pyplot as plt\n",
|
| 71 |
+
"from safetensors.torch import load_file\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"# Assume we have a PyTorch tensor\n",
|
| 74 |
+
"feature_sparsity = load_file(\n",
|
| 75 |
+
" \"artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7/sparsity.safetensors\"\n",
|
| 76 |
+
")[\"sparsity\"]\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"# Convert the tensor to a numpy array\n",
|
| 79 |
+
"data = feature_sparsity.numpy()\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"# Create the histogram\n",
|
| 82 |
+
"plt.hist(data, bins=30, edgecolor=\"black\")\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"# Add labels and title\n",
|
| 85 |
+
"plt.xlabel(\"Value\")\n",
|
| 86 |
+
"plt.ylabel(\"Frequency\")\n",
|
| 87 |
+
"plt.title(\"Histogram of PyTorch Tensor\")\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"# Show the plot\n",
|
| 90 |
+
"plt.show()"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": [
|
| 99 |
+
"import torch\n",
|
| 100 |
+
"from transformer_lens import HookedTransformer\n",
|
| 101 |
+
"from sae_lens import ActivationsStore, SAE\n",
|
| 102 |
+
"from importlib import reload\n",
|
| 103 |
+
"import sae_dashboard\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"torch.set_grad_enabled(False)\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"reload(sae_dashboard)"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": null,
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [],
|
| 115 |
+
"source": [
|
| 116 |
+
"MODEL = \"gemma-2-9b\"\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"if torch.backends.mps.is_available():\n",
|
| 119 |
+
" device = \"mps\"\n",
|
| 120 |
+
"else:\n",
|
| 121 |
+
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"print(f\"Device: {device}\")\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"model = HookedTransformer.from_pretrained(MODEL, device=device, dtype=\"bfloat16\")"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": null,
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"outputs": [],
|
| 133 |
+
"source": [
|
| 134 |
+
"sae = SAE.load_from_pretrained(\n",
|
| 135 |
+
" \"artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7\"\n",
|
| 136 |
+
")\n",
|
| 137 |
+
"sae.fold_W_dec_norm()"
|
| 138 |
+
]
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"cell_type": "code",
|
| 142 |
+
"execution_count": null,
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"outputs": [],
|
| 145 |
+
"source": [
|
| 146 |
+
"# _, cache = model.run_with_cache(\"Wasssssup\", names_filter = sae.cfg.hook_name)\n",
|
| 147 |
+
"# sae_in = cache[sae.cfg.hook_name]\n",
|
| 148 |
+
"# print(sae_in.shape)\n",
|
| 149 |
+
"sae_in = torch.rand((1, 4, 3584)).to(sae.device)\n",
|
| 150 |
+
"sae_out = sae(sae_in)"
|
| 151 |
+
]
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"execution_count": null,
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"outputs": [],
|
| 158 |
+
"source": [
|
| 159 |
+
"# # the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
|
| 160 |
+
"# # Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
|
| 161 |
+
"# # We also return the feature sparsities which are stored in HF for convenience.\n",
|
| 162 |
+
"# sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
|
| 163 |
+
"# release = \"mistral-7b-res-wg\", # see other options in sae_lens/pretrained_saes.yaml\n",
|
| 164 |
+
"# sae_id = \"blocks.8.hook_resid_pre\", # won't always be a hook point\n",
|
| 165 |
+
"# device = \"cuda:3\",\n",
|
| 166 |
+
"# )\n",
|
| 167 |
+
"# # fold w_dec norm so feature activations are accurate\n",
|
| 168 |
+
"#\n",
|
| 169 |
+
"activations_store = ActivationsStore.from_sae(\n",
|
| 170 |
+
" model=model,\n",
|
| 171 |
+
" sae=sae,\n",
|
| 172 |
+
" streaming=True,\n",
|
| 173 |
+
" store_batch_size_prompts=8,\n",
|
| 174 |
+
" n_batches_in_buffer=8,\n",
|
| 175 |
+
" device=\"cpu\",\n",
|
| 176 |
+
")"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "code",
|
| 181 |
+
"execution_count": null,
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"outputs": [],
|
| 184 |
+
"source": [
|
| 185 |
+
"sae.encode_fn"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": [
|
| 194 |
+
"from sae_lens import run_evals\n",
|
| 195 |
+
"\n",
|
| 196 |
+
"eval_metrics = run_evals(\n",
|
| 197 |
+
" sae=sae,\n",
|
| 198 |
+
" activation_store=activations_store,\n",
|
| 199 |
+
" model=model,\n",
|
| 200 |
+
" n_eval_batches=3,\n",
|
| 201 |
+
" eval_batch_size_prompts=8,\n",
|
| 202 |
+
")\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"# CE Loss score should be high for residual stream SAEs\n",
|
| 205 |
+
"print(eval_metrics[\"metrics/CE_loss_score\"])\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"# ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly\n",
|
| 208 |
+
"print(eval_metrics[\"metrics/ce_loss_without_sae\"])\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"# ce loss with SAE shouldn't be massively higher\n",
|
| 211 |
+
"print(eval_metrics[\"metrics/ce_loss_with_sae\"])"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": null,
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [],
|
| 219 |
+
"source": [
|
| 220 |
+
"from tqdm import tqdm\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"from sae_dashboard.utils_fns import get_tokens\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"# 1000 prompts is plenty for a demo.\n",
|
| 226 |
+
"token_dataset = get_tokens(activations_store, 4096)"
|
| 227 |
+
]
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"cell_type": "code",
|
| 231 |
+
"execution_count": null,
|
| 232 |
+
"metadata": {},
|
| 233 |
+
"outputs": [],
|
| 234 |
+
"source": [
|
| 235 |
+
"# torch.save(token_dataset, \"to\")"
|
| 236 |
+
]
|
| 237 |
+
},
|
| 238 |
+
{
|
| 239 |
+
"cell_type": "code",
|
| 240 |
+
"execution_count": null,
|
| 241 |
+
"metadata": {},
|
| 242 |
+
"outputs": [],
|
| 243 |
+
"source": [
|
| 244 |
+
"# torch.save(token_dataset, \"token_dataset.pt\")\n",
|
| 245 |
+
"token_dataset = torch.load(\"token_dataset.pt\")"
|
| 246 |
+
]
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"cell_type": "code",
|
| 250 |
+
"execution_count": null,
|
| 251 |
+
"metadata": {},
|
| 252 |
+
"outputs": [],
|
| 253 |
+
"source": [
|
| 254 |
+
"import os\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"os.rmdir(\"demo_activations_cache\")"
|
| 257 |
+
]
|
| 258 |
+
},
|
| 259 |
+
{
|
| 260 |
+
"cell_type": "code",
|
| 261 |
+
"execution_count": null,
|
| 262 |
+
"metadata": {},
|
| 263 |
+
"outputs": [],
|
| 264 |
+
"source": [
|
| 265 |
+
"import torch\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"\n",
|
| 268 |
+
"def select_indices_in_range(tensor, min_val, max_val, num_samples=None):\n",
|
| 269 |
+
" \"\"\"\n",
|
| 270 |
+
" Select indices of a tensor where values fall within a specified range.\n",
|
| 271 |
+
"\n",
|
| 272 |
+
" Args:\n",
|
| 273 |
+
" tensor (torch.Tensor): Input tensor with values between -10 and 0.\n",
|
| 274 |
+
" min_val (float): Minimum value of the range (inclusive).\n",
|
| 275 |
+
" max_val (float): Maximum value of the range (inclusive).\n",
|
| 276 |
+
" num_samples (int, optional): Number of indices to randomly select. If None, return all indices.\n",
|
| 277 |
+
"\n",
|
| 278 |
+
" Returns:\n",
|
| 279 |
+
" torch.Tensor: Tensor of selected indices.\n",
|
| 280 |
+
" \"\"\"\n",
|
| 281 |
+
" # Ensure the input range is valid\n",
|
| 282 |
+
" if not (-10 <= min_val <= max_val <= 0):\n",
|
| 283 |
+
" raise ValueError(\n",
|
| 284 |
+
" \"Range must be within -10 to 0, and min_val must be <= max_val\"\n",
|
| 285 |
+
" )\n",
|
| 286 |
+
"\n",
|
| 287 |
+
" # Find indices where values are within the specified range\n",
|
| 288 |
+
" mask = (tensor >= min_val) & (tensor <= max_val)\n",
|
| 289 |
+
" indices = mask.nonzero().squeeze()\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" # If num_samples is specified and less than the total number of valid indices,\n",
|
| 292 |
+
" # randomly select that many indices\n",
|
| 293 |
+
" if num_samples is not None and num_samples < indices.numel():\n",
|
| 294 |
+
" perm = torch.randperm(indices.numel())\n",
|
| 295 |
+
" indices = indices[perm[:num_samples]]\n",
|
| 296 |
+
"\n",
|
| 297 |
+
" return indices\n",
|
| 298 |
+
"\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"n_features = 4096\n",
|
| 301 |
+
"feature_idxs = select_indices_in_range(feature_sparsity, -4, -2, 4096)\n",
|
| 302 |
+
"feature_sparsity[feature_idxs.tolist()]"
|
| 303 |
+
]
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"cell_type": "code",
|
| 307 |
+
"execution_count": null,
|
| 308 |
+
"metadata": {},
|
| 309 |
+
"outputs": [],
|
| 310 |
+
"source": []
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"cell_type": "code",
|
| 314 |
+
"execution_count": null,
|
| 315 |
+
"metadata": {},
|
| 316 |
+
"outputs": [],
|
| 317 |
+
"source": [
|
| 318 |
+
"from importlib import reload\n",
|
| 319 |
+
"import sys\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"def reload_user_modules(module_names):\n",
|
| 323 |
+
" \"\"\"Reload specified user modules.\"\"\"\n",
|
| 324 |
+
" for name in module_names:\n",
|
| 325 |
+
" if name in sys.modules:\n",
|
| 326 |
+
" reload(sys.modules[name])\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"# List of your module names\n",
|
| 330 |
+
"user_modules = [\n",
|
| 331 |
+
" \"sae_dashboard\",\n",
|
| 332 |
+
" \"sae_dashboard.sae_vis_runner\",\n",
|
| 333 |
+
" \"sae_dashboard.data_parsing_fns\",\n",
|
| 334 |
+
" \"sae_dashboard.feature_data_generator\",\n",
|
| 335 |
+
"]\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"# Reload modules\n",
|
| 338 |
+
"reload_user_modules(user_modules)\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"# Re-import after reload\n",
|
| 341 |
+
"from sae_dashboard.feature_data_generator import FeatureDataGenerator"
|
| 342 |
+
]
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"cell_type": "code",
|
| 346 |
+
"execution_count": null,
|
| 347 |
+
"metadata": {},
|
| 348 |
+
"outputs": [],
|
| 349 |
+
"source": [
|
| 350 |
+
"from pathlib import Path\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"test_feature_idx_gpt = feature_idxs.tolist()\n",
|
| 353 |
+
"\n",
|
| 354 |
+
"feature_vis_config_gpt = sae_vis_runner.SaeVisConfig(\n",
|
| 355 |
+
" hook_point=sae.cfg.hook_name,\n",
|
| 356 |
+
" features=test_feature_idx_gpt,\n",
|
| 357 |
+
" minibatch_size_features=16,\n",
|
| 358 |
+
" minibatch_size_tokens=4096, # this is really prompt with the number of tokens determined by the sequence length\n",
|
| 359 |
+
" verbose=True,\n",
|
| 360 |
+
" device=\"cuda\",\n",
|
| 361 |
+
" cache_dir=Path(\n",
|
| 362 |
+
" \"demo_activations_cache\"\n",
|
| 363 |
+
" ), # this will enable us to skip running the model for subsequent features.\n",
|
| 364 |
+
" dtype=\"bfloat16\",\n",
|
| 365 |
+
")\n",
|
| 366 |
+
"\n",
|
| 367 |
+
"runner = sae_vis_runner.SaeVisRunner(feature_vis_config_gpt)\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"data = runner.run(\n",
|
| 370 |
+
" encoder=sae, # type: ignore\n",
|
| 371 |
+
" model=model,\n",
|
| 372 |
+
" tokens=token_dataset[:1024],\n",
|
| 373 |
+
")"
|
| 374 |
+
]
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"cell_type": "code",
|
| 378 |
+
"execution_count": null,
|
| 379 |
+
"metadata": {},
|
| 380 |
+
"outputs": [],
|
| 381 |
+
"source": []
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"cell_type": "code",
|
| 385 |
+
"execution_count": null,
|
| 386 |
+
"metadata": {},
|
| 387 |
+
"outputs": [],
|
| 388 |
+
"source": [
|
| 389 |
+
"from sae_dashboard.data_writing_fns import save_feature_centric_vis\n",
|
| 390 |
+
"\n",
|
| 391 |
+
"filename = f\"demo_feature_dashboards.html\"\n",
|
| 392 |
+
"save_feature_centric_vis(sae_vis_data=data, filename=filename)"
|
| 393 |
+
]
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"cell_type": "code",
|
| 397 |
+
"execution_count": null,
|
| 398 |
+
"metadata": {},
|
| 399 |
+
"outputs": [],
|
| 400 |
+
"source": []
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"cell_type": "markdown",
|
| 404 |
+
"metadata": {},
|
| 405 |
+
"source": [
|
| 406 |
+
"# Quick Profiling experiment"
|
| 407 |
+
]
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"cell_type": "code",
|
| 411 |
+
"execution_count": null,
|
| 412 |
+
"metadata": {},
|
| 413 |
+
"outputs": [],
|
| 414 |
+
"source": [
|
| 415 |
+
"def mock_feature_acts_subset_for_now(sae: SAE):\n",
|
| 416 |
+
"\n",
|
| 417 |
+
" @torch.no_grad()\n",
|
| 418 |
+
" def sae_lens_get_feature_acts_subset(x: torch.Tensor, feature_idx): # type: ignore\n",
|
| 419 |
+
" \"\"\"\n",
|
| 420 |
+
" Get a subset of the feature activations for a dataset.\n",
|
| 421 |
+
" \"\"\"\n",
|
| 422 |
+
" original_device = x.device\n",
|
| 423 |
+
" feature_activations = sae.encode_fn(x.to(device=sae.device, dtype=sae.dtype))\n",
|
| 424 |
+
" return feature_activations[..., feature_idx].to(original_device)\n",
|
| 425 |
+
"\n",
|
| 426 |
+
" sae.get_feature_acts_subset = sae_lens_get_feature_acts_subset # type: ignore\n",
|
| 427 |
+
"\n",
|
| 428 |
+
" return sae\n",
|
| 429 |
+
"\n",
|
| 430 |
+
"\n",
|
| 431 |
+
"sae = mock_feature_acts_subset_for_now(sae)\n",
|
| 432 |
+
"feature_idxs = list(range(128))\n",
|
| 433 |
+
"sae_in = torch.rand((1, 4, 3584)).to(sae.device)\n",
|
| 434 |
+
"sae.get_feature_acts_subset(sae_in, feature_idxs)"
|
| 435 |
+
]
|
| 436 |
+
},
|
| 437 |
+
{
|
| 438 |
+
"cell_type": "code",
|
| 439 |
+
"execution_count": null,
|
| 440 |
+
"metadata": {},
|
| 441 |
+
"outputs": [],
|
| 442 |
+
"source": [
|
| 443 |
+
"for k, v in sae.named_parameters():\n",
|
| 444 |
+
" print(k, v.shape)"
|
| 445 |
+
]
|
| 446 |
+
},
|
| 447 |
+
{
|
| 448 |
+
"cell_type": "code",
|
| 449 |
+
"execution_count": null,
|
| 450 |
+
"metadata": {},
|
| 451 |
+
"outputs": [],
|
| 452 |
+
"source": [
|
| 453 |
+
"from torch import nn\n",
|
| 454 |
+
"from typing import List\n",
|
| 455 |
+
"\n",
|
| 456 |
+
"\n",
|
| 457 |
+
"class FeatureMaskingContext:\n",
|
| 458 |
+
" def __init__(self, sae: SAE, feature_idxs: List):\n",
|
| 459 |
+
" self.sae = sae\n",
|
| 460 |
+
" self.feature_idxs = feature_idxs\n",
|
| 461 |
+
" self.original_weight = {}\n",
|
| 462 |
+
"\n",
|
| 463 |
+
" def __enter__(self):\n",
|
| 464 |
+
"\n",
|
| 465 |
+
" ## W_dec\n",
|
| 466 |
+
" self.original_weight[\"W_dec\"] = getattr(self.sae, \"W_dec\").data.clone()\n",
|
| 467 |
+
" # mask the weight\n",
|
| 468 |
+
" masked_weight = sae.W_dec[self.feature_idxs]\n",
|
| 469 |
+
" # set the weight\n",
|
| 470 |
+
" setattr(self.sae, \"W_dec\", nn.Parameter(masked_weight))\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" ## W_enc\n",
|
| 473 |
+
" # clone the weight.\n",
|
| 474 |
+
" self.original_weight[\"W_enc\"] = getattr(self.sae, \"W_enc\").data.clone()\n",
|
| 475 |
+
" # mask the weight\n",
|
| 476 |
+
" masked_weight = sae.W_enc[:, self.feature_idxs]\n",
|
| 477 |
+
" # set the weight\n",
|
| 478 |
+
" setattr(self.sae, \"W_enc\", nn.Parameter(masked_weight))\n",
|
| 479 |
+
"\n",
|
| 480 |
+
" if self.sae.cfg.architecture == \"standard\":\n",
|
| 481 |
+
"\n",
|
| 482 |
+
" ## b_enc\n",
|
| 483 |
+
" self.original_weight[\"b_enc\"] = getattr(self.sae, \"b_enc\").data.clone()\n",
|
| 484 |
+
" # mask the weight\n",
|
| 485 |
+
" masked_weight = sae.b_enc[self.feature_idxs]\n",
|
| 486 |
+
" # set the weight\n",
|
| 487 |
+
" setattr(self.sae, \"b_enc\", nn.Parameter(masked_weight))\n",
|
| 488 |
+
"\n",
|
| 489 |
+
" elif self.sae.cfg.architecture == \"gated\":\n",
|
| 490 |
+
"\n",
|
| 491 |
+
" ## b_gate\n",
|
| 492 |
+
" self.original_weight[\"b_gate\"] = getattr(self.sae, \"b_gate\").data.clone()\n",
|
| 493 |
+
" # mask the weight\n",
|
| 494 |
+
" masked_weight = sae.b_gate[self.feature_idxs]\n",
|
| 495 |
+
" # set the weight\n",
|
| 496 |
+
" setattr(self.sae, \"b_gate\", nn.Parameter(masked_weight))\n",
|
| 497 |
+
"\n",
|
| 498 |
+
" ## r_mag\n",
|
| 499 |
+
" self.original_weight[\"r_mag\"] = getattr(self.sae, \"r_mag\").data.clone()\n",
|
| 500 |
+
" # mask the weight\n",
|
| 501 |
+
" masked_weight = sae.r_mag[self.feature_idxs]\n",
|
| 502 |
+
" # set the weight\n",
|
| 503 |
+
" setattr(self.sae, \"r_mag\", nn.Parameter(masked_weight))\n",
|
| 504 |
+
"\n",
|
| 505 |
+
" ## b_mag\n",
|
| 506 |
+
" self.original_weight[\"b_mag\"] = getattr(self.sae, \"b_mag\").data.clone()\n",
|
| 507 |
+
" # mask the weight\n",
|
| 508 |
+
" masked_weight = sae.b_mag[self.feature_idxs]\n",
|
| 509 |
+
" # set the weight\n",
|
| 510 |
+
" setattr(self.sae, \"b_mag\", nn.Parameter(masked_weight))\n",
|
| 511 |
+
" else:\n",
|
| 512 |
+
" raise (ValueError(\"Invalid architecture\"))\n",
|
| 513 |
+
"\n",
|
| 514 |
+
" return self\n",
|
| 515 |
+
"\n",
|
| 516 |
+
" def __exit__(self, exc_type, exc_value, traceback):\n",
|
| 517 |
+
"\n",
|
| 518 |
+
" # set everything back to normal\n",
|
| 519 |
+
" for key, value in self.original_weight.items():\n",
|
| 520 |
+
" setattr(self.sae, key, nn.Parameter(value))"
|
| 521 |
+
]
|
| 522 |
+
},
|
| 523 |
+
{
|
| 524 |
+
"cell_type": "code",
|
| 525 |
+
"execution_count": null,
|
| 526 |
+
"metadata": {},
|
| 527 |
+
"outputs": [],
|
| 528 |
+
"source": [
|
| 529 |
+
"import gc\n",
|
| 530 |
+
"import torch\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"gc.collect()\n",
|
| 533 |
+
"torch.cuda.empty_cache()\n",
|
| 534 |
+
"torch.set_grad_enabled(False)\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"\n",
|
| 537 |
+
"def my_function(sae_in):\n",
|
| 538 |
+
" # Your PyTorch code here\n",
|
| 539 |
+
" feature_idxs = list(range(2048))\n",
|
| 540 |
+
" with FeatureMaskingContext(sae, feature_idxs):\n",
|
| 541 |
+
" features = sae(sae_in)\n",
|
| 542 |
+
" print(features.mean())\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"\n",
|
| 545 |
+
"tokens = token_dataset[:64]\n",
|
| 546 |
+
"_, cache = model.run_with_cache(\n",
|
| 547 |
+
" tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=sae.cfg.hook_name\n",
|
| 548 |
+
")\n",
|
| 549 |
+
"sae_in = cache[sae.cfg.hook_name]"
|
| 550 |
+
]
|
| 551 |
+
},
|
| 552 |
+
{
|
| 553 |
+
"cell_type": "code",
|
| 554 |
+
"execution_count": null,
|
| 555 |
+
"metadata": {},
|
| 556 |
+
"outputs": [],
|
| 557 |
+
"source": [
|
| 558 |
+
"tokens.shape"
|
| 559 |
+
]
|
| 560 |
+
},
|
| 561 |
+
{
|
| 562 |
+
"cell_type": "code",
|
| 563 |
+
"execution_count": null,
|
| 564 |
+
"metadata": {},
|
| 565 |
+
"outputs": [],
|
| 566 |
+
"source": [
|
| 567 |
+
"sae.W_dec.shape"
|
| 568 |
+
]
|
| 569 |
+
},
|
| 570 |
+
{
|
| 571 |
+
"cell_type": "code",
|
| 572 |
+
"execution_count": null,
|
| 573 |
+
"metadata": {},
|
| 574 |
+
"outputs": [],
|
| 575 |
+
"source": [
|
| 576 |
+
"%load_ext memray"
|
| 577 |
+
]
|
| 578 |
+
},
|
| 579 |
+
{
|
| 580 |
+
"cell_type": "code",
|
| 581 |
+
"execution_count": null,
|
| 582 |
+
"metadata": {},
|
| 583 |
+
"outputs": [],
|
| 584 |
+
"source": [
|
| 585 |
+
"%%memray_flamegraph --trace-python-allocators --leaks\n",
|
| 586 |
+
"my_function(sae_in)"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
{
|
| 590 |
+
"cell_type": "code",
|
| 591 |
+
"execution_count": null,
|
| 592 |
+
"metadata": {},
|
| 593 |
+
"outputs": [],
|
| 594 |
+
"source": []
|
| 595 |
+
}
|
| 596 |
+
],
|
| 597 |
+
"metadata": {
|
| 598 |
+
"kernelspec": {
|
| 599 |
+
"display_name": ".venv",
|
| 600 |
+
"language": "python",
|
| 601 |
+
"name": "python3"
|
| 602 |
+
},
|
| 603 |
+
"language_info": {
|
| 604 |
+
"codemirror_mode": {
|
| 605 |
+
"name": "ipython",
|
| 606 |
+
"version": 3
|
| 607 |
+
},
|
| 608 |
+
"file_extension": ".py",
|
| 609 |
+
"mimetype": "text/x-python",
|
| 610 |
+
"name": "python",
|
| 611 |
+
"nbconvert_exporter": "python",
|
| 612 |
+
"pygments_lexer": "ipython3",
|
| 613 |
+
"version": "3.11.7"
|
| 614 |
+
}
|
| 615 |
+
},
|
| 616 |
+
"nbformat": 4,
|
| 617 |
+
"nbformat_minor": 2
|
| 618 |
+
}
|
SAEDashboard/pyproject.toml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "sae-dashboard"
|
| 3 |
+
version = "0.7.3"
|
| 4 |
+
description = "Open-source SAE visualizer, based on Anthropic's published visualizer. Forked / Detached from sae_vis."
|
| 5 |
+
authors = ["Callum McDougall <cal.s.mcdougall@gmail.com>", "Joseph Bloom, <jbloomaus@gmail.com>"]
|
| 6 |
+
readme = "README.md"
|
| 7 |
+
license = "MIT"
|
| 8 |
+
|
| 9 |
+
[tool.poetry.dependencies]
|
| 10 |
+
python = "^3.10"
|
| 11 |
+
torch = "^2.0.0"
|
| 12 |
+
einops = ">=0.7.0"
|
| 13 |
+
datasets = "^3.0.0"
|
| 14 |
+
dataclasses-json = "^0.6.4"
|
| 15 |
+
jaxtyping = "^0.2.28"
|
| 16 |
+
transformer-lens = "^2.2.0"
|
| 17 |
+
transformers = "<4.57.0"
|
| 18 |
+
eindex-callum = "^0.1.0"
|
| 19 |
+
rich = "^13.7.1"
|
| 20 |
+
matplotlib = "^3.8.4"
|
| 21 |
+
safetensors = "^0.4.3"
|
| 22 |
+
typer = "^0.12.3"
|
| 23 |
+
sae-lens = "^6.8.0"
|
| 24 |
+
decode-clt = "^0.0.1"
|
| 25 |
+
hf-transfer = "^0.1.9"
|
| 26 |
+
|
| 27 |
+
[tool.poetry.group.dev.dependencies]
|
| 28 |
+
isort = "^5.13.2"
|
| 29 |
+
ruff = "^0.3.7"
|
| 30 |
+
pytest = "^8.1.1"
|
| 31 |
+
ipykernel = "^6.29.4"
|
| 32 |
+
pyright = "^1.1.359"
|
| 33 |
+
pytest-profiling = "^1.7.0"
|
| 34 |
+
memray = "^1.12.0"
|
| 35 |
+
syrupy = "^4.6.1"
|
| 36 |
+
flake8 = "^7.0.0"
|
| 37 |
+
pytest-cov = "^5.0.0"
|
| 38 |
+
black = "^24.4.2"
|
| 39 |
+
pytest-memray = "^1.7.0"
|
| 40 |
+
|
| 41 |
+
[tool.poetry.scripts]
|
| 42 |
+
neuronpedia-runner = "sae_dashboard.neuronpedia.neuronpedia_runner:main"
|
| 43 |
+
|
| 44 |
+
[tool.isort]
|
| 45 |
+
profile = "black"
|
| 46 |
+
src_paths = ["sae_dashboard", "tests"]
|
| 47 |
+
|
| 48 |
+
[tool.pyright]
|
| 49 |
+
typeCheckingMode = "strict"
|
| 50 |
+
reportMissingTypeStubs = "none"
|
| 51 |
+
reportUnknownMemberType = "none"
|
| 52 |
+
reportUnknownArgumentType = "none"
|
| 53 |
+
reportUnknownVariableType = "none"
|
| 54 |
+
reportUntypedFunctionDecorator = "none"
|
| 55 |
+
reportUnnecessaryIsInstance = "none"
|
| 56 |
+
reportUnnecessaryComparison = "none"
|
| 57 |
+
reportConstantRedefinition = "none"
|
| 58 |
+
reportUnknownLambdaType = "none"
|
| 59 |
+
reportPrivateUsage = "none"
|
| 60 |
+
reportPrivateImportUsage = "none"
|
| 61 |
+
|
| 62 |
+
[build-system]
|
| 63 |
+
requires = ["poetry-core"]
|
| 64 |
+
build-backend = "poetry.core.masonry.api"
|
| 65 |
+
|
| 66 |
+
[tool.semantic_release]
|
| 67 |
+
version_variables = ["sae_dashboard/__init__.py:__version__"]
|
| 68 |
+
version_toml = ["pyproject.toml:tool.poetry.version"]
|
| 69 |
+
build_command = "pip install poetry && poetry build"
|
| 70 |
+
branches = { main = { match = "main" } }
|
SAEDashboard/sae_dashboard/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.7.3"
|
| 2 |
+
|
| 3 |
+
# from .data_fetching_fns import *
|
| 4 |
+
# from .data_storing_fns import *
|
| 5 |
+
# from .html_fns import *
|
| 6 |
+
# from .transformer_lens_wrapper import *
|
| 7 |
+
# from .utils_fns import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# from autoencoder import AutoEncoder, AutoEncoderConfig
|
SAEDashboard/sae_dashboard/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (184 Bytes). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/components.cpython-313.pyc
ADDED
|
Binary file (33.3 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/components_config.cpython-313.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/data_parsing_fns.cpython-313.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/data_writing_fns.cpython-313.pyc
ADDED
|
Binary file (8 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/dfa_calculator.cpython-313.pyc
ADDED
|
Binary file (6.87 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/feature_data.cpython-313.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/feature_data_generator.cpython-313.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/html_fns.cpython-313.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/layout.cpython-313.pyc
ADDED
|
Binary file (8.84 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/sae_vis_data.cpython-313.pyc
ADDED
|
Binary file (9.15 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/sae_vis_runner.cpython-313.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/sequence_data_generator.cpython-313.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/transformer_lens_wrapper.cpython-313.pyc
ADDED
|
Binary file (8.23 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/utils_fns.cpython-313.pyc
ADDED
|
Binary file (49.6 kB). View file
|
|
|
SAEDashboard/sae_dashboard/__pycache__/vector_vis_data.cpython-313.pyc
ADDED
|
Binary file (9.39 kB). View file
|
|
|
SAEDashboard/sae_dashboard/clt_layer_wrapper.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
# Added dataclass, field, asdict
|
| 4 |
+
from dataclasses import asdict, dataclass, field
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
# import torch.nn as nn # Unused
|
| 8 |
+
# from torch.distributed import ProcessGroup # Unused
|
| 9 |
+
# from types import SimpleNamespace # Unused import
|
| 10 |
+
from typing import ( # Added Optional, Union and List
|
| 11 |
+
TYPE_CHECKING,
|
| 12 |
+
Any,
|
| 13 |
+
List,
|
| 14 |
+
Optional,
|
| 15 |
+
Union,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from clt.models.activations import BatchTopK # type: ignore
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
import torch.distributed # Import for ProcessGroup type hint
|
| 24 |
+
from clt.models.clt import CrossLayerTranscoder # type: ignore
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Placeholder for dist if torch.distributed is not available or initialized
|
| 28 |
+
class MockDist:
|
| 29 |
+
def is_initialized(self) -> bool:
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
def get_world_size(
|
| 33 |
+
self, group: "Optional[torch.distributed.ProcessGroup]" = None
|
| 34 |
+
) -> int:
|
| 35 |
+
return 1
|
| 36 |
+
|
| 37 |
+
def all_gather_into_tensor(
|
| 38 |
+
self,
|
| 39 |
+
output_tensor: torch.Tensor,
|
| 40 |
+
input_tensor: torch.Tensor,
|
| 41 |
+
group: "Optional[torch.distributed.ProcessGroup]" = None,
|
| 42 |
+
) -> None:
|
| 43 |
+
# In non-distributed setting, just copy input to output (assuming output is sized correctly)
|
| 44 |
+
if output_tensor.shape[0] == 1 * input_tensor.shape[0]:
|
| 45 |
+
output_tensor.copy_(input_tensor)
|
| 46 |
+
else:
|
| 47 |
+
# This case shouldn't happen if called correctly, but handle defensively
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"Output tensor size doesn't match input tensor size in mock all_gather"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def all_gather(
|
| 53 |
+
self,
|
| 54 |
+
tensor_list: List[torch.Tensor],
|
| 55 |
+
input_tensor: torch.Tensor,
|
| 56 |
+
group: "Optional[torch.distributed.ProcessGroup]" = None,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""Mock all_gather for a list of tensors."""
|
| 59 |
+
if self.get_world_size(group) == 1:
|
| 60 |
+
if len(tensor_list) == 1:
|
| 61 |
+
tensor_list[0].copy_(input_tensor)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(
|
| 64 |
+
"tensor_list size must be 1 in mock all_gather when world_size is 1"
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
# This mock doesn't support actual gathering for world_size > 1.
|
| 68 |
+
# It's primarily for the dist.all_gather call in _gather_weight,
|
| 69 |
+
# which should ideally not proceed if world_size > 1 and dist is MockDist.
|
| 70 |
+
# However, _gather_weight checks dist.is_initialized() and dist.get_world_size() first.
|
| 71 |
+
raise NotImplementedError(
|
| 72 |
+
"MockDist.all_gather not implemented for world_size > 1"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
import torch.distributed as dist
|
| 78 |
+
|
| 79 |
+
if not dist.is_available():
|
| 80 |
+
dist = MockDist() # type: ignore
|
| 81 |
+
except ImportError:
|
| 82 |
+
dist = MockDist() # type: ignore
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class CLTMetadata:
|
| 87 |
+
"""Simple metadata class for CLT wrapper compatibility."""
|
| 88 |
+
|
| 89 |
+
hook_name: str
|
| 90 |
+
hook_layer: int
|
| 91 |
+
model_name: Optional[str] = None
|
| 92 |
+
context_size: Optional[int] = None
|
| 93 |
+
prepend_bos: bool = True
|
| 94 |
+
hook_head_index: Optional[int] = None
|
| 95 |
+
seqpos_slice: Optional[slice] = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class CLTWrapperConfig:
|
| 100 |
+
"""Configuration dataclass for the CLTLayerWrapper."""
|
| 101 |
+
|
| 102 |
+
# Fields without defaults first
|
| 103 |
+
d_sae: int
|
| 104 |
+
d_in: int
|
| 105 |
+
hook_name: str
|
| 106 |
+
hook_layer: int
|
| 107 |
+
dtype: str
|
| 108 |
+
device: str
|
| 109 |
+
# Fields with defaults last
|
| 110 |
+
architecture: str = "jumprelu"
|
| 111 |
+
hook_head_index: Optional[int] = None
|
| 112 |
+
model_name: Optional[str] = None
|
| 113 |
+
dataset_path: Optional[str] = None
|
| 114 |
+
context_size: Optional[int] = None
|
| 115 |
+
prepend_bos: bool = True
|
| 116 |
+
normalize_activations: bool = False
|
| 117 |
+
dataset_trust_remote_code: bool = False
|
| 118 |
+
seqpos_slice: Optional[slice] = None
|
| 119 |
+
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 120 |
+
metadata: Optional[CLTMetadata] = None
|
| 121 |
+
|
| 122 |
+
def to_dict(self) -> dict[str, Any]:
|
| 123 |
+
"""Convert config to dictionary for compatibility with SAE interface."""
|
| 124 |
+
return asdict(self)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class CLTLayerWrapper:
|
| 128 |
+
"""Wraps a single layer of a CrossLayerTranscoder to mimic the SAE interface.
|
| 129 |
+
|
| 130 |
+
This allows reusing existing dashboard components that expect an SAE object.
|
| 131 |
+
It specifically provides access to the encoder and the *same-layer* decoder weights
|
| 132 |
+
for the specified layer index.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
cfg: CLTWrapperConfig # Add type hint for the config attribute
|
| 136 |
+
threshold: Optional[torch.Tensor] = (
|
| 137 |
+
None # For JumpReLU, set by FeatureMaskingContext
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
clt: "CrossLayerTranscoder",
|
| 143 |
+
layer_idx: int,
|
| 144 |
+
clt_model_dir_path: Optional[str] = None,
|
| 145 |
+
):
|
| 146 |
+
self.clt = clt
|
| 147 |
+
self.layer_idx = layer_idx
|
| 148 |
+
self.device = clt.device
|
| 149 |
+
self.dtype = clt.dtype
|
| 150 |
+
self.hook_z_reshaping_mode = False # Added to satisfy SAE interface
|
| 151 |
+
|
| 152 |
+
# Validate layer index
|
| 153 |
+
if not (0 <= layer_idx < clt.config.num_layers):
|
| 154 |
+
raise ValueError(
|
| 155 |
+
f"Invalid layer_idx {layer_idx} for CLT with {clt.config.num_layers} layers."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# --- Create the Wrapper Config ---
|
| 159 |
+
# Try to get model_name from the underlying clt config if it exists
|
| 160 |
+
clt_model_name = getattr(clt.config, "model_name", None)
|
| 161 |
+
clt_dataset_path = getattr(clt.config, "dataset_path", None)
|
| 162 |
+
clt_context_size = getattr(
|
| 163 |
+
clt.config, "context_size", 128
|
| 164 |
+
) # Default to 128 if not set
|
| 165 |
+
clt_prepend_bos = getattr(clt.config, "prepend_bos", True)
|
| 166 |
+
# Use the activation_fn from CLT config for the wrapper's architecture and encode method
|
| 167 |
+
self.activation_fn = getattr(clt.config, "activation_fn", "jumprelu")
|
| 168 |
+
clt_model_from_pretrained_kwargs = getattr(
|
| 169 |
+
clt.config, "model_from_pretrained_kwargs", {}
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# --- Load CLT-specific normalization stats if applicable ---
|
| 173 |
+
self.clt_norm_mean: Optional[torch.Tensor] = None
|
| 174 |
+
self.clt_norm_std: Optional[torch.Tensor] = None
|
| 175 |
+
wrapper_will_normalize_specifically = False
|
| 176 |
+
clt_norm_method = getattr(clt.config, "normalization_method", "none")
|
| 177 |
+
|
| 178 |
+
if clt_norm_method in ["auto", "estimated_mean_std", "mean_std"]:
|
| 179 |
+
if clt_model_dir_path:
|
| 180 |
+
norm_stats_file = Path(clt_model_dir_path) / "norm_stats.json"
|
| 181 |
+
if norm_stats_file.exists():
|
| 182 |
+
try:
|
| 183 |
+
with open(norm_stats_file, "r") as f:
|
| 184 |
+
stats_data = json.load(f)
|
| 185 |
+
|
| 186 |
+
layer_stats = stats_data.get(str(self.layer_idx), {}).get(
|
| 187 |
+
"inputs", {}
|
| 188 |
+
)
|
| 189 |
+
mean_vals = layer_stats.get("mean")
|
| 190 |
+
std_vals = layer_stats.get("std")
|
| 191 |
+
|
| 192 |
+
if mean_vals is not None and std_vals is not None:
|
| 193 |
+
self.clt_norm_mean = torch.tensor(
|
| 194 |
+
mean_vals, device=self.device, dtype=torch.float32
|
| 195 |
+
).unsqueeze(0)
|
| 196 |
+
self.clt_norm_std = (
|
| 197 |
+
torch.tensor(
|
| 198 |
+
std_vals, device=self.device, dtype=torch.float32
|
| 199 |
+
)
|
| 200 |
+
+ 1e-6
|
| 201 |
+
).unsqueeze(0)
|
| 202 |
+
if torch.any(self.clt_norm_std <= 0):
|
| 203 |
+
print(
|
| 204 |
+
f"Warning: Loaded std for layer {self.layer_idx} contains non-positive values after adding epsilon. Disabling specific normalization."
|
| 205 |
+
)
|
| 206 |
+
self.clt_norm_mean = None
|
| 207 |
+
self.clt_norm_std = None
|
| 208 |
+
else:
|
| 209 |
+
wrapper_will_normalize_specifically = True
|
| 210 |
+
print(
|
| 211 |
+
f"CLTLayerWrapper: Loaded norm_stats.json for layer {self.layer_idx}. Wrapper will apply specific normalization."
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
print(
|
| 215 |
+
f"Warning: norm_stats.json found, but missing 'mean' or 'std' for layer {self.layer_idx} inputs. Wrapper will not normalize specifically."
|
| 216 |
+
)
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(
|
| 219 |
+
f"Warning: Error loading or parsing norm_stats.json from {norm_stats_file}: {e}. Wrapper will not normalize specifically."
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
print(
|
| 223 |
+
f"Warning: normalization_method is '{clt_norm_method}' but norm_stats.json not found at {norm_stats_file}. Wrapper will not normalize specifically."
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
print(
|
| 227 |
+
f"Warning: normalization_method is '{clt_norm_method}' but clt_model_dir_path not provided. Wrapper cannot load norm_stats.json and will not normalize specifically."
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Determine normalize_activations flag for ActivationsStore based on CLT config and wrapper's capability
|
| 231 |
+
# This flag in self.cfg controls ActivationsStore. ActivationsStore should only normalize if the wrapper *isn't* doing specific normalization AND the CLT expected some form of normalization.
|
| 232 |
+
clt_config_indicated_normalization = clt_norm_method != "none"
|
| 233 |
+
normalize_activations_for_store = clt_config_indicated_normalization and (
|
| 234 |
+
not wrapper_will_normalize_specifically
|
| 235 |
+
)
|
| 236 |
+
if normalize_activations_for_store:
|
| 237 |
+
print(
|
| 238 |
+
f"CLTLayerWrapper: Setting normalize_activations=True for ActivationsStore (CLT method: {clt_norm_method}, wrapper specific norm: False)."
|
| 239 |
+
)
|
| 240 |
+
elif clt_config_indicated_normalization and wrapper_will_normalize_specifically:
|
| 241 |
+
print(
|
| 242 |
+
f"CLTLayerWrapper: Setting normalize_activations=False for ActivationsStore (CLT method: {clt_norm_method}, wrapper specific norm: True)."
|
| 243 |
+
)
|
| 244 |
+
else: # not clt_config_indicated_normalization
|
| 245 |
+
print(
|
| 246 |
+
f"CLTLayerWrapper: Setting normalize_activations=False for ActivationsStore (CLT method: {clt_norm_method})."
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Initialize self.threshold if activation is jumprelu
|
| 250 |
+
# This must happen AFTER self.activation_fn, self.device, self.dtype, self.layer_idx, and self.clt are set.
|
| 251 |
+
if self.activation_fn == "jumprelu":
|
| 252 |
+
if (
|
| 253 |
+
hasattr(self.clt, "log_threshold")
|
| 254 |
+
and self.clt.log_threshold is not None
|
| 255 |
+
):
|
| 256 |
+
if 0 <= self.layer_idx < self.clt.log_threshold.shape[0]:
|
| 257 |
+
# The log_threshold from CLT is [num_layers, num_features]
|
| 258 |
+
# We need the threshold for the current layer_idx
|
| 259 |
+
layer_thresholds = torch.exp(
|
| 260 |
+
self.clt.log_threshold[self.layer_idx].clone().detach()
|
| 261 |
+
)
|
| 262 |
+
self.threshold = layer_thresholds.to(
|
| 263 |
+
device=self.device, dtype=self.dtype
|
| 264 |
+
)
|
| 265 |
+
print(
|
| 266 |
+
f"CLTLayerWrapper: Initialized self.threshold for layer {self.layer_idx} from clt.log_threshold."
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
print(
|
| 270 |
+
f"Warning: CLTLayerWrapper layer_idx {self.layer_idx} is out of bounds for clt.log_threshold "
|
| 271 |
+
f"(shape {self.clt.log_threshold.shape}). self.threshold will be None."
|
| 272 |
+
)
|
| 273 |
+
self.threshold = None
|
| 274 |
+
else:
|
| 275 |
+
print(
|
| 276 |
+
f"Warning: Underlying CLT model for layer {self.layer_idx} does not have 'log_threshold' or it's None, "
|
| 277 |
+
f"but activation_fn is 'jumprelu'. self.threshold will be None."
|
| 278 |
+
)
|
| 279 |
+
self.threshold = None
|
| 280 |
+
# else: self.threshold remains its default None, which is fine for other activation functions.
|
| 281 |
+
|
| 282 |
+
# Get the hook name using prioritized templates
|
| 283 |
+
hook_name_template = getattr(clt.config, "tl_input_template", None)
|
| 284 |
+
if hook_name_template:
|
| 285 |
+
hook_name = hook_name_template.format(layer_idx)
|
| 286 |
+
print(f"Using TL hook name template: {hook_name_template} -> {hook_name}")
|
| 287 |
+
else:
|
| 288 |
+
hook_name_template = getattr(clt.config, "mlp_input_template", None)
|
| 289 |
+
if hook_name_template:
|
| 290 |
+
hook_name = hook_name_template.format(layer_idx)
|
| 291 |
+
print(
|
| 292 |
+
f"Warning: tl_input_template not found. Using mlp_input_template: {hook_name_template} -> {hook_name}"
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
# Fallback for older configs without any template
|
| 296 |
+
hook_name = f"blocks.{layer_idx}.hook_mlp_in"
|
| 297 |
+
print(
|
| 298 |
+
f"Warning: Neither tl_input_template nor mlp_input_template found. Falling back to hardcoded: {hook_name}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
self.cfg = CLTWrapperConfig(
|
| 302 |
+
d_sae=clt.config.num_features, # This is the d_sae of the *entire* CLT layer, not a sub-batch
|
| 303 |
+
d_in=clt.config.d_model,
|
| 304 |
+
hook_name=hook_name,
|
| 305 |
+
hook_layer=layer_idx,
|
| 306 |
+
hook_head_index=None,
|
| 307 |
+
dtype=str(self.dtype).replace("torch.", ""),
|
| 308 |
+
device=str(self.device),
|
| 309 |
+
architecture=self.activation_fn, # Use the determined activation_fn
|
| 310 |
+
model_name=clt_model_name,
|
| 311 |
+
dataset_path=clt_dataset_path,
|
| 312 |
+
context_size=clt_context_size,
|
| 313 |
+
prepend_bos=clt_prepend_bos,
|
| 314 |
+
normalize_activations=normalize_activations_for_store,
|
| 315 |
+
dataset_trust_remote_code=False,
|
| 316 |
+
seqpos_slice=None,
|
| 317 |
+
model_from_pretrained_kwargs=clt_model_from_pretrained_kwargs,
|
| 318 |
+
metadata=CLTMetadata(
|
| 319 |
+
hook_name=hook_name,
|
| 320 |
+
hook_layer=layer_idx,
|
| 321 |
+
model_name=clt_model_name,
|
| 322 |
+
context_size=clt_context_size,
|
| 323 |
+
prepend_bos=clt_prepend_bos,
|
| 324 |
+
hook_head_index=None,
|
| 325 |
+
seqpos_slice=None,
|
| 326 |
+
),
|
| 327 |
+
)
|
| 328 |
+
# --- End Config Creation ---
|
| 329 |
+
|
| 330 |
+
# Extract and potentially gather weights
|
| 331 |
+
# Ensure weights are detached and cloned to avoid modifying the original CLT
|
| 332 |
+
# Original W_enc from CLT encoder module is [d_sae_layer, d_model]
|
| 333 |
+
# We transpose to match sae-lens W_enc convention: [d_model, d_sae_layer]
|
| 334 |
+
self.W_enc = (
|
| 335 |
+
self._gather_encoder_weight(clt.encoder_module.encoders[layer_idx].weight) # type: ignore
|
| 336 |
+
.t()
|
| 337 |
+
.contiguous()
|
| 338 |
+
)
|
| 339 |
+
# For W_dec, use the decoder from the same layer to itself
|
| 340 |
+
decoder_key = f"{layer_idx}->{layer_idx}"
|
| 341 |
+
if decoder_key not in clt.decoder_module.decoders: # type: ignore
|
| 342 |
+
raise KeyError(f"Decoder key {decoder_key} not found in CLT decoders.")
|
| 343 |
+
# Original W_dec from CLT decoder module is [d_model, d_sae_layer]
|
| 344 |
+
# We transpose to match sae-lens W_dec convention: [d_sae_layer, d_model]
|
| 345 |
+
self.W_dec = (
|
| 346 |
+
self._gather_decoder_weight(clt.decoder_module.decoders[decoder_key].weight) # type: ignore
|
| 347 |
+
.t()
|
| 348 |
+
.contiguous()
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
self.b_enc = self._gather_encoder_bias(
|
| 352 |
+
clt.encoder_module.encoders[layer_idx].bias_param # type: ignore
|
| 353 |
+
)
|
| 354 |
+
# For b_dec, use the bias from the same-layer decoder
|
| 355 |
+
self.b_dec = self._gather_decoder_bias(
|
| 356 |
+
clt.decoder_module.decoders[decoder_key].bias_param # type: ignore
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# Cache for folded weights if needed
|
| 360 |
+
self._W_dec_folded = False
|
| 361 |
+
# Thresholds for JumpReLU will be handled by FeatureMaskingContext if architecture is 'jumprelu'
|
| 362 |
+
# by setting self.threshold directly on the wrapper instance.
|
| 363 |
+
|
| 364 |
+
# --- Façade methods mimicking SAE --- #
|
| 365 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 366 |
+
"""
|
| 367 |
+
Encodes input using the CLTLayerWrapper's own W_enc and b_enc,
|
| 368 |
+
respecting masks applied by FeatureMaskingContext.
|
| 369 |
+
Applies the activation function specified in self.activation_fn.
|
| 370 |
+
"""
|
| 371 |
+
# x is [..., d_model]
|
| 372 |
+
# self.W_enc after masking (by FeatureMaskingContext) should be [d_model, N_FEATURES_IN_BATCH]
|
| 373 |
+
# self.b_enc after masking (by FeatureMaskingContext) should be [N_FEATURES_IN_BATCH]
|
| 374 |
+
|
| 375 |
+
original_shape = x.shape
|
| 376 |
+
if x.ndim > 2: # Ensure x is [N, d_model] for F.linear
|
| 377 |
+
# self.cfg.d_in should be d_model
|
| 378 |
+
x_reshaped = x.reshape(-1, self.cfg.d_in)
|
| 379 |
+
else:
|
| 380 |
+
x_reshaped = x
|
| 381 |
+
|
| 382 |
+
x_to_process = x_reshaped
|
| 383 |
+
# Apply CLT-specific normalization if stats were loaded
|
| 384 |
+
if self.clt_norm_mean is not None and self.clt_norm_std is not None:
|
| 385 |
+
# Ensure calculation is done in float32 for precision, then cast back
|
| 386 |
+
x_float32 = x_to_process.to(torch.float32)
|
| 387 |
+
normalized_x = (x_float32 - self.clt_norm_mean) / self.clt_norm_std
|
| 388 |
+
x_to_process = normalized_x.to(x.dtype)
|
| 389 |
+
|
| 390 |
+
# F.linear(input, weight, bias) expects weight to be [out_features, in_features]
|
| 391 |
+
# self.W_enc is [d_model, N_FEATURES_IN_BATCH], so its transpose is [N_FEATURES_IN_BATCH, d_model]
|
| 392 |
+
hidden_pre = F.linear(
|
| 393 |
+
x_to_process, self.W_enc.T, self.b_enc
|
| 394 |
+
) # Output: [N, N_FEATURES_IN_BATCH]
|
| 395 |
+
|
| 396 |
+
# Apply activation function
|
| 397 |
+
if self.activation_fn == "relu":
|
| 398 |
+
encoded_acts = torch.relu(hidden_pre)
|
| 399 |
+
elif self.activation_fn == "jumprelu":
|
| 400 |
+
if not hasattr(self, "threshold") or self.threshold is None:
|
| 401 |
+
raise AttributeError(
|
| 402 |
+
"JumpReLU activation selected, but 'self.threshold' is not available on CLTLayerWrapper. "
|
| 403 |
+
"FeatureMaskingContext should set this if architecture is 'jumprelu'."
|
| 404 |
+
)
|
| 405 |
+
encoded_acts = torch.where(
|
| 406 |
+
hidden_pre > self.threshold, hidden_pre, torch.zeros_like(hidden_pre)
|
| 407 |
+
)
|
| 408 |
+
elif self.activation_fn == "batchtopk":
|
| 409 |
+
k_val: float
|
| 410 |
+
batchtopk_k_abs = getattr(self.clt.config, "batchtopk_k", None)
|
| 411 |
+
batchtopk_k_frac = getattr(self.clt.config, "batchtopk_frac", None)
|
| 412 |
+
|
| 413 |
+
if batchtopk_k_abs is not None:
|
| 414 |
+
# This k is global. For the current batch of features, we use a per-layer approximation.
|
| 415 |
+
k_val = float(batchtopk_k_abs) / self.clt.config.num_layers
|
| 416 |
+
k_val = max(
|
| 417 |
+
1.0, k_val
|
| 418 |
+
) # Ensure at least 1 feature is kept if k/num_layers is small
|
| 419 |
+
elif batchtopk_k_frac is not None:
|
| 420 |
+
k_val = float(
|
| 421 |
+
batchtopk_k_frac
|
| 422 |
+
) # Fraction applies directly to current N_FEATURES_IN_BATCH
|
| 423 |
+
else:
|
| 424 |
+
# Fallback: if neither k nor frac is specified, keep all features currently being processed.
|
| 425 |
+
# This matches the fallback in CrossLayerTranscoder.encode for its per-layer batchtopk.
|
| 426 |
+
print(
|
| 427 |
+
f"Warning: CLTLayerWrapper using batchtopk, but neither 'batchtopk_k' nor 'batchtopk_frac' defined in CLTConfig. Defaulting to keeping all {hidden_pre.size(-1)} features in the current batch."
|
| 428 |
+
)
|
| 429 |
+
k_val = float(hidden_pre.size(-1))
|
| 430 |
+
|
| 431 |
+
straight_through_flag = getattr(
|
| 432 |
+
self.clt.config, "batchtopk_straight_through", False
|
| 433 |
+
)
|
| 434 |
+
encoded_acts = BatchTopK.apply(hidden_pre, k_val, straight_through_flag)
|
| 435 |
+
else:
|
| 436 |
+
raise ValueError(
|
| 437 |
+
f"Unsupported activation function in CLTLayerWrapper: {self.activation_fn}"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
if x.ndim > 2:
|
| 441 |
+
# Reshape back to original batch/sequence dimensions, with the last dim being N_FEATURES_IN_BATCH
|
| 442 |
+
encoded_acts = encoded_acts.reshape(*original_shape[:-1], -1) # type: ignore
|
| 443 |
+
|
| 444 |
+
return encoded_acts # type: ignore
|
| 445 |
+
|
| 446 |
+
def turn_off_forward_pass_hook_z_reshaping(self):
|
| 447 |
+
"""Stub method to satisfy SAE interface. CLTWrapper does not use this."""
|
| 448 |
+
# This mode is not applicable to CLTLayerWrapper, so this method is a no-op.
|
| 449 |
+
pass
|
| 450 |
+
|
| 451 |
+
# Note: CLTLayerWrapper does not have a separate `decode` method façade
|
| 452 |
+
# because the dashboard primarily uses W_dec directly for analysis (e.g., logits).
|
| 453 |
+
# The CLT's actual decode logic (summing across layers) isn't needed here.
|
| 454 |
+
|
| 455 |
+
def fold_W_dec_norm(self):
|
| 456 |
+
"""Folds the L2 norm of W_dec into W_enc and b_enc.
|
| 457 |
+
|
| 458 |
+
Mirrors the logic in sae_lens.SAE.fold_W_dec_norm.
|
| 459 |
+
Important for ensuring that W_enc activations directly correspond
|
| 460 |
+
to the output norm when using the wrapped W_dec.
|
| 461 |
+
"""
|
| 462 |
+
if self._W_dec_folded:
|
| 463 |
+
print("Warning: W_dec norm already folded.")
|
| 464 |
+
return
|
| 465 |
+
|
| 466 |
+
if self.W_dec is None or self.W_enc is None:
|
| 467 |
+
print("Warning: Cannot fold W_dec norm, weights not available.")
|
| 468 |
+
return
|
| 469 |
+
|
| 470 |
+
# Detach W_dec before calculating norm to avoid gradient issues
|
| 471 |
+
# W_dec is [N_FEATURES_IN_BATCH, d_model] (after masking context and init)
|
| 472 |
+
# Norm should be taken over d_model dim (dim=1)
|
| 473 |
+
|
| 474 |
+
# Use W_dec with its original dtype for norm calculation
|
| 475 |
+
w_dec_for_norm = self.W_dec.detach()
|
| 476 |
+
w_dec_norms = torch.norm(
|
| 477 |
+
w_dec_for_norm, dim=1, keepdim=True
|
| 478 |
+
) # [N_FEATURES_IN_BATCH, 1]
|
| 479 |
+
|
| 480 |
+
w_dec_norms = torch.where(
|
| 481 |
+
w_dec_norms == 0, torch.ones_like(w_dec_norms), w_dec_norms
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# self.W_enc is [d_model, N_FEATURES_IN_BATCH]
|
| 485 |
+
# We want to scale each column of W_enc (each feature's encoder vector)
|
| 486 |
+
# by the corresponding feature's w_dec_norm.
|
| 487 |
+
# Ensure dtypes match for multiplication, then cast W_enc back if necessary
|
| 488 |
+
original_w_enc_dtype = self.W_enc.dtype
|
| 489 |
+
self.W_enc.data = (self.W_enc.data.to(w_dec_norms.dtype) * w_dec_norms.t()).to(
|
| 490 |
+
original_w_enc_dtype
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
if self.b_enc is not None:
|
| 494 |
+
# self.b_enc is [N_FEATURES_IN_BATCH]
|
| 495 |
+
# w_dec_norms.squeeze() is [N_FEATURES_IN_BATCH]
|
| 496 |
+
original_b_enc_dtype = self.b_enc.dtype
|
| 497 |
+
self.b_enc.data = (
|
| 498 |
+
self.b_enc.data.to(w_dec_norms.dtype) * w_dec_norms.squeeze()
|
| 499 |
+
).to(original_b_enc_dtype)
|
| 500 |
+
|
| 501 |
+
# Store the norms for potential unfolding or reference
|
| 502 |
+
self._w_dec_norms_backup = w_dec_norms
|
| 503 |
+
self._W_dec_folded = True
|
| 504 |
+
print("Folded W_dec norm into W_enc and b_enc.")
|
| 505 |
+
|
| 506 |
+
def unfold_W_dec_norm(self):
|
| 507 |
+
"""Unfolds the L2 norm of W_dec from W_enc and b_enc."""
|
| 508 |
+
if not self._W_dec_folded or not hasattr(self, "_w_dec_norms_backup"):
|
| 509 |
+
print("Warning: W_dec norm not folded or backup norms not found.")
|
| 510 |
+
return
|
| 511 |
+
|
| 512 |
+
if self.W_enc is None:
|
| 513 |
+
print("Warning: Cannot unfold W_dec norm, W_enc not available.")
|
| 514 |
+
return
|
| 515 |
+
|
| 516 |
+
# Retrieve the norms used for folding
|
| 517 |
+
w_dec_norms = self._w_dec_norms_backup
|
| 518 |
+
# Avoid division by zero (should have been handled in fold, but double check)
|
| 519 |
+
w_dec_norms = torch.where(
|
| 520 |
+
w_dec_norms == 0, torch.ones_like(w_dec_norms), w_dec_norms
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
original_w_enc_dtype = self.W_enc.dtype
|
| 524 |
+
self.W_enc.data = (self.W_enc.data.to(w_dec_norms.dtype) / w_dec_norms.t()).to(
|
| 525 |
+
original_w_enc_dtype
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if self.b_enc is not None:
|
| 529 |
+
original_b_enc_dtype = self.b_enc.dtype
|
| 530 |
+
self.b_enc.data = (
|
| 531 |
+
self.b_enc.data.to(w_dec_norms.dtype) / w_dec_norms.squeeze()
|
| 532 |
+
).to(original_b_enc_dtype)
|
| 533 |
+
|
| 534 |
+
del self._w_dec_norms_backup
|
| 535 |
+
self._W_dec_folded = False
|
| 536 |
+
print("Unfolded W_dec norm from W_enc and b_enc.")
|
| 537 |
+
|
| 538 |
+
def to(self, device: Union[str, torch.device]):
|
| 539 |
+
"""Moves the wrapper and underlying components to the specified device."""
|
| 540 |
+
target_device = torch.device(device)
|
| 541 |
+
|
| 542 |
+
# Move the underlying CLT model
|
| 543 |
+
try:
|
| 544 |
+
self.clt.to(target_device)
|
| 545 |
+
except Exception as e:
|
| 546 |
+
print(
|
| 547 |
+
f"Warning: Failed to move underlying CLT model to {target_device}: {e}"
|
| 548 |
+
)
|
| 549 |
+
# Continue trying to move wrapper components
|
| 550 |
+
|
| 551 |
+
# Move the wrapper's stored tensors
|
| 552 |
+
if self.W_enc is not None:
|
| 553 |
+
self.W_enc = self.W_enc.to(target_device)
|
| 554 |
+
if self.W_dec is not None:
|
| 555 |
+
self.W_dec = self.W_dec.to(target_device)
|
| 556 |
+
if self.b_enc is not None:
|
| 557 |
+
self.b_enc = self.b_enc.to(target_device)
|
| 558 |
+
if self.b_dec is not None:
|
| 559 |
+
self.b_dec = self.b_dec.to(target_device)
|
| 560 |
+
if (
|
| 561 |
+
hasattr(self, "_w_dec_norms_backup")
|
| 562 |
+
and self._w_dec_norms_backup is not None
|
| 563 |
+
):
|
| 564 |
+
self._w_dec_norms_backup = self._w_dec_norms_backup.to(target_device)
|
| 565 |
+
|
| 566 |
+
# Update device attributes
|
| 567 |
+
self.device = target_device
|
| 568 |
+
self.cfg.device = str(target_device)
|
| 569 |
+
|
| 570 |
+
# Update activation_fn related thresholds if they exist (e.g. for JumpReLU)
|
| 571 |
+
if hasattr(self, "threshold") and self.threshold is not None:
|
| 572 |
+
self.threshold = self.threshold.to(target_device)
|
| 573 |
+
|
| 574 |
+
if self.clt_norm_mean is not None: # Added to move norm stats
|
| 575 |
+
self.clt_norm_mean = self.clt_norm_mean.to(target_device)
|
| 576 |
+
if self.clt_norm_std is not None: # Added to move norm stats
|
| 577 |
+
self.clt_norm_std = self.clt_norm_std.to(target_device)
|
| 578 |
+
|
| 579 |
+
print(f"Moved CLTLayerWrapper to {target_device}")
|
| 580 |
+
return self
|
| 581 |
+
|
| 582 |
+
# --- Helper methods for Tensor Parallelism --- #
|
| 583 |
+
|
| 584 |
+
def _gather_weight(
|
| 585 |
+
self,
|
| 586 |
+
weight_shard: torch.Tensor,
|
| 587 |
+
gather_dim: int = 0,
|
| 588 |
+
target_full_dim_size: Optional[int] = None,
|
| 589 |
+
) -> torch.Tensor:
|
| 590 |
+
"""Gather a weight tensor shard across TP ranks."""
|
| 591 |
+
if not dist.is_initialized() or dist.get_world_size() == 1:
|
| 592 |
+
return weight_shard.clone().detach()
|
| 593 |
+
|
| 594 |
+
world_size = dist.get_world_size()
|
| 595 |
+
# Create a list to hold all gathered tensors
|
| 596 |
+
tensor_list = [torch.empty_like(weight_shard) for _ in range(world_size)]
|
| 597 |
+
dist.all_gather(tensor_list, weight_shard)
|
| 598 |
+
|
| 599 |
+
# Concatenate along the specified dimension
|
| 600 |
+
full_weight = torch.cat(tensor_list, dim=gather_dim)
|
| 601 |
+
|
| 602 |
+
# Trim padding if necessary
|
| 603 |
+
if target_full_dim_size is not None:
|
| 604 |
+
if gather_dim == 0:
|
| 605 |
+
if full_weight.shape[0] > target_full_dim_size:
|
| 606 |
+
full_weight = full_weight[:target_full_dim_size, :]
|
| 607 |
+
elif gather_dim == 1:
|
| 608 |
+
if full_weight.shape[1] > target_full_dim_size:
|
| 609 |
+
full_weight = full_weight[:, :target_full_dim_size]
|
| 610 |
+
# Add other gather_dim cases if needed
|
| 611 |
+
|
| 612 |
+
return full_weight.detach()
|
| 613 |
+
|
| 614 |
+
def _gather_encoder_weight(self, weight_shard: torch.Tensor) -> torch.Tensor:
|
| 615 |
+
"""Gather ColumnParallelLinear weight (sharded along output/feature dim)."""
|
| 616 |
+
# ColumnParallel weight is [d_sae_local, d_model]
|
| 617 |
+
# We need to gather along dim 0 to get [d_sae_full_for_layer, d_model]
|
| 618 |
+
return self._gather_weight(
|
| 619 |
+
weight_shard,
|
| 620 |
+
gather_dim=0,
|
| 621 |
+
target_full_dim_size=self.clt.config.num_features,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
def _gather_decoder_weight(self, weight_shard: torch.Tensor) -> torch.Tensor:
|
| 625 |
+
"""Gather RowParallelLinear weight (sharded along input/feature dim)."""
|
| 626 |
+
# RowParallel weight is [d_model, d_sae_local]
|
| 627 |
+
# We need to gather along dim 1 to get [d_model, d_sae_full_for_layer]
|
| 628 |
+
return self._gather_weight(
|
| 629 |
+
weight_shard,
|
| 630 |
+
gather_dim=1,
|
| 631 |
+
target_full_dim_size=self.clt.config.num_features,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
def _gather_bias(
|
| 635 |
+
self,
|
| 636 |
+
bias_shard: Optional[torch.Tensor],
|
| 637 |
+
gather_dim: int = 0,
|
| 638 |
+
target_full_dim_size: Optional[int] = None,
|
| 639 |
+
) -> Optional[torch.Tensor]:
|
| 640 |
+
"""Gather a bias tensor shard across TP ranks."""
|
| 641 |
+
if bias_shard is None:
|
| 642 |
+
return None
|
| 643 |
+
# Biases are typically sharded along the same dimension as the weight's corresponding output dim
|
| 644 |
+
return self._gather_weight(
|
| 645 |
+
bias_shard, gather_dim=gather_dim, target_full_dim_size=target_full_dim_size
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
def _gather_encoder_bias(
|
| 649 |
+
self, bias_shard_candidate: Optional[torch.Tensor]
|
| 650 |
+
) -> Optional[torch.Tensor]:
|
| 651 |
+
"""Gather ColumnParallelLinear bias (sharded along output/feature dim).
|
| 652 |
+
|
| 653 |
+
Defensively checks if the provided candidate is actually a Tensor.
|
| 654 |
+
"""
|
| 655 |
+
# Check if the provided object is a Tensor
|
| 656 |
+
if isinstance(bias_shard_candidate, torch.Tensor):
|
| 657 |
+
# Encoder bias shape [d_sae_local], gather along dim 0
|
| 658 |
+
return self._gather_bias(
|
| 659 |
+
bias_shard_candidate,
|
| 660 |
+
gather_dim=0,
|
| 661 |
+
target_full_dim_size=self.clt.config.num_features,
|
| 662 |
+
)
|
| 663 |
+
else:
|
| 664 |
+
# If it's None, bool, or anything else, treat as no bias
|
| 665 |
+
return None
|
| 666 |
+
|
| 667 |
+
def _gather_decoder_bias(
|
| 668 |
+
self, bias_shard_candidate: Optional[torch.Tensor]
|
| 669 |
+
) -> Optional[torch.Tensor]:
|
| 670 |
+
"""Gather RowParallelLinear bias (NOT sharded, but might need broadcast/check).
|
| 671 |
+
|
| 672 |
+
Defensively checks if the provided candidate is actually a Tensor.
|
| 673 |
+
"""
|
| 674 |
+
# Check if the provided object is a Tensor
|
| 675 |
+
if isinstance(bias_shard_candidate, torch.Tensor):
|
| 676 |
+
# RowParallelLinear bias is typically not sharded (added after all-reduce)
|
| 677 |
+
# However, let's check world size and return a clone if TP=1, or verify replication if TP>1
|
| 678 |
+
if not dist.is_initialized() or dist.get_world_size() == 1:
|
| 679 |
+
return bias_shard_candidate.clone().detach()
|
| 680 |
+
|
| 681 |
+
# In TP > 1, the bias should be identical across ranks. Verify this.
|
| 682 |
+
world_size = dist.get_world_size()
|
| 683 |
+
tensor_list = [
|
| 684 |
+
torch.empty_like(bias_shard_candidate) for _ in range(world_size)
|
| 685 |
+
]
|
| 686 |
+
dist.all_gather(tensor_list, bias_shard_candidate)
|
| 687 |
+
# Check if all gathered biases are the same
|
| 688 |
+
for i in range(1, world_size):
|
| 689 |
+
if not torch.equal(tensor_list[0], tensor_list[i]):
|
| 690 |
+
raise RuntimeError(
|
| 691 |
+
"RowParallelLinear bias shards are not identical across TP ranks, which is unexpected."
|
| 692 |
+
)
|
| 693 |
+
# Return the bias from rank 0 (or any rank, as they are identical)
|
| 694 |
+
return tensor_list[0].clone().detach()
|
| 695 |
+
else:
|
| 696 |
+
# If it's None, bool, or anything else, treat as no bias
|
| 697 |
+
return None
|
SAEDashboard/sae_dashboard/components.py
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Callable, List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from dataclasses_json import dataclass_json
|
| 8 |
+
|
| 9 |
+
from sae_dashboard.components_config import (
|
| 10 |
+
ActsHistogramConfig,
|
| 11 |
+
FeatureTablesConfig,
|
| 12 |
+
LogitsHistogramConfig,
|
| 13 |
+
LogitsTableConfig,
|
| 14 |
+
PromptConfig,
|
| 15 |
+
SequencesConfig,
|
| 16 |
+
)
|
| 17 |
+
from sae_dashboard.html_fns import HTML, bgColorMap, uColorMap
|
| 18 |
+
from sae_dashboard.utils_fns import (
|
| 19 |
+
HistogramData,
|
| 20 |
+
max_or_1,
|
| 21 |
+
to_str_tokens,
|
| 22 |
+
unprocess_str_tok,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
PRECISION = 4
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class DecoderWeightsDistribution:
|
| 30 |
+
n_heads: int
|
| 31 |
+
allocation_by_head: List[float]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass_json
|
| 35 |
+
@dataclass
|
| 36 |
+
class FeatureTablesData:
|
| 37 |
+
"""
|
| 38 |
+
This contains all the data necessary to make the left-hand tables in prompt-centric visualization. See diagram
|
| 39 |
+
in readme:
|
| 40 |
+
|
| 41 |
+
https://github.com/callummcdougall/sae_vis#data_storing_fnspy
|
| 42 |
+
|
| 43 |
+
Inputs:
|
| 44 |
+
neuron_alignment...
|
| 45 |
+
The data for the neuron alignment table (each of its 3 columns). In other words, the data containing which
|
| 46 |
+
neurons in the transformer the encoder feature is most aligned with.
|
| 47 |
+
|
| 48 |
+
correlated_neurons...
|
| 49 |
+
The data for the correlated neurons table (each of its 3 columns). In other words, the data containing which
|
| 50 |
+
neurons in the transformer are most correlated with the encoder feature.
|
| 51 |
+
|
| 52 |
+
correlated_features...
|
| 53 |
+
The data for the correlated features table (each of its 3 columns). In other words, the data containing
|
| 54 |
+
which features in this encoder are most correlated with each other.
|
| 55 |
+
|
| 56 |
+
correlated_b_features...
|
| 57 |
+
The data for the correlated features table (each of its 3 columns). In other words, the data containing
|
| 58 |
+
which features in encoder-B are most correlated with those in the original encoder. Note, this one might be
|
| 59 |
+
absent if we're not using a B-encoder.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
neuron_alignment_indices: list[int] = field(default_factory=list)
|
| 63 |
+
neuron_alignment_values: list[float] = field(default_factory=list)
|
| 64 |
+
neuron_alignment_l1: list[float] = field(default_factory=list)
|
| 65 |
+
correlated_neurons_indices: list[int] = field(default_factory=list)
|
| 66 |
+
correlated_neurons_pearson: list[float] = field(default_factory=list)
|
| 67 |
+
correlated_neurons_cossim: list[float] = field(default_factory=list)
|
| 68 |
+
correlated_features_indices: list[int] = field(default_factory=list)
|
| 69 |
+
correlated_features_pearson: list[float] = field(default_factory=list)
|
| 70 |
+
correlated_features_cossim: list[float] = field(default_factory=list)
|
| 71 |
+
correlated_b_features_indices: list[int] = field(default_factory=list)
|
| 72 |
+
correlated_b_features_pearson: list[float] = field(default_factory=list)
|
| 73 |
+
correlated_b_features_cossim: list[float] = field(default_factory=list)
|
| 74 |
+
|
| 75 |
+
def _get_html_data(
|
| 76 |
+
self,
|
| 77 |
+
cfg: FeatureTablesConfig,
|
| 78 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 79 |
+
id_suffix: str,
|
| 80 |
+
column: int | tuple[int, int],
|
| 81 |
+
component_specific_kwargs: dict[str, Any] = {},
|
| 82 |
+
) -> HTML:
|
| 83 |
+
"""
|
| 84 |
+
Returns the HTML for the left-hand tables, wrapped in a 'grid-column' div.
|
| 85 |
+
|
| 86 |
+
Note, we only ever use this obj in the context of the left-hand column of the feature-centric vis, and it's
|
| 87 |
+
always the same width & height, which is why there's no customization available for this function.
|
| 88 |
+
"""
|
| 89 |
+
# Read HTML from file, and replace placeholders with real ID values
|
| 90 |
+
html_str = (
|
| 91 |
+
Path(__file__).parent / "html" / "feature_tables_template.html"
|
| 92 |
+
).read_text()
|
| 93 |
+
html_str = html_str.replace("FEATURE_TABLES_ID", f"feature-tables-{id_suffix}")
|
| 94 |
+
|
| 95 |
+
# Create dictionary storing the data
|
| 96 |
+
data: dict[str, list[dict[str, str | float]]] = {}
|
| 97 |
+
|
| 98 |
+
# Store the neuron alignment data, if it exists
|
| 99 |
+
if len(self.neuron_alignment_indices) > 0:
|
| 100 |
+
assert len(self.neuron_alignment_indices) >= cfg.n_rows, "Not enough rows!"
|
| 101 |
+
data["neuronAlignment"] = [
|
| 102 |
+
{
|
| 103 |
+
"index": index,
|
| 104 |
+
"value": f"{value:+.3f}",
|
| 105 |
+
"percentageL1": f"{percent_l1:.1%}",
|
| 106 |
+
}
|
| 107 |
+
for index, value, percent_l1 in zip(
|
| 108 |
+
self.neuron_alignment_indices,
|
| 109 |
+
self.neuron_alignment_values,
|
| 110 |
+
self.neuron_alignment_l1,
|
| 111 |
+
)
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
# Store the other 3, if they exist (they're all in the same format, so we can do it in a for loop)
|
| 115 |
+
for name, js_name in zip(
|
| 116 |
+
["correlated_neurons", "correlated_features", "correlated_b_features"],
|
| 117 |
+
["correlatedNeurons", "correlatedFeatures", "correlatedFeaturesB"],
|
| 118 |
+
):
|
| 119 |
+
if len(getattr(self, f"{name}_indices")) > 0:
|
| 120 |
+
# assert len(getattr(self, f"{name}_indices")) >= cfg.n_rows, "Not enough rows!"
|
| 121 |
+
data[js_name] = [
|
| 122 |
+
{
|
| 123 |
+
"index": index,
|
| 124 |
+
"value": f"{value:+.3f}",
|
| 125 |
+
"percentageL1": f"{percent_L1:+.3f}",
|
| 126 |
+
}
|
| 127 |
+
for index, value, percent_L1 in zip(
|
| 128 |
+
getattr(self, f"{name}_indices")[: cfg.n_rows],
|
| 129 |
+
getattr(self, f"{name}_pearson")[: cfg.n_rows],
|
| 130 |
+
getattr(self, f"{name}_cossim")[: cfg.n_rows],
|
| 131 |
+
)
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
return HTML(
|
| 135 |
+
html_data={column: html_str},
|
| 136 |
+
js_data={"featureTablesData": {id_suffix: data}},
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@dataclass_json
|
| 141 |
+
@dataclass
|
| 142 |
+
class ActsHistogramData(HistogramData):
|
| 143 |
+
def _get_html_data(
|
| 144 |
+
self,
|
| 145 |
+
cfg: ActsHistogramConfig,
|
| 146 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 147 |
+
id_suffix: str,
|
| 148 |
+
column: int | tuple[int, int],
|
| 149 |
+
component_specific_kwargs: dict[str, Any] = {},
|
| 150 |
+
) -> HTML:
|
| 151 |
+
"""
|
| 152 |
+
Converts data -> HTML object, for the feature activations histogram (i.e. the histogram over all sampled tokens,
|
| 153 |
+
showing the distribution of activations for this feature).
|
| 154 |
+
"""
|
| 155 |
+
# We can't post-hoc change the number of bins, so check this wasn't changed in the config
|
| 156 |
+
# assert cfg.n_bins == len(self.bar_heights),\
|
| 157 |
+
# "Can't post-hoc change `n_bins` in histogram config - you need to regenerate data."
|
| 158 |
+
|
| 159 |
+
# Read HTML from file, and replace placeholders with real ID values
|
| 160 |
+
html_str = (
|
| 161 |
+
Path(__file__).parent / "html" / "acts_histogram_template.html"
|
| 162 |
+
).read_text()
|
| 163 |
+
html_str = html_str.replace("HISTOGRAM_ACTS_ID", f"histogram-acts-{id_suffix}")
|
| 164 |
+
|
| 165 |
+
# Process colors for frequency histogram; it's darker at higher values
|
| 166 |
+
bar_values_normed = [
|
| 167 |
+
(0.4 * max(self.bar_values) + 0.6 * v)
|
| 168 |
+
/ max(max(self.bar_values), 1e-6) # avoid divide by zero
|
| 169 |
+
for v in self.bar_values
|
| 170 |
+
]
|
| 171 |
+
bar_colors = [bgColorMap(v) for v in bar_values_normed]
|
| 172 |
+
|
| 173 |
+
# Next we create the data dict
|
| 174 |
+
data: dict[str, Any] = {
|
| 175 |
+
"y": self.bar_heights,
|
| 176 |
+
"x": self.bar_values,
|
| 177 |
+
"ticks": self.tick_vals,
|
| 178 |
+
"colors": bar_colors,
|
| 179 |
+
}
|
| 180 |
+
if self.title is not None:
|
| 181 |
+
data["title"] = self.title
|
| 182 |
+
|
| 183 |
+
return HTML(
|
| 184 |
+
html_data={column: html_str},
|
| 185 |
+
js_data={"actsHistogramData": {id_suffix: data}},
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@dataclass_json
|
| 190 |
+
@dataclass
|
| 191 |
+
class LogitsHistogramData(HistogramData):
|
| 192 |
+
def _get_html_data(
|
| 193 |
+
self,
|
| 194 |
+
cfg: LogitsHistogramConfig,
|
| 195 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 196 |
+
id_suffix: str,
|
| 197 |
+
column: int | tuple[int, int],
|
| 198 |
+
component_specific_kwargs: dict[str, Any] = {},
|
| 199 |
+
) -> HTML:
|
| 200 |
+
"""
|
| 201 |
+
Converts data -> HTML object, for the logits histogram (i.e. the histogram over all tokens in the vocab, showing
|
| 202 |
+
the distribution of direct logit effect on that token).
|
| 203 |
+
"""
|
| 204 |
+
# We can't post-hoc change the number of bins, so check this wasn't changed in the config
|
| 205 |
+
# assert cfg.n_bins == len(self.bar_heights),\
|
| 206 |
+
# "Can't post-hoc change `n_bins` in histogram config - you need to regenerate data."
|
| 207 |
+
|
| 208 |
+
# Read HTML from file, and replace placeholders with real ID values
|
| 209 |
+
html_str = (
|
| 210 |
+
Path(__file__).parent / "html" / "logits_histogram_template.html"
|
| 211 |
+
).read_text()
|
| 212 |
+
html_str = html_str.replace(
|
| 213 |
+
"HISTOGRAM_LOGITS_ID", f"histogram-logits-{id_suffix}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
data: dict[str, Any] = {
|
| 217 |
+
"y": self.bar_heights,
|
| 218 |
+
"x": self.bar_values,
|
| 219 |
+
"ticks": self.tick_vals,
|
| 220 |
+
}
|
| 221 |
+
if self.title is not None:
|
| 222 |
+
data["title"] = self.title
|
| 223 |
+
|
| 224 |
+
return HTML(
|
| 225 |
+
html_data={column: html_str},
|
| 226 |
+
js_data={"logitsHistogramData": {id_suffix: data}},
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@dataclass_json
|
| 231 |
+
@dataclass
|
| 232 |
+
class LogitsTableData:
|
| 233 |
+
bottom_token_ids: list[int] = field(default_factory=list)
|
| 234 |
+
bottom_logits: list[float] = field(default_factory=list)
|
| 235 |
+
top_token_ids: list[int] = field(default_factory=list)
|
| 236 |
+
top_logits: list[float] = field(default_factory=list)
|
| 237 |
+
|
| 238 |
+
def _get_html_data(
|
| 239 |
+
self,
|
| 240 |
+
cfg: LogitsTableConfig,
|
| 241 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 242 |
+
id_suffix: str,
|
| 243 |
+
column: int | tuple[int, int],
|
| 244 |
+
component_specific_kwargs: dict[str, Any] = {},
|
| 245 |
+
) -> HTML:
|
| 246 |
+
"""
|
| 247 |
+
Converts data -> HTML object, for the logits table (i.e. the top and bottom affected tokens by this feature).
|
| 248 |
+
"""
|
| 249 |
+
# Crop the lists to `cfg.n_rows` (first checking the config doesn't ask for more rows than we have)
|
| 250 |
+
assert cfg.n_rows <= len(self.bottom_logits)
|
| 251 |
+
bottom_token_ids = self.bottom_token_ids[: cfg.n_rows]
|
| 252 |
+
bottom_logits = self.bottom_logits[: cfg.n_rows]
|
| 253 |
+
top_token_ids = self.top_token_ids[: cfg.n_rows]
|
| 254 |
+
top_logits = self.top_logits[: cfg.n_rows]
|
| 255 |
+
|
| 256 |
+
# Get the negative and positive background values (darkest when equals max abs)
|
| 257 |
+
max_value = max(
|
| 258 |
+
max(top_logits[: cfg.n_rows]), -min(bottom_logits[: cfg.n_rows])
|
| 259 |
+
)
|
| 260 |
+
neg_bg_values = np.absolute(bottom_logits[: cfg.n_rows]) / max_value
|
| 261 |
+
pos_bg_values = np.absolute(top_logits[: cfg.n_rows]) / max_value
|
| 262 |
+
|
| 263 |
+
# Get the string tokens, using the decode function
|
| 264 |
+
neg_str = to_str_tokens(decode_fn, bottom_token_ids[: cfg.n_rows])
|
| 265 |
+
pos_str = to_str_tokens(decode_fn, top_token_ids[: cfg.n_rows])
|
| 266 |
+
|
| 267 |
+
# Read HTML from file, and replace placeholders with real ID values
|
| 268 |
+
html_str = (
|
| 269 |
+
Path(__file__).parent / "html" / "logits_table_template.html"
|
| 270 |
+
).read_text()
|
| 271 |
+
html_str = html_str.replace("LOGITS_TABLE_ID", f"logits-table-{id_suffix}")
|
| 272 |
+
|
| 273 |
+
# Create object for storing JS data
|
| 274 |
+
data: dict[str, list[dict[str, str | float]]] = {
|
| 275 |
+
"negLogits": [],
|
| 276 |
+
"posLogits": [],
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
# Get data for the tables of pos/neg logits
|
| 280 |
+
for i in range(len(neg_str)):
|
| 281 |
+
data["negLogits"].append(
|
| 282 |
+
{
|
| 283 |
+
"symbol": unprocess_str_tok(neg_str[i]),
|
| 284 |
+
"value": round(bottom_logits[i], 2),
|
| 285 |
+
"color": f"rgba(255,{int(255*(1-neg_bg_values[i]))},{int(255*(1-neg_bg_values[i]))},0.5)",
|
| 286 |
+
}
|
| 287 |
+
)
|
| 288 |
+
data["posLogits"].append(
|
| 289 |
+
{
|
| 290 |
+
"symbol": unprocess_str_tok(pos_str[i]),
|
| 291 |
+
"value": round(top_logits[i], 2),
|
| 292 |
+
"color": f"rgba({int(255*(1-pos_bg_values[i]))},{int(255*(1-pos_bg_values[i]))},255,0.5)",
|
| 293 |
+
}
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
return HTML(
|
| 297 |
+
html_data={column: html_str},
|
| 298 |
+
js_data={"logitsTableData": {id_suffix: data}},
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@dataclass_json
|
| 303 |
+
@dataclass
|
| 304 |
+
class SequenceData:
|
| 305 |
+
"""
|
| 306 |
+
This contains all the data necessary to make a sequence of tokens in the vis. See diagram in readme:
|
| 307 |
+
|
| 308 |
+
https://github.com/callummcdougall/sae_vis#data_storing_fnspy
|
| 309 |
+
|
| 310 |
+
Always-visible data:
|
| 311 |
+
token_ids: List of token IDs in the sequence
|
| 312 |
+
feat_acts: Sizes of activations on this sequence
|
| 313 |
+
loss_contribution: Effect on loss of this feature, for this particular token (neg = helpful)
|
| 314 |
+
|
| 315 |
+
Data which is visible on hover:
|
| 316 |
+
token_logits: The logits of the particular token in that sequence (used for line on logits histogram)
|
| 317 |
+
top_token_ids: List of the top 5 logit-boosted tokens by this feature
|
| 318 |
+
top_logits: List of the corresponding 5 changes in logits for those tokens
|
| 319 |
+
bottom_token_ids: List of the bottom 5 logit-boosted tokens by this feature
|
| 320 |
+
bottom_logits: List of the corresponding 5 changes in logits for those tokens
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
original_index: int = 0
|
| 324 |
+
qualifying_token_index: int = 0
|
| 325 |
+
token_ids: list[int] = field(default_factory=list)
|
| 326 |
+
feat_acts: list[float] = field(default_factory=list)
|
| 327 |
+
loss_contribution: list[float] = field(default_factory=list)
|
| 328 |
+
|
| 329 |
+
token_logits: list[float] = field(default_factory=list)
|
| 330 |
+
top_token_ids: list[list[int]] = field(default_factory=list)
|
| 331 |
+
top_logits: list[list[float]] = field(default_factory=list)
|
| 332 |
+
bottom_token_ids: list[list[int]] = field(default_factory=list)
|
| 333 |
+
bottom_logits: list[list[float]] = field(default_factory=list)
|
| 334 |
+
|
| 335 |
+
def __post_init__(self) -> None:
|
| 336 |
+
"""
|
| 337 |
+
Filters the logits & token IDs by removing any elements which are zero (this saves space in the eventual
|
| 338 |
+
JavaScript).
|
| 339 |
+
"""
|
| 340 |
+
self.seq_len = len(self.token_ids)
|
| 341 |
+
self.top_logits, self.top_token_ids = self._filter(
|
| 342 |
+
self.top_logits, self.top_token_ids
|
| 343 |
+
)
|
| 344 |
+
self.bottom_logits, self.bottom_token_ids = self._filter(
|
| 345 |
+
self.bottom_logits, self.bottom_token_ids
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
def _filter(
|
| 349 |
+
self, float_list: list[list[float]], int_list: list[list[int]]
|
| 350 |
+
) -> tuple[list[list[float]], list[list[int]]]:
|
| 351 |
+
"""
|
| 352 |
+
Filters the list of floats and ints, by removing any elements which are zero. Note - the absolute values of the
|
| 353 |
+
floats are monotonic non-increasing, so we can assume that all the elements we keep will be the first elements
|
| 354 |
+
of their respective lists. Also reduces precisions of feature activations & logits.
|
| 355 |
+
"""
|
| 356 |
+
# Next, filter out zero-elements and reduce precision
|
| 357 |
+
float_list = [
|
| 358 |
+
[round(f, PRECISION) for f in floats if abs(f) > 1e-6]
|
| 359 |
+
for floats in float_list
|
| 360 |
+
]
|
| 361 |
+
int_list = [ints[: len(floats)] for ints, floats in zip(int_list, float_list)]
|
| 362 |
+
return float_list, int_list
|
| 363 |
+
|
| 364 |
+
def _get_html_data(
|
| 365 |
+
self,
|
| 366 |
+
cfg: PromptConfig | SequencesConfig,
|
| 367 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 368 |
+
id_suffix: str,
|
| 369 |
+
column: int | tuple[int, int],
|
| 370 |
+
component_specific_kwargs: dict[str, Any] = {},
|
| 371 |
+
) -> HTML:
|
| 372 |
+
"""
|
| 373 |
+
Args:
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
js_data: list[dict[str, Any]]
|
| 377 |
+
The data for this sequence, in the form of a list of dicts for each token (where the dict stores things
|
| 378 |
+
like token, feature activations, etc).
|
| 379 |
+
"""
|
| 380 |
+
assert isinstance(
|
| 381 |
+
cfg, (PromptConfig, SequencesConfig)
|
| 382 |
+
), f"Invalid config type: {type(cfg)}"
|
| 383 |
+
seq_group_id = component_specific_kwargs.get("seq_group_id", None)
|
| 384 |
+
max_feat_act = component_specific_kwargs.get("max_feat_act", None)
|
| 385 |
+
max_loss_contribution = component_specific_kwargs.get(
|
| 386 |
+
"max_loss_contribution", None
|
| 387 |
+
)
|
| 388 |
+
bold_idx = component_specific_kwargs.get("bold_idx", None)
|
| 389 |
+
permanent_line = component_specific_kwargs.get("permanent_line", False)
|
| 390 |
+
first_in_group = component_specific_kwargs.get("first_in_group", True)
|
| 391 |
+
title = component_specific_kwargs.get("title", None)
|
| 392 |
+
hover_above = component_specific_kwargs.get("hover_above", False)
|
| 393 |
+
|
| 394 |
+
# If we didn't supply a sequence group ID, then we assume this sequence is on its own, and give it a unique ID
|
| 395 |
+
if seq_group_id is None:
|
| 396 |
+
seq_group_id = f"prompt-{column:03d}"
|
| 397 |
+
|
| 398 |
+
# If we didn't specify bold_idx, then set it to be the midpoint
|
| 399 |
+
if bold_idx is None:
|
| 400 |
+
bold_idx = self.seq_len // 2
|
| 401 |
+
|
| 402 |
+
# If we only have data for the bold token, we pad out everything with zeros or empty lists
|
| 403 |
+
only_bold = isinstance(cfg, SequencesConfig) and not (cfg.compute_buffer)
|
| 404 |
+
if only_bold:
|
| 405 |
+
assert bold_idx != "max", "Don't know how to deal with this case yet."
|
| 406 |
+
feat_acts = [
|
| 407 |
+
self.feat_acts[0] if (i == bold_idx) else 0.0
|
| 408 |
+
for i in range(self.seq_len)
|
| 409 |
+
]
|
| 410 |
+
loss_contribution = [
|
| 411 |
+
self.loss_contribution[0] if (i == bold_idx) + 1 else 0.0
|
| 412 |
+
for i in range(self.seq_len)
|
| 413 |
+
]
|
| 414 |
+
pos_ids = [
|
| 415 |
+
self.top_token_ids[0] if (i == bold_idx) + 1 else []
|
| 416 |
+
for i in range(self.seq_len)
|
| 417 |
+
]
|
| 418 |
+
neg_ids = [
|
| 419 |
+
self.bottom_token_ids[0] if (i == bold_idx) + 1 else []
|
| 420 |
+
for i in range(self.seq_len)
|
| 421 |
+
]
|
| 422 |
+
pos_val = [
|
| 423 |
+
self.top_logits[0] if (i == bold_idx) + 1 else []
|
| 424 |
+
for i in range(self.seq_len)
|
| 425 |
+
]
|
| 426 |
+
neg_val = [
|
| 427 |
+
self.bottom_logits[0] if (i == bold_idx) + 1 else []
|
| 428 |
+
for i in range(self.seq_len)
|
| 429 |
+
]
|
| 430 |
+
else:
|
| 431 |
+
feat_acts = deepcopy(self.feat_acts)
|
| 432 |
+
loss_contribution = deepcopy(self.loss_contribution)
|
| 433 |
+
pos_ids = deepcopy(self.top_token_ids)
|
| 434 |
+
neg_ids = deepcopy(self.bottom_token_ids)
|
| 435 |
+
pos_val = deepcopy(self.top_logits)
|
| 436 |
+
neg_val = deepcopy(self.bottom_logits)
|
| 437 |
+
|
| 438 |
+
# EXPERIMENT: let's just hardcode everything except feature acts to be 0's for now.
|
| 439 |
+
loss_contribution = [0.0 for _ in range(self.seq_len)]
|
| 440 |
+
pos_ids = [[] for _ in range(self.seq_len)]
|
| 441 |
+
neg_ids = [[] for _ in range(self.seq_len)]
|
| 442 |
+
pos_val = [[] for _ in range(self.seq_len)]
|
| 443 |
+
neg_val = [[] for _ in range(self.seq_len)]
|
| 444 |
+
### END EXPERIMENT
|
| 445 |
+
|
| 446 |
+
# Get values for converting into colors later
|
| 447 |
+
bg_denom = max_feat_act or max_or_1(self.feat_acts)
|
| 448 |
+
u_denom = max_loss_contribution or max_or_1(self.loss_contribution, abs=True)
|
| 449 |
+
bg_values = (np.maximum(feat_acts, 0.0) / max(1e-4, bg_denom)).tolist()
|
| 450 |
+
u_values = (np.array(loss_contribution) / max(1e-4, u_denom)).tolist()
|
| 451 |
+
|
| 452 |
+
# If we sent in a prompt rather than this being sliced from a longer sequence, then the pos_ids etc will be shorter
|
| 453 |
+
# than the token list by 1, so we need to pad it at the first token
|
| 454 |
+
if isinstance(cfg, PromptConfig):
|
| 455 |
+
assert (
|
| 456 |
+
len(pos_ids)
|
| 457 |
+
== len(neg_ids)
|
| 458 |
+
== len(pos_val)
|
| 459 |
+
== len(neg_val)
|
| 460 |
+
== len(self.token_ids) - 1
|
| 461 |
+
), "If this is a single prompt, these lists must be the same length as token_ids or 1 less"
|
| 462 |
+
pos_ids = [[]] + pos_ids
|
| 463 |
+
neg_ids = [[]] + neg_ids
|
| 464 |
+
pos_val = [[]] + pos_val
|
| 465 |
+
neg_val = [[]] + neg_val
|
| 466 |
+
|
| 467 |
+
assert (
|
| 468 |
+
len(pos_ids)
|
| 469 |
+
== len(neg_ids)
|
| 470 |
+
== len(pos_val)
|
| 471 |
+
== len(neg_val)
|
| 472 |
+
== len(self.token_ids)
|
| 473 |
+
), "If this is part of a sequence group etc are given, they must be the same length as token_ids"
|
| 474 |
+
|
| 475 |
+
# Process the tokens to get str toks
|
| 476 |
+
toks = to_str_tokens(decode_fn, self.token_ids)
|
| 477 |
+
pos_toks = [to_str_tokens(decode_fn, pos) for pos in pos_ids]
|
| 478 |
+
neg_toks = [to_str_tokens(decode_fn, neg) for neg in neg_ids]
|
| 479 |
+
|
| 480 |
+
# Define the JavaScript object which will be used to populate the HTML string
|
| 481 |
+
js_data_list = []
|
| 482 |
+
|
| 483 |
+
for i in range(len(self.token_ids)):
|
| 484 |
+
# We might store a bunch of different case-specific data in the JavaScript object for each token. This is
|
| 485 |
+
# done in the form of a disjoint union over different dictionaries (which can each be empty or not), this
|
| 486 |
+
# minimizes the size of the overall JavaScript object. See function in `tokens_script.js` for more.
|
| 487 |
+
kwargs_bold: dict[str, bool] = {}
|
| 488 |
+
kwargs_hide: dict[str, bool] = {}
|
| 489 |
+
kwargs_this_token_active: dict[str, Any] = {}
|
| 490 |
+
kwargs_prev_token_active: dict[str, Any] = {}
|
| 491 |
+
kwargs_hover_above: dict[str, bool] = {}
|
| 492 |
+
|
| 493 |
+
# Get args if this is the bolded token (we make it bold, and maybe add permanent line to histograms)
|
| 494 |
+
if bold_idx is not None:
|
| 495 |
+
kwargs_bold["isBold"] = (bold_idx == i) or (
|
| 496 |
+
bold_idx == "max" and i == np.argmax(feat_acts).item()
|
| 497 |
+
)
|
| 498 |
+
if kwargs_bold["isBold"] and permanent_line:
|
| 499 |
+
kwargs_bold["permanentLine"] = True
|
| 500 |
+
|
| 501 |
+
# If we only have data for the bold token, we hide all other tokens' hoverdata (and skip other kwargs)
|
| 502 |
+
if (
|
| 503 |
+
only_bold
|
| 504 |
+
and isinstance(bold_idx, int)
|
| 505 |
+
and (i not in {bold_idx, bold_idx + 1})
|
| 506 |
+
):
|
| 507 |
+
kwargs_hide["hide"] = True
|
| 508 |
+
|
| 509 |
+
else:
|
| 510 |
+
# Get args if we're making the tooltip hover above token (default is below)
|
| 511 |
+
if hover_above:
|
| 512 |
+
kwargs_hover_above["hoverAbove"] = True
|
| 513 |
+
|
| 514 |
+
# If feature active on this token, get background color and feature act (for hist line)
|
| 515 |
+
if abs(feat_acts[i]) > 1e-8:
|
| 516 |
+
kwargs_this_token_active = dict(
|
| 517 |
+
featAct=round(feat_acts[i], PRECISION),
|
| 518 |
+
bgColor=bgColorMap(bg_values[i]),
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# If prev token active, get the top/bottom logits table, underline color, and loss effect (for hist line)
|
| 522 |
+
pos_toks_i, neg_toks_i, pos_val_i, neg_val_i = (
|
| 523 |
+
pos_toks[i],
|
| 524 |
+
neg_toks[i],
|
| 525 |
+
pos_val[i],
|
| 526 |
+
neg_val[i],
|
| 527 |
+
)
|
| 528 |
+
if len(pos_toks_i) + len(neg_toks_i) > 0:
|
| 529 |
+
# Create dictionary
|
| 530 |
+
kwargs_prev_token_active = dict(
|
| 531 |
+
posToks=pos_toks_i,
|
| 532 |
+
negToks=neg_toks_i,
|
| 533 |
+
posVal=pos_val_i,
|
| 534 |
+
negVal=neg_val_i,
|
| 535 |
+
lossEffect=round(loss_contribution[i], PRECISION),
|
| 536 |
+
uColor=uColorMap(u_values[i]),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
js_data_list.append(
|
| 540 |
+
dict(
|
| 541 |
+
tok=unprocess_str_tok(toks[i]),
|
| 542 |
+
tokID=self.token_ids[i],
|
| 543 |
+
tokenLogit=round(self.token_logits[i], PRECISION),
|
| 544 |
+
**kwargs_bold,
|
| 545 |
+
**kwargs_this_token_active,
|
| 546 |
+
**kwargs_prev_token_active,
|
| 547 |
+
**kwargs_hover_above,
|
| 548 |
+
)
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# Create HTML string (empty by default since sequences are added by JavaScript) and JS data
|
| 552 |
+
html_str = ""
|
| 553 |
+
js_seq_group_data: dict[str, Any] = {"data": [js_data_list]}
|
| 554 |
+
|
| 555 |
+
# Add group-specific stuff if this is the first sequence in the group
|
| 556 |
+
if first_in_group:
|
| 557 |
+
# Read HTML from file, replace placeholders with real ID values
|
| 558 |
+
html_str = (
|
| 559 |
+
Path(__file__).parent / "html" / "sequences_group_template.html"
|
| 560 |
+
).read_text()
|
| 561 |
+
html_str = html_str.replace("SEQUENCE_GROUP_ID", seq_group_id)
|
| 562 |
+
|
| 563 |
+
# Get title of sequence group, and the idSuffix to match up with a histogram
|
| 564 |
+
js_seq_group_data["idSuffix"] = id_suffix
|
| 565 |
+
if title is not None:
|
| 566 |
+
js_seq_group_data["title"] = title
|
| 567 |
+
|
| 568 |
+
return HTML(
|
| 569 |
+
html_data={column: html_str},
|
| 570 |
+
js_data={"tokenData": {seq_group_id: js_seq_group_data}},
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
@dataclass_json
|
| 575 |
+
@dataclass
|
| 576 |
+
class SequenceGroupData:
|
| 577 |
+
"""
|
| 578 |
+
This contains all the data necessary to make a single group of sequences (e.g. a quantile in prompt-centric
|
| 579 |
+
visualization). See diagram in readme:
|
| 580 |
+
|
| 581 |
+
https://github.com/callummcdougall/sae_vis#data_storing_fnspy
|
| 582 |
+
|
| 583 |
+
Inputs:
|
| 584 |
+
title: The title that this sequence group will have, if any. This is used in `_get_html_data`. The titles
|
| 585 |
+
will actually be in the HTML strings, not in the JavaScript data.
|
| 586 |
+
seq_data: The data for the sequences in this group.
|
| 587 |
+
"""
|
| 588 |
+
|
| 589 |
+
title: str = ""
|
| 590 |
+
seq_data: list[SequenceData] = field(default_factory=list)
|
| 591 |
+
|
| 592 |
+
def __len__(self) -> int:
|
| 593 |
+
return len(self.seq_data)
|
| 594 |
+
|
| 595 |
+
@property
|
| 596 |
+
def max_feat_act(self) -> float:
|
| 597 |
+
"""Returns maximum value of feature activation over all sequences in this group."""
|
| 598 |
+
return max_or_1([act for seq in self.seq_data for act in seq.feat_acts])
|
| 599 |
+
|
| 600 |
+
@property
|
| 601 |
+
def max_loss_contribution(self) -> float:
|
| 602 |
+
"""Returns maximum value of loss contribution over all sequences in this group."""
|
| 603 |
+
return max_or_1(
|
| 604 |
+
[loss for seq in self.seq_data for loss in seq.loss_contribution], abs=True
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
def _get_html_data(
|
| 608 |
+
self,
|
| 609 |
+
cfg: SequencesConfig,
|
| 610 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 611 |
+
id_suffix: str,
|
| 612 |
+
column: int | tuple[int, int],
|
| 613 |
+
component_specific_kwargs: dict[str, Any] = {},
|
| 614 |
+
# These default values should be correct when we only have one sequence group, because when we call this from
|
| 615 |
+
# a SequenceMultiGroupData we'll override them)
|
| 616 |
+
) -> HTML:
|
| 617 |
+
"""
|
| 618 |
+
This creates a single group of sequences, i.e. title plus some number of vertically stacked sequences.
|
| 619 |
+
|
| 620 |
+
Note, `column` is treated specially here, because the col might overflow (hence colulmn could be a tuple).
|
| 621 |
+
|
| 622 |
+
Args (from component-specific kwargs):
|
| 623 |
+
seq_group_id: The id of the sequence group div. This will usually be passed as e.g. "seq-group-001".
|
| 624 |
+
group_size: Max size of sequences in the group (i.e. we truncate after this many, if argument supplied).
|
| 625 |
+
max_feat_act: If supplied, then we use this as the most extreme value (for coloring by feature act).
|
| 626 |
+
|
| 627 |
+
Returns:
|
| 628 |
+
html_obj: Object containing the HTML and JavaScript data for this seq group.
|
| 629 |
+
"""
|
| 630 |
+
seq_group_id = component_specific_kwargs.get("seq_group_id", None)
|
| 631 |
+
group_size = component_specific_kwargs.get("group_size", None)
|
| 632 |
+
max_feat_act = component_specific_kwargs.get("max_feat_act", self.max_feat_act)
|
| 633 |
+
max_loss_contribution = component_specific_kwargs.get(
|
| 634 |
+
"max_loss_contribution", self.max_loss_contribution
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Get the data that will go into the div (list of list of dicts, i.e. containing all data for seqs in group). We
|
| 638 |
+
# start with the title.
|
| 639 |
+
html_obj = HTML()
|
| 640 |
+
|
| 641 |
+
# If seq_group_id is not supplied, then we assume this is the only sequence in the column, and we name the group
|
| 642 |
+
# after the column
|
| 643 |
+
if seq_group_id is None:
|
| 644 |
+
seq_group_id = f"seq-group-{column:03d}"
|
| 645 |
+
|
| 646 |
+
# Accumulate the HTML data for each sequence in this group
|
| 647 |
+
for i, seq in enumerate(self.seq_data[:group_size]):
|
| 648 |
+
html_obj += seq._get_html_data(
|
| 649 |
+
cfg=cfg,
|
| 650 |
+
# pass in a PromptConfig object
|
| 651 |
+
decode_fn=decode_fn,
|
| 652 |
+
id_suffix=id_suffix,
|
| 653 |
+
column=column,
|
| 654 |
+
component_specific_kwargs=dict(
|
| 655 |
+
bold_idx="max" if cfg.buffer is None else cfg.buffer[0],
|
| 656 |
+
permanent_line=False, # in a group, we're never showing a permanent line (only for single seqs)
|
| 657 |
+
max_feat_act=max_feat_act,
|
| 658 |
+
max_loss_contribution=max_loss_contribution,
|
| 659 |
+
seq_group_id=seq_group_id,
|
| 660 |
+
first_in_group=(i == 0),
|
| 661 |
+
title=self.title,
|
| 662 |
+
),
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
return html_obj
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
@dataclass_json
|
| 669 |
+
@dataclass
|
| 670 |
+
class SequenceMultiGroupData:
|
| 671 |
+
"""
|
| 672 |
+
This contains all the data necessary to make multiple groups of sequences (e.g. the different quantiles in the
|
| 673 |
+
prompt-centric visualization). See diagram in readme:
|
| 674 |
+
|
| 675 |
+
https://github.com/callummcdougall/sae_vis#data_storing_fnspy
|
| 676 |
+
"""
|
| 677 |
+
|
| 678 |
+
seq_group_data: list[SequenceGroupData] = field(default_factory=list)
|
| 679 |
+
|
| 680 |
+
def __getitem__(self, idx: int) -> SequenceGroupData:
|
| 681 |
+
return self.seq_group_data[idx]
|
| 682 |
+
|
| 683 |
+
@property
|
| 684 |
+
def max_feat_act(self) -> float:
|
| 685 |
+
"""Returns maximum value of feature activation over all sequences in this group."""
|
| 686 |
+
return max_or_1([seq_group.max_feat_act for seq_group in self.seq_group_data])
|
| 687 |
+
|
| 688 |
+
@property
|
| 689 |
+
def max_loss_contribution(self) -> float:
|
| 690 |
+
"""Returns maximum value of loss contribution over all sequences in this group."""
|
| 691 |
+
return max_or_1(
|
| 692 |
+
[seq_group.max_loss_contribution for seq_group in self.seq_group_data]
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
def _get_html_data(
|
| 696 |
+
self,
|
| 697 |
+
cfg: SequencesConfig,
|
| 698 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 699 |
+
id_suffix: str,
|
| 700 |
+
column: int | tuple[int, int],
|
| 701 |
+
component_specific_kwargs: dict[str, Any] = {},
|
| 702 |
+
) -> HTML:
|
| 703 |
+
"""
|
| 704 |
+
Args:
|
| 705 |
+
decode_fn: Mapping from token IDs to string tokens.
|
| 706 |
+
id_suffix: The suffix for the ID of the div containing the sequences.
|
| 707 |
+
column: The index of this column. Note that this will be an int, but we might end up
|
| 708 |
+
turning it into a tuple if we overflow into a new column.
|
| 709 |
+
component_specific_kwargs: Contains any specific kwargs that could be used to customize this component.
|
| 710 |
+
|
| 711 |
+
Returns:
|
| 712 |
+
html_obj: Object containing the HTML and JavaScript data for these multiple seq groups.
|
| 713 |
+
"""
|
| 714 |
+
assert isinstance(column, int)
|
| 715 |
+
|
| 716 |
+
# Get max activation value & max loss contributions, over all sequences in all groups
|
| 717 |
+
max_feat_act = component_specific_kwargs.get("max_feat_act", self.max_feat_act)
|
| 718 |
+
max_loss_contribution = component_specific_kwargs.get(
|
| 719 |
+
"max_loss_contribution", self.max_loss_contribution
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
# Get the correct column indices for the sequence groups, depending on how group_wrap is configured. Note, we
|
| 723 |
+
# deal with overflowing columns by extending the dictionary, i.e. our column argument isn't just `column`, but
|
| 724 |
+
# is a tuple of `(column, x)` where `x` is the number of times we've overflowed. For instance, if we have mode
|
| 725 |
+
# 'stack-none' then our columns are `(column, 0), (column, 1), (column, 1), (column, 1), (column, 2), ...`
|
| 726 |
+
n_groups = len(self.seq_group_data)
|
| 727 |
+
n_quantile_groups = n_groups - 1
|
| 728 |
+
match cfg.stack_mode:
|
| 729 |
+
case "stack-all":
|
| 730 |
+
# Here, we stack all groups into 1st column
|
| 731 |
+
cols = [column for _ in range(n_groups)]
|
| 732 |
+
case "stack-quantiles":
|
| 733 |
+
# Here, we give 1st group its own column, and stack all groups into second column
|
| 734 |
+
cols = [(column, 0)] + [(column, 1) for _ in range(n_quantile_groups)]
|
| 735 |
+
case "stack-none":
|
| 736 |
+
# Here, we stack groups into columns as [1, 3, 3, ...]
|
| 737 |
+
cols = [
|
| 738 |
+
(column, 0),
|
| 739 |
+
*[(column, 1 + int(i / 3)) for i in range(n_quantile_groups)],
|
| 740 |
+
]
|
| 741 |
+
case _:
|
| 742 |
+
raise ValueError(
|
| 743 |
+
f"Invalid stack_mode: {cfg.stack_mode}. Expected in 'stack-{{all,quantiles,none}}'."
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
# Create the HTML object, and add all the sequence groups to it, possibly across different columns
|
| 747 |
+
html_obj = HTML()
|
| 748 |
+
for i, (col, group_size, sequences_group) in enumerate(
|
| 749 |
+
zip(cols, cfg.group_sizes, self.seq_group_data)
|
| 750 |
+
):
|
| 751 |
+
html_obj += sequences_group._get_html_data(
|
| 752 |
+
cfg=cfg,
|
| 753 |
+
decode_fn=decode_fn,
|
| 754 |
+
id_suffix=id_suffix,
|
| 755 |
+
column=col,
|
| 756 |
+
component_specific_kwargs=dict(
|
| 757 |
+
group_size=group_size,
|
| 758 |
+
max_feat_act=max_feat_act,
|
| 759 |
+
max_loss_contribution=max_loss_contribution,
|
| 760 |
+
seq_group_id=f"seq-group-{column}-{i}", # we label our sequence groups with (index, column)
|
| 761 |
+
),
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
return html_obj
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
GenericData = (
|
| 768 |
+
FeatureTablesData
|
| 769 |
+
| ActsHistogramData
|
| 770 |
+
| LogitsTableData
|
| 771 |
+
| LogitsHistogramData
|
| 772 |
+
| SequenceMultiGroupData
|
| 773 |
+
| SequenceData
|
| 774 |
+
)
|
SAEDashboard/sae_dashboard/components_config.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, Iterator, Literal
|
| 3 |
+
|
| 4 |
+
SEQUENCES_CONFIG_HELP = dict(
|
| 5 |
+
buffer="How many tokens to add as context to each sequence, on each side. The tokens chosen for the top acts / \
|
| 6 |
+
quantile groups can't be outside the buffer range. If None, we use the entire sequence as context.",
|
| 7 |
+
compute_buffer="If False, then we don't compute the loss effect, activations, or any other data for tokens \
|
| 8 |
+
other than the bold tokens in our sequences (saving time).",
|
| 9 |
+
n_quantiles="Number of quantile groups for the sequences. If zero, we only show top activations, no quantile \
|
| 10 |
+
groups.",
|
| 11 |
+
top_acts_group_size="Number of sequences in the 'top activating sequences' group.",
|
| 12 |
+
quantile_group_size="Number of sequences in each of the sequence quantile groups.",
|
| 13 |
+
top_logits_hoverdata="Number of top/bottom logits to show in the hoverdata for each token.",
|
| 14 |
+
stack_mode="How to stack the sequence groups.\n 'stack-all' = all groups are stacked in a single column \
|
| 15 |
+
(scrolls vertically if it overflows)\n 'stack-quantiles' = first col contains top acts, second col contains all \
|
| 16 |
+
quantile groups\n 'stack-none' = we stack in a way which ensures no vertical scrolling.",
|
| 17 |
+
hover_below="Whether the hover information about a token appears below or above the token.",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
ACTIVATIONS_HISTOGRAM_CONFIG_HELP = dict(
|
| 21 |
+
n_bins="Number of bins for the histogram.",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
LOGITS_HISTOGRAM_CONFIG_HELP = dict(
|
| 25 |
+
n_bins="Number of bins for the histogram.",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
LOGITS_TABLE_CONFIG_HELP = dict(
|
| 29 |
+
n_rows="Number of top/bottom logits to show in the table.",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
FEATURE_TABLES_CONFIG_HELP = dict(
|
| 33 |
+
n_rows="Number of rows to show for each feature table.",
|
| 34 |
+
neuron_alignment_table="Whether to show the neuron alignment table.",
|
| 35 |
+
correlated_neurons_table="Whether to show the correlated neurons table.",
|
| 36 |
+
correlated_features_table="Whether to show the (pairwise) correlated features table.",
|
| 37 |
+
correlated_b_features_table="Whether to show the correlated encoder-B features table.",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class BaseComponentConfig:
|
| 43 |
+
def data_is_contained_in(self, other: "BaseComponentConfig") -> bool:
|
| 44 |
+
"""
|
| 45 |
+
This returns False only when the data that was computed based on `other` wouldn't be enough to show the data
|
| 46 |
+
that was computed based on `self`. For instance, if `self` was a config object with 10 rows, and `other` had
|
| 47 |
+
just 5 rows, then this would return False. A less obvious example: if `self` was a histogram config with 50 bins
|
| 48 |
+
then `other` would need to have exactly 50 bins (because we can't change the bins after generating them).
|
| 49 |
+
"""
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def help_dict(self) -> dict[str, str]:
|
| 54 |
+
"""
|
| 55 |
+
This is a dictionary which maps the name of each argument to a description of what it does. This is used when
|
| 56 |
+
printing out the help for a config object, to show what each argument does.
|
| 57 |
+
"""
|
| 58 |
+
return {}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class PromptConfig(BaseComponentConfig):
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class SequencesConfig(BaseComponentConfig):
|
| 68 |
+
buffer: tuple[int, int] | None = (5, 5)
|
| 69 |
+
compute_buffer: bool = True
|
| 70 |
+
n_quantiles: int = 10
|
| 71 |
+
top_acts_group_size: int = 20
|
| 72 |
+
quantile_group_size: int = 5
|
| 73 |
+
top_logits_hoverdata: int = 5
|
| 74 |
+
stack_mode: Literal["stack-all", "stack-quantiles", "stack-none"] = "stack-all"
|
| 75 |
+
hover_below: bool = True
|
| 76 |
+
|
| 77 |
+
def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
|
| 78 |
+
assert isinstance(other, self.__class__)
|
| 79 |
+
return all(
|
| 80 |
+
[
|
| 81 |
+
self.buffer is None
|
| 82 |
+
or (
|
| 83 |
+
other.buffer is not None and self.buffer[0] <= other.buffer[0]
|
| 84 |
+
), # the buffer needs to be <=
|
| 85 |
+
self.buffer is None
|
| 86 |
+
or (other.buffer is not None and self.buffer[1] <= other.buffer[1]),
|
| 87 |
+
int(self.compute_buffer)
|
| 88 |
+
<= int(
|
| 89 |
+
other.compute_buffer
|
| 90 |
+
), # we can't compute the buffer if we didn't in `other`
|
| 91 |
+
self.n_quantiles
|
| 92 |
+
in {
|
| 93 |
+
0,
|
| 94 |
+
other.n_quantiles,
|
| 95 |
+
}, # we actually need the quantiles identical (or one to be zero)
|
| 96 |
+
self.top_acts_group_size
|
| 97 |
+
<= other.top_acts_group_size, # group size needs to be <=
|
| 98 |
+
self.quantile_group_size
|
| 99 |
+
<= other.quantile_group_size, # each quantile group needs to be <=
|
| 100 |
+
self.top_logits_hoverdata
|
| 101 |
+
<= other.top_logits_hoverdata, # hoverdata rows need to be <=
|
| 102 |
+
]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def __post_init__(self):
|
| 106 |
+
# Get list of group lengths, based on the config params
|
| 107 |
+
self.group_sizes = [self.top_acts_group_size] + [
|
| 108 |
+
self.quantile_group_size
|
| 109 |
+
] * self.n_quantiles
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def help_dict(self) -> dict[str, str]:
|
| 113 |
+
return SEQUENCES_CONFIG_HELP
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclass
|
| 117 |
+
class ActsHistogramConfig(BaseComponentConfig):
|
| 118 |
+
n_bins: int = 50
|
| 119 |
+
|
| 120 |
+
def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
|
| 121 |
+
assert isinstance(other, self.__class__)
|
| 122 |
+
return self.n_bins == other.n_bins
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def help_dict(self) -> dict[str, str]:
|
| 126 |
+
return ACTIVATIONS_HISTOGRAM_CONFIG_HELP
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class LogitsHistogramConfig(BaseComponentConfig):
|
| 131 |
+
n_bins: int = 50
|
| 132 |
+
|
| 133 |
+
def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
|
| 134 |
+
assert isinstance(other, self.__class__)
|
| 135 |
+
return self.n_bins == other.n_bins
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def help_dict(self) -> dict[str, str]:
|
| 139 |
+
return LOGITS_HISTOGRAM_CONFIG_HELP
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class LogitsTableConfig(BaseComponentConfig):
|
| 144 |
+
n_rows: int = 10
|
| 145 |
+
|
| 146 |
+
def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
|
| 147 |
+
assert isinstance(other, self.__class__)
|
| 148 |
+
return self.n_rows <= other.n_rows
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def help_dict(self) -> dict[str, str]:
|
| 152 |
+
return LOGITS_TABLE_CONFIG_HELP
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@dataclass
|
| 156 |
+
class FeatureTablesConfig(BaseComponentConfig):
|
| 157 |
+
n_rows: int = 3
|
| 158 |
+
neuron_alignment_table: bool = True
|
| 159 |
+
correlated_neurons_table: bool = True
|
| 160 |
+
correlated_features_table: bool = True
|
| 161 |
+
correlated_b_features_table: bool = False
|
| 162 |
+
|
| 163 |
+
def data_is_contained_in(self, other: BaseComponentConfig) -> bool:
|
| 164 |
+
assert isinstance(other, self.__class__)
|
| 165 |
+
return all(
|
| 166 |
+
[
|
| 167 |
+
self.n_rows <= other.n_rows,
|
| 168 |
+
self.neuron_alignment_table <= other.neuron_alignment_table,
|
| 169 |
+
self.correlated_neurons_table <= other.correlated_neurons_table,
|
| 170 |
+
self.correlated_features_table <= other.correlated_features_table,
|
| 171 |
+
self.correlated_b_features_table <= other.correlated_b_features_table,
|
| 172 |
+
]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def help_dict(self) -> dict[str, str]:
|
| 177 |
+
return FEATURE_TABLES_CONFIG_HELP
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
GenericComponentConfig = (
|
| 181 |
+
PromptConfig
|
| 182 |
+
| SequencesConfig
|
| 183 |
+
| ActsHistogramConfig
|
| 184 |
+
| LogitsHistogramConfig
|
| 185 |
+
| LogitsTableConfig
|
| 186 |
+
| FeatureTablesConfig
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class Column:
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
*args: GenericComponentConfig,
|
| 194 |
+
width: int | None = None,
|
| 195 |
+
):
|
| 196 |
+
self.components = list(args)
|
| 197 |
+
self.width = width
|
| 198 |
+
|
| 199 |
+
def __iter__(self) -> Iterator[Any]:
|
| 200 |
+
return iter(self.components)
|
| 201 |
+
|
| 202 |
+
def __getitem__(self, idx: int) -> Any:
|
| 203 |
+
return self.components[idx]
|
| 204 |
+
|
| 205 |
+
def __len__(self) -> int:
|
| 206 |
+
return len(self.components)
|
SAEDashboard/sae_dashboard/css/dropdown.css
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Styling of the dropdowns */
|
| 2 |
+
select {
|
| 3 |
+
appearance: none;
|
| 4 |
+
border: 0;
|
| 5 |
+
flex: 1;
|
| 6 |
+
padding: 0 1em;
|
| 7 |
+
background-color: #eee;
|
| 8 |
+
cursor: pointer;
|
| 9 |
+
}
|
| 10 |
+
.select {
|
| 11 |
+
box-shadow: 0 5px 5px rgba(0, 0, 0, 0.25);
|
| 12 |
+
cursor: pointer;
|
| 13 |
+
display: flex;
|
| 14 |
+
width: 100px;
|
| 15 |
+
height: 25px;
|
| 16 |
+
border-radius: .25em;
|
| 17 |
+
overflow: hidden;
|
| 18 |
+
position: relative;
|
| 19 |
+
margin-right: 15px;
|
| 20 |
+
}
|
| 21 |
+
.select::after {
|
| 22 |
+
position: absolute;
|
| 23 |
+
content: '\25BC';
|
| 24 |
+
font-size: 9px;
|
| 25 |
+
top: 0;
|
| 26 |
+
right: 0;
|
| 27 |
+
padding: 1em;
|
| 28 |
+
background-color: #ddd;
|
| 29 |
+
transition: .25s all ease;
|
| 30 |
+
pointer-events: none;
|
| 31 |
+
}
|
| 32 |
+
.select:hover::after {
|
| 33 |
+
color: black;
|
| 34 |
+
}
|
| 35 |
+
#dropdown-container {
|
| 36 |
+
margin-left: 10px;
|
| 37 |
+
margin-top: 20px;
|
| 38 |
+
display: flex;
|
| 39 |
+
flex-wrap: wrap;
|
| 40 |
+
}
|
SAEDashboard/sae_dashboard/css/general.css
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Styling of the top-level container */
|
| 2 |
+
.grid-container {
|
| 3 |
+
font-family: 'system-ui';
|
| 4 |
+
border: 1px solid #e6e6e6;
|
| 5 |
+
background-color: #fff;
|
| 6 |
+
margin: 30px 10px;
|
| 7 |
+
box-shadow: 0 5px 5px rgba(0, 0, 0, 0.25);
|
| 8 |
+
display: grid;
|
| 9 |
+
justify-content: start;
|
| 10 |
+
grid-template-columns: auto;
|
| 11 |
+
overflow-x: auto;
|
| 12 |
+
overflow-y: visible;
|
| 13 |
+
grid-auto-flow: column;
|
| 14 |
+
white-space: nowrap;
|
| 15 |
+
padding-bottom: 12px;
|
| 16 |
+
padding-top: 35px;
|
| 17 |
+
padding-left: 20px;
|
| 18 |
+
}
|
| 19 |
+
/* Styling each grid column (note, the max-height controls height of grid-container) */
|
| 20 |
+
.grid-column {
|
| 21 |
+
margin-left: 20px;
|
| 22 |
+
padding-right: 20px;
|
| 23 |
+
width: max-content;
|
| 24 |
+
overflow-y: auto;
|
| 25 |
+
max-height: 750px;
|
| 26 |
+
}
|
| 27 |
+
/* Styling the scrollbars */
|
| 28 |
+
::-webkit-scrollbar {
|
| 29 |
+
height: 10px;
|
| 30 |
+
width: 10px;
|
| 31 |
+
}
|
| 32 |
+
::-webkit-scrollbar-track {
|
| 33 |
+
background: #f1f1f1;
|
| 34 |
+
}
|
| 35 |
+
::-webkit-scrollbar-thumb {
|
| 36 |
+
background: #999;
|
| 37 |
+
}
|
| 38 |
+
::-webkit-scrollbar-thumb:hover {
|
| 39 |
+
background: #555;
|
| 40 |
+
}
|
| 41 |
+
/* Margin at the bottom of each histogram */
|
| 42 |
+
.plotly-hist {
|
| 43 |
+
margin-bottom: 25px;
|
| 44 |
+
}
|
| 45 |
+
/* Margins below the titles (most subtitles are h4, except for the prompt-centric view which has h2 titles) */
|
| 46 |
+
h4 {
|
| 47 |
+
margin-top: 0px;
|
| 48 |
+
margin-bottom: 10px;
|
| 49 |
+
}
|
| 50 |
+
/* Some space below the <hr> line in prompt-centric vis */
|
| 51 |
+
hr {
|
| 52 |
+
margin-bottom: 35px;
|
| 53 |
+
}
|
SAEDashboard/sae_dashboard/css/sequences.css
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Default font & appearance for the words in the sequence, before being hovered over */
|
| 2 |
+
code {
|
| 3 |
+
font-family: Consolas, Menlo, Monaco;
|
| 4 |
+
}
|
| 5 |
+
/* Margin at the bottom of every sequence group, plus handle how overflow works (maybe not necessary) */
|
| 6 |
+
.seq-group {
|
| 7 |
+
overflow-x: auto;
|
| 8 |
+
overflow-y: visible;
|
| 9 |
+
padding-top: 5px;
|
| 10 |
+
padding-bottom: 10px;
|
| 11 |
+
margin-bottom: 10px;
|
| 12 |
+
}
|
| 13 |
+
/* Margin between single sequences */
|
| 14 |
+
.seq {
|
| 15 |
+
margin-bottom: 11px;
|
| 16 |
+
}
|
| 17 |
+
/* Styling for each token in a sequence */
|
| 18 |
+
.token {
|
| 19 |
+
font-family: Consolas, Menlo, Monaco;
|
| 20 |
+
font-size: 0.9em;
|
| 21 |
+
border-top-left-radius: 3px;
|
| 22 |
+
border-top-right-radius: 3px;
|
| 23 |
+
padding: 1px;
|
| 24 |
+
color: black;
|
| 25 |
+
display: inline;
|
| 26 |
+
white-space: pre-wrap;
|
| 27 |
+
}
|
| 28 |
+
/* All the messy hovering stuff! */
|
| 29 |
+
.hover-text {
|
| 30 |
+
position: relative;
|
| 31 |
+
cursor: pointer;
|
| 32 |
+
display: inline-block; /* Needed to contain the tooltip */
|
| 33 |
+
box-sizing: border-box;
|
| 34 |
+
}
|
| 35 |
+
.tooltip {
|
| 36 |
+
background-color: #fff;
|
| 37 |
+
color: #333;
|
| 38 |
+
text-align: center;
|
| 39 |
+
border-radius: 10px;
|
| 40 |
+
padding: 5px;
|
| 41 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.25);
|
| 42 |
+
align-items: center;
|
| 43 |
+
justify-content: center;
|
| 44 |
+
overflow: hidden;
|
| 45 |
+
font-family: 'system-ui';
|
| 46 |
+
font-size: 1.1em;
|
| 47 |
+
display: none;
|
| 48 |
+
position: fixed;
|
| 49 |
+
z-index: 1000;
|
| 50 |
+
}
|
| 51 |
+
.token:hover {
|
| 52 |
+
border-top: 3px solid black;
|
| 53 |
+
}
|
| 54 |
+
.tooltip-container {
|
| 55 |
+
position: absolute;
|
| 56 |
+
pointer-events: none;
|
| 57 |
+
}
|
| 58 |
+
.hover-text:hover + .tooltip-container .tooltip {
|
| 59 |
+
display: block;
|
| 60 |
+
}
|
| 61 |
+
|
SAEDashboard/sae_dashboard/css/tables.css
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
table {
|
| 2 |
+
border: unset;
|
| 3 |
+
color: black;
|
| 4 |
+
border-collapse: collapse;
|
| 5 |
+
width: -moz-fit-content;
|
| 6 |
+
width: -webkit-fit-content;
|
| 7 |
+
width: fit-content;
|
| 8 |
+
margin-left: auto;
|
| 9 |
+
margin-right: auto;
|
| 10 |
+
font-size: 0.8em;
|
| 11 |
+
}
|
| 12 |
+
table.table-left tr {
|
| 13 |
+
border-bottom: 1px solid #eee;
|
| 14 |
+
padding: 15px;
|
| 15 |
+
}
|
| 16 |
+
table.table-left td {
|
| 17 |
+
padding: 3px 4px;
|
| 18 |
+
}
|
| 19 |
+
table.table-left {
|
| 20 |
+
width: 100%;
|
| 21 |
+
}
|
| 22 |
+
table.table-left td.left-aligned {
|
| 23 |
+
max-width: 120px;
|
| 24 |
+
overflow-x: hidden;
|
| 25 |
+
}
|
| 26 |
+
td {
|
| 27 |
+
border: none;
|
| 28 |
+
padding: 2px 4px;
|
| 29 |
+
white-space: nowrap;
|
| 30 |
+
}
|
| 31 |
+
.right-aligned {
|
| 32 |
+
text-align: right;
|
| 33 |
+
}
|
| 34 |
+
.left-aligned {
|
| 35 |
+
text-align: left;
|
| 36 |
+
}
|
| 37 |
+
.center-aligned {
|
| 38 |
+
text-align: center;
|
| 39 |
+
padding-bottom: 8px;
|
| 40 |
+
}
|
| 41 |
+
table code {
|
| 42 |
+
background-color: #ddd;
|
| 43 |
+
padding: 2px;
|
| 44 |
+
border-radius: 3px;
|
| 45 |
+
}
|
| 46 |
+
.table-container {
|
| 47 |
+
width: 100%;
|
| 48 |
+
}
|
| 49 |
+
.half-width-container {
|
| 50 |
+
display: flex;
|
| 51 |
+
}
|
| 52 |
+
.half-width {
|
| 53 |
+
width: 50%;
|
| 54 |
+
margin-right: -4px;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
/* Feature tables should have space below them, also they should have a min column width */
|
| 58 |
+
div.feature-tables table {
|
| 59 |
+
margin-bottom: 25px;
|
| 60 |
+
min-width: 250px;
|
| 61 |
+
}
|
| 62 |
+
/* Configure logits table container (i.e. the thing containing the smaller and larger tables) */
|
| 63 |
+
div.logits-table {
|
| 64 |
+
min-width: 375px;
|
| 65 |
+
display: flex;
|
| 66 |
+
overflow-x: hidden;
|
| 67 |
+
margin-bottom: 20px;
|
| 68 |
+
}
|
| 69 |
+
/* Code is always bold in this table (this is just the neg/pos string tokens) */
|
| 70 |
+
div.logits-table code {
|
| 71 |
+
font-weight: bold;
|
| 72 |
+
}
|
| 73 |
+
/* Set width of the tables inside the container (so they can stack horizontally), also put a gap between them */
|
| 74 |
+
div.logits-table > div.positive {
|
| 75 |
+
width: 47%;
|
| 76 |
+
}
|
| 77 |
+
div.logits-table > div.negative {
|
| 78 |
+
width: 47%;
|
| 79 |
+
margin-right: 5%;
|
| 80 |
+
}
|
| 81 |
+
|
SAEDashboard/sae_dashboard/data_parsing_fns.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import einops
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from eindex import eindex
|
| 5 |
+
from jaxtyping import Float, Int
|
| 6 |
+
from sae_lens import SAE
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from transformer_lens import HookedTransformer, utils
|
| 9 |
+
|
| 10 |
+
from sae_dashboard.components import LogitsTableData, SequenceData
|
| 11 |
+
from sae_dashboard.sae_vis_data import SaeVisData
|
| 12 |
+
from sae_dashboard.transformer_lens_wrapper import (
|
| 13 |
+
ActivationConfig,
|
| 14 |
+
TransformerLensWrapper,
|
| 15 |
+
to_resid_direction,
|
| 16 |
+
)
|
| 17 |
+
from sae_dashboard.utils_fns import RollingCorrCoef, TopK
|
| 18 |
+
|
| 19 |
+
Arr = np.ndarray
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_features_table_data(
|
| 23 |
+
feature_out_dir: Float[Tensor, "feats d_out"],
|
| 24 |
+
n_rows: int,
|
| 25 |
+
corrcoef_neurons: RollingCorrCoef | None = None,
|
| 26 |
+
corrcoef_encoder: RollingCorrCoef | None = None,
|
| 27 |
+
) -> dict[str, list[list[int]] | list[list[float]]]:
|
| 28 |
+
# ! Calculate all data for the left-hand column visualisations, i.e. the 3 tables
|
| 29 |
+
# Store kwargs (makes it easier to turn the tables on and off individually)
|
| 30 |
+
feature_tables_data: dict[str, list[list[int]] | list[list[float]]] = {}
|
| 31 |
+
|
| 32 |
+
# Table 1: neuron alignment, based on decoder weights
|
| 33 |
+
# if layout.feature_tables_cfg.neuron_alignment_table:
|
| 34 |
+
# Let's just always do this.
|
| 35 |
+
add_neuron_alignment_data(
|
| 36 |
+
feature_out_dir=feature_out_dir,
|
| 37 |
+
feature_tables_data=feature_tables_data,
|
| 38 |
+
n_rows=n_rows,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Table 2: neurons correlated with this feature, based on their activations
|
| 42 |
+
if corrcoef_neurons is not None:
|
| 43 |
+
add_feature_neuron_correlations(
|
| 44 |
+
corrcoef_neurons=corrcoef_neurons,
|
| 45 |
+
feature_tables_data=feature_tables_data,
|
| 46 |
+
n_rows=n_rows,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Table 3: primary encoder features correlated with this feature, based on their activations
|
| 50 |
+
if corrcoef_encoder is not None:
|
| 51 |
+
add_intra_encoder_correlations(
|
| 52 |
+
corrcoef_encoder=corrcoef_encoder,
|
| 53 |
+
feature_tables_data=feature_tables_data,
|
| 54 |
+
n_rows=n_rows,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return feature_tables_data
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def add_intra_encoder_correlations(
|
| 61 |
+
corrcoef_encoder: RollingCorrCoef,
|
| 62 |
+
feature_tables_data: dict[str, list[list[int]] | list[list[float]]],
|
| 63 |
+
n_rows: int,
|
| 64 |
+
):
|
| 65 |
+
enc_indices, enc_pearson, enc_cossim = corrcoef_encoder.topk_pearson(
|
| 66 |
+
k=n_rows,
|
| 67 |
+
)
|
| 68 |
+
feature_tables_data["correlated_features_indices"] = enc_indices
|
| 69 |
+
feature_tables_data["correlated_features_pearson"] = enc_pearson
|
| 70 |
+
feature_tables_data["correlated_features_cossim"] = enc_cossim
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def add_neuron_alignment_data(
|
| 74 |
+
feature_out_dir: Float[Tensor, "feats d_out"],
|
| 75 |
+
feature_tables_data: dict[str, list[list[int]] | list[list[float]]],
|
| 76 |
+
n_rows: int,
|
| 77 |
+
):
|
| 78 |
+
top3_neurons_aligned = TopK(tensor=feature_out_dir.float(), k=n_rows, largest=True)
|
| 79 |
+
feature_out_l1_norm = feature_out_dir.abs().sum(dim=-1, keepdim=True)
|
| 80 |
+
pct_of_l1: Arr = np.absolute(top3_neurons_aligned.values) / utils.to_numpy(
|
| 81 |
+
feature_out_l1_norm.float()
|
| 82 |
+
)
|
| 83 |
+
feature_tables_data["neuron_alignment_indices"] = (
|
| 84 |
+
top3_neurons_aligned.indices.tolist()
|
| 85 |
+
)
|
| 86 |
+
feature_tables_data["neuron_alignment_values"] = (
|
| 87 |
+
top3_neurons_aligned.values.tolist()
|
| 88 |
+
)
|
| 89 |
+
feature_tables_data["neuron_alignment_l1"] = pct_of_l1.tolist()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def add_feature_neuron_correlations(
|
| 93 |
+
corrcoef_neurons: RollingCorrCoef,
|
| 94 |
+
feature_tables_data: dict[str, list[list[int]] | list[list[float]]],
|
| 95 |
+
n_rows: int,
|
| 96 |
+
):
|
| 97 |
+
neuron_indices, neuron_pearson, neuron_cossim = corrcoef_neurons.topk_pearson(
|
| 98 |
+
k=n_rows,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
feature_tables_data["correlated_neurons_indices"] = neuron_indices
|
| 102 |
+
feature_tables_data["correlated_neurons_pearson"] = neuron_pearson
|
| 103 |
+
feature_tables_data["correlated_neurons_cossim"] = neuron_cossim
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_logits_table_data(
|
| 107 |
+
logit_vector: Float[Tensor, "d_vocab"], n_rows: int # noqa: F821
|
| 108 |
+
):
|
| 109 |
+
# Get logits table data
|
| 110 |
+
top_logits = TopK(logit_vector.float(), k=n_rows, largest=True)
|
| 111 |
+
bottom_logits = TopK(logit_vector.float(), k=n_rows, largest=False)
|
| 112 |
+
|
| 113 |
+
top_logit_values = top_logits.values.tolist()
|
| 114 |
+
top_token_ids = top_logits.indices.tolist()
|
| 115 |
+
|
| 116 |
+
bottom_logit_values = bottom_logits.values.tolist()
|
| 117 |
+
bottom_token_ids = bottom_logits.indices.tolist()
|
| 118 |
+
|
| 119 |
+
logits_table_data = LogitsTableData(
|
| 120 |
+
bottom_logits=bottom_logit_values,
|
| 121 |
+
bottom_token_ids=bottom_token_ids,
|
| 122 |
+
top_logits=top_logit_values,
|
| 123 |
+
top_token_ids=top_token_ids,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return logits_table_data
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# @torch.inference_mode()
|
| 130 |
+
# def get_feature_data(
|
| 131 |
+
# encoder: AutoEncoder,
|
| 132 |
+
# model: HookedTransformer,
|
| 133 |
+
# tokens: Int[Tensor, "batch seq"],
|
| 134 |
+
# cfg: SaeVisConfig,
|
| 135 |
+
# ) -> SaeVisData:
|
| 136 |
+
# """
|
| 137 |
+
# This is the main function which users will run to generate the feature visualization data. It batches this
|
| 138 |
+
# computation over features, in accordance with the arguments in the SaeVisConfig object (we don't want to compute all
|
| 139 |
+
# the features at once, since might give OOMs).
|
| 140 |
+
|
| 141 |
+
# See the `_get_feature_data` function for an explanation of the arguments, as well as a more detailed explanation
|
| 142 |
+
# of what this function is doing.
|
| 143 |
+
|
| 144 |
+
# The return object is the merged SaeVisData objects returned by the `_get_feature_data` function.
|
| 145 |
+
# """
|
| 146 |
+
# pass
|
| 147 |
+
|
| 148 |
+
# # return sae_vis_data
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@torch.inference_mode()
|
| 152 |
+
def parse_prompt_data(
|
| 153 |
+
tokens: Int[Tensor, "batch seq"],
|
| 154 |
+
str_toks: list[str],
|
| 155 |
+
sae_vis_data: SaeVisData,
|
| 156 |
+
feat_acts: Float[Tensor, "seq feats"],
|
| 157 |
+
feature_resid_dir: Float[Tensor, "feats d_model"],
|
| 158 |
+
resid_post: Float[Tensor, "seq d_model"],
|
| 159 |
+
W_U: Float[Tensor, "d_model d_vocab"],
|
| 160 |
+
feature_idx: list[int] | None = None,
|
| 161 |
+
num_top_features: int = 10,
|
| 162 |
+
) -> dict[str, tuple[list[int], list[str]]]:
|
| 163 |
+
"""
|
| 164 |
+
Gets data needed to create the sequences in the prompt-centric vis (displaying dashboards for the most relevant
|
| 165 |
+
features on a prompt).
|
| 166 |
+
|
| 167 |
+
This function exists so that prompt dashboards can be generated without using our AutoEncoder or
|
| 168 |
+
TransformerLens(Wrapper) classes.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
tokens: Int[Tensor, "batch seq"]
|
| 172 |
+
The tokens we'll be using to get the feature activations. Note that we might not be using all of them; the
|
| 173 |
+
number used is determined by `fvp.total_batch_size`.
|
| 174 |
+
|
| 175 |
+
str_toks: list[str]
|
| 176 |
+
The tokens as a list of strings, so that they can be visualized in HTML.
|
| 177 |
+
|
| 178 |
+
sae_vis_data: SaeVisData
|
| 179 |
+
The object storing all data for each feature. We'll set each `feature_data.prompt_data` to the
|
| 180 |
+
data we get from `prompt`.
|
| 181 |
+
|
| 182 |
+
feat_acts: Float[Tensor, "seq feats"]
|
| 183 |
+
The activations values of the features across the sequence.
|
| 184 |
+
|
| 185 |
+
feature_resid_dir: Float[Tensor, "feats d_model"]
|
| 186 |
+
The directions that each feature writes to the residual stream.
|
| 187 |
+
|
| 188 |
+
resid_post: Float[Tensor, "seq d_model"]
|
| 189 |
+
The activations of the final layer of the model before the unembed.
|
| 190 |
+
|
| 191 |
+
W_U: Float[Tensor, "d_model d_vocab"]
|
| 192 |
+
The model's unembed weights for the logit lens.
|
| 193 |
+
|
| 194 |
+
feature_idx: list[int] or None
|
| 195 |
+
The features we're actually computing. These might just be a subset of the model's full features.
|
| 196 |
+
|
| 197 |
+
num_top_features: int
|
| 198 |
+
The number of top features to display in this view, for any given metric.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
scores_dict: dict[str, tuple[list[int], list[str]]]
|
| 202 |
+
A dictionary mapping keys like "act_quantile|'django' (0)" to a tuple of lists, where the first list is the
|
| 203 |
+
feature indices, and the second list is the string-formatted values of the scores.
|
| 204 |
+
|
| 205 |
+
As well as returning this dictionary, this function will also set `FeatureData.prompt_data` for each feature in
|
| 206 |
+
`sae_vis_data` (this is necessary for getting the prompts in the prompt-centric vis). Note this design choice could
|
| 207 |
+
have been done differently (i.e. have this function return a list of the prompt data for each feature). I chose this
|
| 208 |
+
way because it means the FeatureData._get_html_data_prompt_centric can work fundamentally the same way as
|
| 209 |
+
FeatureData._get_html_data_feature_centric, rather than treating the prompt data object as a different kind of
|
| 210 |
+
component in the vis.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
device = sae_vis_data.cfg.device
|
| 214 |
+
|
| 215 |
+
if feature_idx is None:
|
| 216 |
+
feature_idx = list(sae_vis_data.feature_data_dict.keys())
|
| 217 |
+
n_feats = len(feature_idx)
|
| 218 |
+
assert (
|
| 219 |
+
feature_resid_dir.shape[0] == n_feats
|
| 220 |
+
), f"The number of features in feature_resid_dir ({feature_resid_dir.shape[0]}) does not match the number of feature indices ({n_feats})"
|
| 221 |
+
|
| 222 |
+
assert (
|
| 223 |
+
feat_acts.shape[1] == n_feats
|
| 224 |
+
), f"The number of features in feat_acts ({feat_acts.shape[1]}) does not match the number of feature indices ({n_feats})"
|
| 225 |
+
|
| 226 |
+
feats_loss_contribution = torch.empty(
|
| 227 |
+
size=(n_feats, tokens.shape[1] - 1), device=device
|
| 228 |
+
)
|
| 229 |
+
# Some logit computations which we only need to do once
|
| 230 |
+
# correct_token_unembeddings = model_wrapped.W_U[:, tokens[0, 1:]] # [d_model seq]
|
| 231 |
+
orig_logits = (
|
| 232 |
+
resid_post / resid_post.std(dim=-1, keepdim=True)
|
| 233 |
+
) @ W_U # [seq d_vocab]
|
| 234 |
+
raw_logits = feature_resid_dir @ W_U # [feats d_vocab]
|
| 235 |
+
|
| 236 |
+
for i, feat in enumerate(feature_idx):
|
| 237 |
+
# ! Calculate the sequence data for each feature, and store it as FeatureData.prompt_data
|
| 238 |
+
|
| 239 |
+
# Get this feature's output vector, using an outer product over the feature activations for all tokens
|
| 240 |
+
resid_post_feature_effect = einops.einsum(
|
| 241 |
+
feat_acts[:, i], feature_resid_dir[i], "seq, d_model -> seq d_model"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Ablate the output vector from the residual stream, and get logits post-ablation
|
| 245 |
+
new_resid_post = resid_post - resid_post_feature_effect
|
| 246 |
+
new_logits = (new_resid_post / new_resid_post.std(dim=-1, keepdim=True)) @ W_U
|
| 247 |
+
|
| 248 |
+
# Get the top5 & bottom5 changes in logits (don't bother with `efficient_topk` cause it's small)
|
| 249 |
+
contribution_to_logprobs = orig_logits.log_softmax(
|
| 250 |
+
dim=-1
|
| 251 |
+
) - new_logits.log_softmax(dim=-1)
|
| 252 |
+
top_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5)
|
| 253 |
+
bottom_contribution_to_logits = TopK(
|
| 254 |
+
contribution_to_logprobs[:-1], k=5, largest=False
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Get the change in loss (which is negative of change of logprobs for correct token)
|
| 258 |
+
loss_contribution = eindex(
|
| 259 |
+
-contribution_to_logprobs[:-1], tokens[0, 1:], "seq [seq]"
|
| 260 |
+
)
|
| 261 |
+
feats_loss_contribution[i, :] = loss_contribution
|
| 262 |
+
|
| 263 |
+
# Store the sequence data
|
| 264 |
+
sae_vis_data.feature_data_dict[feat].prompt_data = SequenceData(
|
| 265 |
+
token_ids=tokens.squeeze(0).tolist(),
|
| 266 |
+
feat_acts=[round(f, 4) for f in feat_acts[:, i].tolist()],
|
| 267 |
+
loss_contribution=[0.0] + loss_contribution.tolist(),
|
| 268 |
+
token_logits=raw_logits[i, tokens.squeeze(0)].tolist(),
|
| 269 |
+
top_token_ids=top_contribution_to_logits.indices.tolist(),
|
| 270 |
+
top_logits=top_contribution_to_logits.values.tolist(),
|
| 271 |
+
bottom_token_ids=bottom_contribution_to_logits.indices.tolist(),
|
| 272 |
+
bottom_logits=bottom_contribution_to_logits.values.tolist(),
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# ! Lastly, return a dictionary mapping each key like 'act_quantile|"django" (0)' to a list of feature indices & scores
|
| 276 |
+
|
| 277 |
+
# Get a dict with keys like f"act_quantile|'My' (1)" and values (feature indices list, feature score values list)
|
| 278 |
+
scores_dict: dict[str, tuple[list[int], list[str]]] = {}
|
| 279 |
+
|
| 280 |
+
for seq_pos, seq_key in enumerate([f"{t!r} ({i})" for i, t in enumerate(str_toks)]):
|
| 281 |
+
# Filter the feature activations, since we only need the ones that are non-zero
|
| 282 |
+
feat_acts_nonzero_filter = utils.to_numpy(feat_acts[seq_pos] > 0)
|
| 283 |
+
feat_acts_nonzero_locations = np.nonzero(feat_acts_nonzero_filter)[0].tolist()
|
| 284 |
+
_feat_acts = feat_acts[seq_pos, feat_acts_nonzero_filter] # [feats_filtered,]
|
| 285 |
+
_feature_idx = np.array(feature_idx)[feat_acts_nonzero_filter]
|
| 286 |
+
|
| 287 |
+
if feat_acts_nonzero_filter.sum() > 0:
|
| 288 |
+
k = min(num_top_features, _feat_acts.numel())
|
| 289 |
+
|
| 290 |
+
# Get the top features by activation size. This is just applying a TopK function to the feat acts (which
|
| 291 |
+
# were stored by the code before this). The feat acts are formatted to 3dp.
|
| 292 |
+
act_size_topk = TopK(_feat_acts, k=k, largest=True)
|
| 293 |
+
top_features = _feature_idx[act_size_topk.indices].tolist()
|
| 294 |
+
formatted_scores = [f"{v:.3f}" for v in act_size_topk.values]
|
| 295 |
+
scores_dict[f"act_size|{seq_key}"] = (top_features, formatted_scores)
|
| 296 |
+
|
| 297 |
+
# Get the top features by activation quantile. We do this using the `feature_act_quantiles` object, which
|
| 298 |
+
# was stored `sae_vis_data`. This quantiles object has a method to return quantiles for a given set of
|
| 299 |
+
# data, as well as the precision (we make the precision higher for quantiles closer to 100%, because these
|
| 300 |
+
# are usually the quantiles we're interested in, and it lets us to save space in `feature_act_quantiles`).
|
| 301 |
+
act_quantile, act_precision = sae_vis_data.feature_stats.get_quantile(
|
| 302 |
+
_feat_acts, feat_acts_nonzero_locations
|
| 303 |
+
)
|
| 304 |
+
act_quantile_topk = TopK(act_quantile, k=k, largest=True)
|
| 305 |
+
act_formatting = [
|
| 306 |
+
f".{act_precision[i]-2}%" for i in act_quantile_topk.indices
|
| 307 |
+
]
|
| 308 |
+
top_features = _feature_idx[act_quantile_topk.indices].tolist()
|
| 309 |
+
formatted_scores = [
|
| 310 |
+
f"{v:{f}}" for v, f in zip(act_quantile_topk.values, act_formatting)
|
| 311 |
+
]
|
| 312 |
+
scores_dict[f"act_quantile|{seq_key}"] = (top_features, formatted_scores)
|
| 313 |
+
|
| 314 |
+
# We don't measure loss effect on the first token
|
| 315 |
+
if seq_pos == 0:
|
| 316 |
+
continue
|
| 317 |
+
|
| 318 |
+
# Filter the loss effects, since we only need the ones which have non-zero feature acts on the tokens before them
|
| 319 |
+
prev_feat_acts_nonzero_filter = utils.to_numpy(feat_acts[seq_pos - 1] > 0)
|
| 320 |
+
_loss_contribution = feats_loss_contribution[
|
| 321 |
+
prev_feat_acts_nonzero_filter, seq_pos - 1
|
| 322 |
+
] # [feats_filtered,]
|
| 323 |
+
_feature_idx_prev = np.array(feature_idx)[prev_feat_acts_nonzero_filter]
|
| 324 |
+
|
| 325 |
+
if prev_feat_acts_nonzero_filter.sum() > 0:
|
| 326 |
+
k = min(num_top_features, _loss_contribution.numel())
|
| 327 |
+
|
| 328 |
+
# Get the top features by loss effect. This is just applying a TopK function to the loss effects (which were
|
| 329 |
+
# stored by the code before this). The loss effects are formatted to 3dp. We look for the most negative
|
| 330 |
+
# values, i.e. the most loss-reducing features.
|
| 331 |
+
loss_contribution_topk = TopK(_loss_contribution, k=k, largest=False)
|
| 332 |
+
top_features = _feature_idx_prev[loss_contribution_topk.indices].tolist()
|
| 333 |
+
formatted_scores = [f"{v:+.3f}" for v in loss_contribution_topk.values]
|
| 334 |
+
scores_dict[f"loss_effect|{seq_key}"] = (top_features, formatted_scores)
|
| 335 |
+
return scores_dict
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@torch.inference_mode()
|
| 339 |
+
def get_prompt_data(
|
| 340 |
+
sae_vis_data: SaeVisData,
|
| 341 |
+
prompt: str,
|
| 342 |
+
num_top_features: int,
|
| 343 |
+
) -> dict[str, tuple[list[int], list[str]]]:
|
| 344 |
+
"""
|
| 345 |
+
Gets data that will be used to create the sequences in the prompt-centric HTML visualisation, i.e. an object of
|
| 346 |
+
type SequenceData for each of our features.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
sae_vis_data: The object storing all data for each feature. We'll set each `feature_data.prompt_data` to the
|
| 350 |
+
data we get from `prompt`.
|
| 351 |
+
prompt: The prompt we'll be using to get the feature activations.#
|
| 352 |
+
num_top_features: The number of top features we'll be getting data for.
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
scores_dict: A dictionary mapping keys like "act_quantile|0" to a tuple of lists, where the first list is
|
| 356 |
+
the feature indices, and the second list is the string-formatted values of the scores.
|
| 357 |
+
|
| 358 |
+
As well as returning this dictionary, this function will also set `FeatureData.prompt_data` for each feature in
|
| 359 |
+
`sae_vis_data`. This is because the prompt-centric vis will call `FeatureData._get_html_data_prompt_centric` on each
|
| 360 |
+
feature data object, so it's useful to have all the data in once place! Even if this will get overwritten next
|
| 361 |
+
time we call `get_prompt_data` for this same `sae_vis_data` object.
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
# ! Boring setup code
|
| 365 |
+
feature_idx = list(sae_vis_data.feature_data_dict.keys())
|
| 366 |
+
encoder = sae_vis_data.encoder
|
| 367 |
+
assert isinstance(encoder, SAE)
|
| 368 |
+
model = sae_vis_data.model
|
| 369 |
+
assert isinstance(model, HookedTransformer)
|
| 370 |
+
cfg = sae_vis_data.cfg
|
| 371 |
+
assert isinstance(cfg.hook_point, str), f"{cfg.hook_point=}, expected a string"
|
| 372 |
+
|
| 373 |
+
str_toks: list[str] = model.tokenizer.tokenize(prompt) # type: ignore
|
| 374 |
+
tokens = model.tokenizer.encode(prompt, return_tensors="pt").to( # type: ignore
|
| 375 |
+
sae_vis_data.cfg.device
|
| 376 |
+
)
|
| 377 |
+
assert isinstance(tokens, torch.Tensor)
|
| 378 |
+
|
| 379 |
+
model_wrapped = TransformerLensWrapper(model, ActivationConfig(cfg.hook_point, [])) # type: ignore
|
| 380 |
+
|
| 381 |
+
feature_act_dir = encoder.W_enc[:, feature_idx] # [d_in feats]
|
| 382 |
+
feature_out_dir = encoder.W_dec[feature_idx] # [feats d_in]
|
| 383 |
+
feature_resid_dir = to_resid_direction(
|
| 384 |
+
feature_out_dir, model_wrapped
|
| 385 |
+
) # [feats d_model]
|
| 386 |
+
assert (
|
| 387 |
+
feature_act_dir.T.shape
|
| 388 |
+
== feature_out_dir.shape
|
| 389 |
+
== (len(feature_idx), encoder.cfg.d_in)
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# ! Define hook functions to cache all the info required for feature ablation, then run those hook fns
|
| 393 |
+
resid_post, act_post = model_wrapped(tokens, return_logits=False)
|
| 394 |
+
resid_post: Tensor = resid_post.squeeze(0) # type: ignore
|
| 395 |
+
feat_acts = encoder.get_feature_acts_subset(act_post, feature_idx).squeeze( # type: ignore
|
| 396 |
+
0
|
| 397 |
+
) # [seq feats] # type: ignore
|
| 398 |
+
|
| 399 |
+
# ! Use the data we've collected to make the scores_dict and update the sae_vis_data
|
| 400 |
+
scores_dict = parse_prompt_data(
|
| 401 |
+
tokens=tokens,
|
| 402 |
+
str_toks=str_toks,
|
| 403 |
+
sae_vis_data=sae_vis_data,
|
| 404 |
+
feat_acts=feat_acts,
|
| 405 |
+
feature_resid_dir=feature_resid_dir,
|
| 406 |
+
resid_post=resid_post,
|
| 407 |
+
W_U=model.W_U,
|
| 408 |
+
feature_idx=feature_idx,
|
| 409 |
+
num_top_features=num_top_features,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
return scores_dict
|
SAEDashboard/sae_dashboard/data_writing_fns.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
|
| 7 |
+
from sae_dashboard.data_parsing_fns import get_prompt_data
|
| 8 |
+
from sae_dashboard.html_fns import HTML
|
| 9 |
+
from sae_dashboard.sae_vis_data import SaeVisData
|
| 10 |
+
from sae_dashboard.utils_fns import get_decode_html_safe_fn
|
| 11 |
+
|
| 12 |
+
METRIC_TITLES = {
|
| 13 |
+
"act_size": "Activation Size",
|
| 14 |
+
"act_quantile": "Activation Quantile",
|
| 15 |
+
"loss_effect": "Loss Effect",
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def save_feature_centric_vis(
|
| 20 |
+
sae_vis_data: SaeVisData,
|
| 21 |
+
filename: str | Path,
|
| 22 |
+
feature_idx: int | None = None,
|
| 23 |
+
include_only: list[int] | None = None,
|
| 24 |
+
separate_files: bool = False,
|
| 25 |
+
) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Returns the HTML string for the view which lets you navigate between different features.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
sae_vis_data: Object containing visualization data.
|
| 31 |
+
filename: The HTML filepath we'll save the visualization to. If separate_files is True, this is used as a base name.
|
| 32 |
+
feature_idx: This is the default feature index we'll start on. If None, we use the first feature.
|
| 33 |
+
include_only: Optional list of specific features to include.
|
| 34 |
+
separate_files: If True, saves each feature to a separate HTML file.
|
| 35 |
+
"""
|
| 36 |
+
# Set the default argument for the dropdown (i.e. when the page first loads)
|
| 37 |
+
first_feature = (
|
| 38 |
+
next(iter(sae_vis_data.feature_data_dict))
|
| 39 |
+
if (feature_idx is None)
|
| 40 |
+
else feature_idx
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Get tokenize function (we only need to define it once)
|
| 44 |
+
assert sae_vis_data.model is not None
|
| 45 |
+
assert sae_vis_data.model.tokenizer is not None
|
| 46 |
+
decode_fn = get_decode_html_safe_fn(sae_vis_data.model.tokenizer)
|
| 47 |
+
|
| 48 |
+
# Create iterator
|
| 49 |
+
if include_only is not None:
|
| 50 |
+
iterator = [(i, sae_vis_data.feature_data_dict[i]) for i in include_only]
|
| 51 |
+
else:
|
| 52 |
+
iterator = list(sae_vis_data.feature_data_dict.items())
|
| 53 |
+
if sae_vis_data.cfg.verbose:
|
| 54 |
+
iterator = tqdm(iterator, desc="Saving feature-centric vis")
|
| 55 |
+
|
| 56 |
+
HTML_OBJ = HTML() # Initialize HTML object for combined file
|
| 57 |
+
|
| 58 |
+
# For each FeatureData object, we get the html_obj for its feature-centric vis
|
| 59 |
+
for feature, feature_data in iterator:
|
| 60 |
+
html_obj = feature_data._get_html_data_feature_centric(
|
| 61 |
+
sae_vis_data.cfg.feature_centric_layout, decode_fn
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if separate_files:
|
| 65 |
+
feature_HTML_OBJ = HTML() # Initialize a new HTML object for each feature
|
| 66 |
+
feature_HTML_OBJ.js_data[str(feature)] = deepcopy(html_obj.js_data)
|
| 67 |
+
feature_HTML_OBJ.html_data = deepcopy(html_obj.html_data)
|
| 68 |
+
|
| 69 |
+
# Add the aggdata
|
| 70 |
+
feature_HTML_OBJ.js_data = {
|
| 71 |
+
"AGGDATA": sae_vis_data.feature_stats.aggdata,
|
| 72 |
+
"DASHBOARD_DATA": feature_HTML_OBJ.js_data,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Generate filename for this feature
|
| 76 |
+
feature_filename = Path(filename).with_stem(
|
| 77 |
+
f"{Path(filename).stem}_feature_{feature}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Save the HTML for this feature
|
| 81 |
+
feature_HTML_OBJ.get_html(
|
| 82 |
+
layout_columns=sae_vis_data.cfg.feature_centric_layout.columns,
|
| 83 |
+
layout_height=sae_vis_data.cfg.feature_centric_layout.height,
|
| 84 |
+
filename=feature_filename,
|
| 85 |
+
first_key=str(feature),
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
# Original behavior: accumulate all features in one HTML object
|
| 89 |
+
HTML_OBJ.js_data[str(feature)] = deepcopy(html_obj.js_data)
|
| 90 |
+
if feature == first_feature:
|
| 91 |
+
HTML_OBJ.html_data = deepcopy(html_obj.html_data)
|
| 92 |
+
|
| 93 |
+
if not separate_files:
|
| 94 |
+
# Add the aggdata
|
| 95 |
+
HTML_OBJ.js_data = {
|
| 96 |
+
"AGGDATA": sae_vis_data.feature_stats.aggdata,
|
| 97 |
+
"DASHBOARD_DATA": HTML_OBJ.js_data,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Save our full HTML
|
| 101 |
+
HTML_OBJ.get_html(
|
| 102 |
+
layout_columns=sae_vis_data.cfg.feature_centric_layout.columns,
|
| 103 |
+
layout_height=sae_vis_data.cfg.feature_centric_layout.height,
|
| 104 |
+
filename=filename,
|
| 105 |
+
first_key=str(first_feature),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def save_prompt_centric_vis(
|
| 110 |
+
sae_vis_data: SaeVisData,
|
| 111 |
+
prompt: str,
|
| 112 |
+
filename: str | Path,
|
| 113 |
+
metric: str | None = None,
|
| 114 |
+
seq_pos: int | None = None,
|
| 115 |
+
num_top_features: int = 10,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Returns the HTML string for the view which lets you navigate between different features.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
prompt: The user-input prompt.
|
| 122 |
+
model: Used to get the tokenizer (for converting token IDs to string tokens).
|
| 123 |
+
filename: The HTML filepath we'll save the visualization to.
|
| 124 |
+
metric: This is the default scoring metric we'll start on. If None, we use 'act_quantile'.
|
| 125 |
+
seq_pos: This is the default seq pos we'll start on. If None, we use 0.
|
| 126 |
+
"""
|
| 127 |
+
# Initialize the object we'll eventually get_html from
|
| 128 |
+
HTML_OBJ = HTML()
|
| 129 |
+
|
| 130 |
+
# Run forward passes on our prompt, and store the data within each FeatureData object as `self.prompt_data` as
|
| 131 |
+
# well as returning the scores_dict (which maps from score hash to a list of feature indices & formatted scores)
|
| 132 |
+
|
| 133 |
+
scores_dict = get_prompt_data(
|
| 134 |
+
sae_vis_data=sae_vis_data,
|
| 135 |
+
prompt=prompt,
|
| 136 |
+
num_top_features=num_top_features,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Get all possible values for dropdowns
|
| 140 |
+
str_toks = sae_vis_data.model.tokenizer.tokenize(prompt) # type: ignore
|
| 141 |
+
str_toks = [
|
| 142 |
+
t.replace("|", "│") for t in str_toks
|
| 143 |
+
] # vertical line -> pipe (hacky, so key splitting on | works)
|
| 144 |
+
str_toks_list = [f"{t!r} ({i})" for i, t in enumerate(str_toks)]
|
| 145 |
+
metric_list = ["act_quantile", "act_size", "loss_effect"]
|
| 146 |
+
|
| 147 |
+
# Get default values for dropdowns
|
| 148 |
+
first_metric = "act_quantile" or metric
|
| 149 |
+
first_seq_pos = str_toks_list[0 if seq_pos is None else seq_pos]
|
| 150 |
+
first_key = f"{first_metric}|{first_seq_pos}"
|
| 151 |
+
|
| 152 |
+
# Get tokenize function (we only need to define it once)
|
| 153 |
+
assert sae_vis_data.model is not None
|
| 154 |
+
assert sae_vis_data.model.tokenizer is not None
|
| 155 |
+
decode_fn = get_decode_html_safe_fn(sae_vis_data.model.tokenizer)
|
| 156 |
+
|
| 157 |
+
# For each (metric, seqpos) object, we merge the prompt-centric views of each of the top features, then we merge
|
| 158 |
+
# these all together into our HTML_OBJ
|
| 159 |
+
for _metric, _seq_pos in itertools.product(metric_list, range(len(str_toks))):
|
| 160 |
+
# Create the key for this given combination of metric & seqpos, and get our top features & scores
|
| 161 |
+
key = f"{_metric}|{str_toks_list[_seq_pos]}"
|
| 162 |
+
if key not in scores_dict:
|
| 163 |
+
continue
|
| 164 |
+
feature_idx_list, scores_formatted = scores_dict[key]
|
| 165 |
+
|
| 166 |
+
# Create HTML object, to store each feature column for all the top features for this particular key
|
| 167 |
+
html_obj = HTML()
|
| 168 |
+
|
| 169 |
+
for i, (feature_idx, score_formatted) in enumerate(
|
| 170 |
+
zip(feature_idx_list, scores_formatted)
|
| 171 |
+
):
|
| 172 |
+
# Get HTML object at this column (which includes JavaScript to dynamically set the title)
|
| 173 |
+
html_obj += sae_vis_data.feature_data_dict[
|
| 174 |
+
feature_idx
|
| 175 |
+
]._get_html_data_prompt_centric(
|
| 176 |
+
layout=sae_vis_data.cfg.prompt_centric_layout,
|
| 177 |
+
decode_fn=decode_fn,
|
| 178 |
+
column_idx=i,
|
| 179 |
+
bold_idx=_seq_pos,
|
| 180 |
+
title=f"<h3>#{feature_idx}<br>{METRIC_TITLES[_metric]} = {score_formatted}</h3><hr>",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Add the JavaScript (which includes the titles for each column)
|
| 184 |
+
HTML_OBJ.js_data[key] = deepcopy(html_obj.js_data)
|
| 185 |
+
|
| 186 |
+
# Set the HTML data to be the one with the most columns (since different options might have fewer cols)
|
| 187 |
+
if len(HTML_OBJ.html_data) < len(html_obj.html_data):
|
| 188 |
+
HTML_OBJ.html_data = deepcopy(html_obj.html_data)
|
| 189 |
+
|
| 190 |
+
# Check our first key is in the scores_dict (if not, we should pick a different key)
|
| 191 |
+
assert first_key in scores_dict, "\n".join(
|
| 192 |
+
[
|
| 193 |
+
f"Key {first_key} not found in {scores_dict.keys()=}.",
|
| 194 |
+
"This means that there are no features with a nontrivial score for this choice of key & metric.",
|
| 195 |
+
]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Add the aggdata
|
| 199 |
+
HTML_OBJ.js_data = {
|
| 200 |
+
"AGGDATA": sae_vis_data.feature_stats.aggdata,
|
| 201 |
+
"DASHBOARD_DATA": HTML_OBJ.js_data,
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# Save our full HTML
|
| 205 |
+
HTML_OBJ.get_html(
|
| 206 |
+
layout_columns=sae_vis_data.cfg.prompt_centric_layout.columns,
|
| 207 |
+
layout_height=sae_vis_data.cfg.prompt_centric_layout.height,
|
| 208 |
+
filename=filename,
|
| 209 |
+
first_key=first_key,
|
| 210 |
+
)
|
SAEDashboard/sae_dashboard/dfa_calculator.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Union
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from sae_lens import SAE
|
| 7 |
+
from transformer_lens import ActivationCache, HookedTransformer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DFACalculator:
|
| 11 |
+
"""Calculate DFA values for a given layer and set of feature indices."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, model: HookedTransformer, sae: SAE[Any]):
|
| 14 |
+
self.model = model
|
| 15 |
+
self.sae = sae
|
| 16 |
+
if (
|
| 17 |
+
hasattr(model.cfg, "n_key_value_heads")
|
| 18 |
+
and model.cfg.n_key_value_heads is not None
|
| 19 |
+
and model.cfg.n_key_value_heads < model.cfg.n_heads
|
| 20 |
+
):
|
| 21 |
+
print("Using GQA")
|
| 22 |
+
self.use_gqa = True
|
| 23 |
+
else:
|
| 24 |
+
self.use_gqa = False
|
| 25 |
+
|
| 26 |
+
def calculate(
|
| 27 |
+
self,
|
| 28 |
+
activations: Union[Dict[str, torch.Tensor], ActivationCache],
|
| 29 |
+
layer_num: int,
|
| 30 |
+
feature_indices: List[int],
|
| 31 |
+
max_value_indices: torch.Tensor,
|
| 32 |
+
) -> Dict[int, Any]: # type: ignore
|
| 33 |
+
"""Calculate DFA values for a given layer and set of feature indices."""
|
| 34 |
+
if not feature_indices:
|
| 35 |
+
return {}
|
| 36 |
+
|
| 37 |
+
v = activations[f"blocks.{layer_num}.attn.hook_v"]
|
| 38 |
+
attn_weights = activations[f"blocks.{layer_num}.attn.hook_pattern"]
|
| 39 |
+
|
| 40 |
+
if self.use_gqa:
|
| 41 |
+
per_src_pos_dfa = self.calculate_gqa_intermediate_tensor(
|
| 42 |
+
attn_weights, v, feature_indices
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
per_src_pos_dfa = self.calculate_standard_intermediate_tensor(
|
| 46 |
+
attn_weights, v, feature_indices
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
n_prompts, seq_len, _, n_features = per_src_pos_dfa.shape
|
| 50 |
+
|
| 51 |
+
# Use advanced indexing to get per_src_dfa
|
| 52 |
+
prompt_indices = torch.arange(n_prompts)[:, None, None]
|
| 53 |
+
src_pos_indices = torch.arange(seq_len)[None, :, None]
|
| 54 |
+
feature_indices_tensor = torch.arange(n_features)[None, None, :]
|
| 55 |
+
max_value_indices_expanded = max_value_indices[:, None, :]
|
| 56 |
+
|
| 57 |
+
per_src_dfa = per_src_pos_dfa[
|
| 58 |
+
prompt_indices,
|
| 59 |
+
max_value_indices_expanded,
|
| 60 |
+
src_pos_indices,
|
| 61 |
+
feature_indices_tensor,
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
max_values, _ = per_src_dfa.max(dim=1)
|
| 65 |
+
|
| 66 |
+
# Create a structured numpy array to hold all the data
|
| 67 |
+
dtype = np.dtype(
|
| 68 |
+
[
|
| 69 |
+
("dfa_values", np.float32, (seq_len,)),
|
| 70 |
+
("dfa_target_index", np.int32),
|
| 71 |
+
("dfa_max_value", np.float32),
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
results = np.zeros((len(feature_indices), n_prompts), dtype=dtype)
|
| 75 |
+
|
| 76 |
+
# Fill the numpy array with data
|
| 77 |
+
results["dfa_values"] = per_src_dfa.detach().cpu().numpy().transpose(2, 0, 1)
|
| 78 |
+
results["dfa_target_index"] = max_value_indices.detach().cpu().numpy().T
|
| 79 |
+
results["dfa_max_value"] = max_values.detach().cpu().numpy().T
|
| 80 |
+
|
| 81 |
+
# Create a dictionary mapping feature indices to their respective data
|
| 82 |
+
final_results = {
|
| 83 |
+
feat_idx: results[i] for i, feat_idx in enumerate(feature_indices)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
return final_results
|
| 87 |
+
|
| 88 |
+
def calculate_standard_intermediate_tensor(
|
| 89 |
+
self,
|
| 90 |
+
attn_weights: torch.Tensor,
|
| 91 |
+
v: torch.Tensor,
|
| 92 |
+
feature_indices: List[int],
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
v_cat = einops.rearrange(
|
| 95 |
+
v, "batch src_pos n_heads d_head -> batch src_pos (n_heads d_head)"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
attn_weights_bcast = einops.repeat(
|
| 99 |
+
attn_weights,
|
| 100 |
+
"batch n_heads dest_pos src_pos -> batch dest_pos src_pos (n_heads d_head)",
|
| 101 |
+
d_head=self.model.cfg.d_head,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
decomposed_z_cat = attn_weights_bcast * v_cat.unsqueeze(1)
|
| 105 |
+
|
| 106 |
+
W_enc_selected = self.sae.W_enc[:, feature_indices] # [d_model, num_indices]
|
| 107 |
+
|
| 108 |
+
per_src_pos_dfa = einops.einsum(
|
| 109 |
+
decomposed_z_cat,
|
| 110 |
+
W_enc_selected,
|
| 111 |
+
"batch dest_pos src_pos d_model, d_model num_features -> batch dest_pos src_pos num_features",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return per_src_pos_dfa
|
| 115 |
+
|
| 116 |
+
def calculate_gqa_intermediate_tensor(
|
| 117 |
+
self, attn_weights: torch.Tensor, v: torch.Tensor, feature_indices: List[int]
|
| 118 |
+
) -> torch.Tensor:
|
| 119 |
+
n_query_heads = attn_weights.shape[1]
|
| 120 |
+
n_kv_heads = v.shape[2]
|
| 121 |
+
expansion_factor = n_query_heads // n_kv_heads
|
| 122 |
+
v = v.repeat_interleave(expansion_factor, dim=2)
|
| 123 |
+
|
| 124 |
+
v_cat = einops.rearrange(
|
| 125 |
+
v, "batch src_pos n_heads d_head -> batch src_pos (n_heads d_head)"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
W_enc_selected = self.sae.W_enc[:, feature_indices] # [d_model, num_indices]
|
| 129 |
+
|
| 130 |
+
# Initialize the result tensor
|
| 131 |
+
n_prompts, seq_len, _ = v_cat.shape
|
| 132 |
+
n_features = len(feature_indices)
|
| 133 |
+
per_src_pos_dfa = torch.zeros(
|
| 134 |
+
(n_prompts, seq_len, seq_len, n_features), device=v_cat.device
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Process in chunks
|
| 138 |
+
chunk_size = 16 # Adjust this based on your memory constraints
|
| 139 |
+
for i in range(0, seq_len, chunk_size):
|
| 140 |
+
chunk_end = min(i + chunk_size, seq_len)
|
| 141 |
+
|
| 142 |
+
# Process a chunk of destination positions
|
| 143 |
+
attn_weights_chunk = attn_weights[:, :, i:chunk_end, :]
|
| 144 |
+
attn_weights_bcast_chunk = einops.repeat(
|
| 145 |
+
attn_weights_chunk,
|
| 146 |
+
"batch n_heads dest_pos src_pos -> batch dest_pos src_pos (n_heads d_head)",
|
| 147 |
+
d_head=self.model.cfg.d_head,
|
| 148 |
+
)
|
| 149 |
+
decomposed_z_cat_chunk = attn_weights_bcast_chunk * v_cat.unsqueeze(1)
|
| 150 |
+
|
| 151 |
+
per_src_pos_dfa_chunk = einops.einsum(
|
| 152 |
+
decomposed_z_cat_chunk,
|
| 153 |
+
W_enc_selected,
|
| 154 |
+
"batch dest_pos src_pos d_model, d_model num_features -> batch dest_pos src_pos num_features",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
per_src_pos_dfa[:, i:chunk_end, :, :] = per_src_pos_dfa_chunk
|
| 158 |
+
|
| 159 |
+
return per_src_pos_dfa
|
SAEDashboard/sae_dashboard/feature_data.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Any, Callable, List, Literal, Optional
|
| 3 |
+
|
| 4 |
+
from sae_dashboard.components import (
|
| 5 |
+
ActsHistogramData,
|
| 6 |
+
DecoderWeightsDistribution,
|
| 7 |
+
FeatureTablesData,
|
| 8 |
+
GenericData,
|
| 9 |
+
LogitsHistogramData,
|
| 10 |
+
LogitsTableData,
|
| 11 |
+
SequenceData,
|
| 12 |
+
SequenceMultiGroupData,
|
| 13 |
+
)
|
| 14 |
+
from sae_dashboard.components_config import (
|
| 15 |
+
ActsHistogramConfig,
|
| 16 |
+
FeatureTablesConfig,
|
| 17 |
+
GenericComponentConfig,
|
| 18 |
+
LogitsHistogramConfig,
|
| 19 |
+
LogitsTableConfig,
|
| 20 |
+
PromptConfig,
|
| 21 |
+
SequencesConfig,
|
| 22 |
+
)
|
| 23 |
+
from sae_dashboard.html_fns import HTML
|
| 24 |
+
from sae_dashboard.layout import SaeVisLayoutConfig
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class DFAData:
|
| 29 |
+
dfaValues: List[List[float]] = field(default_factory=list)
|
| 30 |
+
dfaTargetIndex: List[int] = field(default_factory=list)
|
| 31 |
+
dfaMaxValue: float = 0.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class FeatureData:
|
| 36 |
+
"""
|
| 37 |
+
This contains all the data necessary to make the feature-centric visualization, for a single feature. See
|
| 38 |
+
diagram in readme:
|
| 39 |
+
|
| 40 |
+
https://github.com/callummcdougall/sae_vis#data_storing_fnspy
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
feature_idx: Index of the feature in question (not used within this class's methods, but used elsewhere).
|
| 44 |
+
cfg: Contains layout parameters which are important in the `get_html` function.
|
| 45 |
+
|
| 46 |
+
The other args are the 6 possible components we might have in the feature-centric vis, i.e. this is where we
|
| 47 |
+
store the actual data. Note that one of these arguments is `prompt_data` which is only applicable in the prompt-
|
| 48 |
+
centric view.
|
| 49 |
+
|
| 50 |
+
This is used in both the feature-centric and prompt-centric views. In the feature-centric view, a single one
|
| 51 |
+
of these objects creates the HTML for a single feature (i.e. a full screen). In the prompt-centric view, a single
|
| 52 |
+
one of these objects will create one column of the full screen vis.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
feature_tables_data: FeatureTablesData = field(
|
| 56 |
+
default_factory=lambda: FeatureTablesData()
|
| 57 |
+
)
|
| 58 |
+
acts_histogram_data: ActsHistogramData = field(
|
| 59 |
+
default_factory=lambda: ActsHistogramData()
|
| 60 |
+
)
|
| 61 |
+
logits_table_data: LogitsTableData = field(
|
| 62 |
+
default_factory=lambda: LogitsTableData()
|
| 63 |
+
)
|
| 64 |
+
logits_histogram_data: LogitsHistogramData = field(
|
| 65 |
+
default_factory=lambda: LogitsHistogramData()
|
| 66 |
+
)
|
| 67 |
+
sequence_data: SequenceMultiGroupData = field(
|
| 68 |
+
default_factory=lambda: SequenceMultiGroupData()
|
| 69 |
+
)
|
| 70 |
+
prompt_data: SequenceData = field(default_factory=lambda: SequenceData())
|
| 71 |
+
dfa_data: Optional[dict[int, dict[str, Any]]] = None
|
| 72 |
+
decoder_weights_data: Optional[DecoderWeightsDistribution] = None
|
| 73 |
+
|
| 74 |
+
def __post_init__(self):
|
| 75 |
+
if self.dfa_data is None:
|
| 76 |
+
self.dfa_data = {}
|
| 77 |
+
|
| 78 |
+
def get_component_from_config(self, config: GenericComponentConfig) -> GenericData:
|
| 79 |
+
"""
|
| 80 |
+
Given a config object, returns the corresponding data object stored by this instance. For instance, if the input
|
| 81 |
+
is an `FeatureTablesConfig` instance, then this function returns `self.feature_tables_data`.
|
| 82 |
+
"""
|
| 83 |
+
CONFIG_CLASS_MAP = {
|
| 84 |
+
FeatureTablesConfig.__name__: self.feature_tables_data,
|
| 85 |
+
ActsHistogramConfig.__name__: self.acts_histogram_data,
|
| 86 |
+
LogitsTableConfig.__name__: self.logits_table_data,
|
| 87 |
+
LogitsHistogramConfig.__name__: self.logits_histogram_data,
|
| 88 |
+
SequencesConfig.__name__: self.sequence_data,
|
| 89 |
+
PromptConfig.__name__: self.prompt_data,
|
| 90 |
+
# Add DFA config here if we create a specific config for it
|
| 91 |
+
}
|
| 92 |
+
config_class_name = config.__class__.__name__
|
| 93 |
+
assert (
|
| 94 |
+
config_class_name in CONFIG_CLASS_MAP
|
| 95 |
+
), f"Invalid component config: {config_class_name}"
|
| 96 |
+
return CONFIG_CLASS_MAP[config_class_name]
|
| 97 |
+
|
| 98 |
+
def _get_html_data_feature_centric(
|
| 99 |
+
self,
|
| 100 |
+
layout: SaeVisLayoutConfig,
|
| 101 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 102 |
+
) -> HTML:
|
| 103 |
+
"""
|
| 104 |
+
Returns the HTML object for a single feature-centric view. These are assembled together into the full feature-
|
| 105 |
+
centric view.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
decode_fn: We use this function to decode the token IDs into string tokens.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
html_obj.html_data:
|
| 112 |
+
Contains a dictionary with keys equal to columns, and values equal to the HTML strings. These will be
|
| 113 |
+
turned into grid-column elements, and concatenated.
|
| 114 |
+
html_obj.js_data:
|
| 115 |
+
Contains a dictionary with keys = component names, and values = JavaScript data that will be used by the
|
| 116 |
+
scripts we'll eventually dump in.
|
| 117 |
+
"""
|
| 118 |
+
# Create object to store all HTML
|
| 119 |
+
html_obj = HTML()
|
| 120 |
+
|
| 121 |
+
# For every column in this feature-centric layout, we add all the components in that column
|
| 122 |
+
for column_idx, column_components in layout.columns.items():
|
| 123 |
+
for component_config in column_components:
|
| 124 |
+
component = self.get_component_from_config(component_config)
|
| 125 |
+
|
| 126 |
+
html_obj += component._get_html_data(
|
| 127 |
+
cfg=component_config,
|
| 128 |
+
decode_fn=decode_fn,
|
| 129 |
+
column=column_idx,
|
| 130 |
+
id_suffix="0", # we only use this if we have >1 set of histograms, i.e. prompt-centric vis
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return html_obj
|
| 134 |
+
|
| 135 |
+
def _get_html_data_prompt_centric(
|
| 136 |
+
self,
|
| 137 |
+
layout: SaeVisLayoutConfig,
|
| 138 |
+
decode_fn: Callable[[int | list[int]], str | list[str]],
|
| 139 |
+
column_idx: int,
|
| 140 |
+
bold_idx: int | Literal["max"],
|
| 141 |
+
title: str,
|
| 142 |
+
) -> HTML:
|
| 143 |
+
"""
|
| 144 |
+
Returns the HTML object for a single column of the prompt-centric view. These are assembled together into a full
|
| 145 |
+
screen of a prompt-centric view, and then they're further assembled together into the full prompt-centric view.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
decode_fn: We use this function to decode the token IDs into string tokens.
|
| 149 |
+
column_idx: This method only gives us a single column (of the prompt-centric vis), so we need to know which
|
| 150 |
+
column this is (for the JavaScript data).
|
| 151 |
+
bold_idx: Which index should be bolded in the sequence data. If "max", we default to bolding the max-act
|
| 152 |
+
token in each sequence.
|
| 153 |
+
title: The title for this column, which will be used in the JavaScript data.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
html_obj.html_data:
|
| 157 |
+
Contains a dictionary with the single key `str(column_idx)`, representing the single column. This will
|
| 158 |
+
become a single grid-column element, and will get concatenated with others of these.
|
| 159 |
+
html_obj.js_data:
|
| 160 |
+
Contains a dictionary with keys = component names, and values = JavaScript data that will be used by the
|
| 161 |
+
scripts we'll eventually dump in.
|
| 162 |
+
"""
|
| 163 |
+
# Create object to store all HTML
|
| 164 |
+
html_obj = HTML()
|
| 165 |
+
|
| 166 |
+
# Verify that we only have a single column
|
| 167 |
+
assert layout.columns.keys() == {
|
| 168 |
+
0
|
| 169 |
+
}, f"prompt_centric_layout should only have 1 column, instead found cols {layout.columns.keys()}"
|
| 170 |
+
assert (
|
| 171 |
+
layout.prompt_cfg is not None
|
| 172 |
+
), "prompt_centric_cfg should include a PromptConfig, but found None"
|
| 173 |
+
if layout.seq_cfg is not None:
|
| 174 |
+
assert (layout.seq_cfg.n_quantiles == 0) or (
|
| 175 |
+
layout.seq_cfg.stack_mode == "stack-all"
|
| 176 |
+
), "prompt_centric_layout should have stack_mode='stack-all' if n_quantiles > 0, so that it fits in 1 col"
|
| 177 |
+
|
| 178 |
+
# Get the maximum color over both the prompt and the sequences
|
| 179 |
+
max_feat_act = max(
|
| 180 |
+
max(self.prompt_data.feat_acts), self.sequence_data.max_feat_act
|
| 181 |
+
)
|
| 182 |
+
max_loss_contribution = max(
|
| 183 |
+
max(self.prompt_data.loss_contribution),
|
| 184 |
+
self.sequence_data.max_loss_contribution,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# For every component in the single column of this prompt-centric layout, add all the components in that column
|
| 188 |
+
for component_config in layout.columns[0]:
|
| 189 |
+
component = self.get_component_from_config(component_config)
|
| 190 |
+
|
| 191 |
+
html_obj += component._get_html_data(
|
| 192 |
+
cfg=component_config,
|
| 193 |
+
decode_fn=decode_fn,
|
| 194 |
+
column=column_idx,
|
| 195 |
+
id_suffix=str(column_idx),
|
| 196 |
+
component_specific_kwargs=dict( # only used by SequenceData (the prompt)
|
| 197 |
+
bold_idx=bold_idx,
|
| 198 |
+
permanent_line=True,
|
| 199 |
+
hover_above=True,
|
| 200 |
+
max_feat_act=max_feat_act,
|
| 201 |
+
max_loss_contribution=max_loss_contribution,
|
| 202 |
+
),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Add the title in JavaScript, and the empty title element in HTML
|
| 206 |
+
html_obj.html_data[column_idx] = (
|
| 207 |
+
f"<div id='column-{column_idx}-title'></div>\n{html_obj.html_data[column_idx]}"
|
| 208 |
+
)
|
| 209 |
+
html_obj.js_data["gridColumnTitlesData"] = {str(column_idx): title}
|
| 210 |
+
|
| 211 |
+
return html_obj
|
SAEDashboard/sae_dashboard/feature_data_generator.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Any, Dict, List
|
| 3 |
+
|
| 4 |
+
import einops
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from jaxtyping import Float, Int
|
| 8 |
+
from sae_lens import SAE, HookedSAETransformer
|
| 9 |
+
from sae_lens.config import DTYPE_MAP as DTYPES
|
| 10 |
+
from sae_lens.saes.topk_sae import TopK
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
from tqdm.auto import tqdm
|
| 13 |
+
|
| 14 |
+
from sae_dashboard.dfa_calculator import DFACalculator
|
| 15 |
+
from sae_dashboard.sae_vis_data import SaeVisConfig
|
| 16 |
+
from sae_dashboard.transformer_lens_wrapper import to_resid_direction
|
| 17 |
+
from sae_dashboard.utils_fns import RollingCorrCoef
|
| 18 |
+
|
| 19 |
+
Arr = np.ndarray
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FeatureDataGenerator:
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
cfg: SaeVisConfig,
|
| 26 |
+
tokens: Int[Tensor, "batch seq"],
|
| 27 |
+
model: HookedSAETransformer,
|
| 28 |
+
encoder: SAE[Any],
|
| 29 |
+
):
|
| 30 |
+
self.cfg = cfg
|
| 31 |
+
self.model = model
|
| 32 |
+
self.encoder = encoder
|
| 33 |
+
self.token_minibatches = self.batch_tokens(tokens)
|
| 34 |
+
self.dfa_calculator = (
|
| 35 |
+
DFACalculator(model.model, encoder) if cfg.use_dfa else None # type: ignore
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if cfg.use_dfa:
|
| 39 |
+
assert (
|
| 40 |
+
"hook_z" in encoder.cfg.hook_name
|
| 41 |
+
), f"DFAs are only supported for hook_z, but got {encoder.cfg.hook_name}"
|
| 42 |
+
|
| 43 |
+
@torch.inference_mode()
|
| 44 |
+
def batch_tokens(
|
| 45 |
+
self, tokens: Int[Tensor, "batch seq"]
|
| 46 |
+
) -> list[Int[Tensor, "batch seq"]]:
|
| 47 |
+
# Get tokens into minibatches, for the fwd pass
|
| 48 |
+
token_minibatches = (
|
| 49 |
+
(tokens,)
|
| 50 |
+
if self.cfg.minibatch_size_tokens is None
|
| 51 |
+
else tokens.split(self.cfg.minibatch_size_tokens)
|
| 52 |
+
)
|
| 53 |
+
token_minibatches = [tok.to(self.cfg.device) for tok in token_minibatches]
|
| 54 |
+
|
| 55 |
+
return token_minibatches
|
| 56 |
+
|
| 57 |
+
@torch.inference_mode()
|
| 58 |
+
def get_feature_data( # type: ignore
|
| 59 |
+
self,
|
| 60 |
+
feature_indices: list[int],
|
| 61 |
+
progress: list[tqdm] | None = None, # type: ignore
|
| 62 |
+
): # type: ignore
|
| 63 |
+
# Create lists to store the feature activations & final values of the residual stream
|
| 64 |
+
all_feat_acts = []
|
| 65 |
+
all_dfa_results = {feature_idx: {} for feature_idx in feature_indices}
|
| 66 |
+
total_prompts = 0
|
| 67 |
+
|
| 68 |
+
# Create objects to store the data for computing rolling stats
|
| 69 |
+
corrcoef_neurons = RollingCorrCoef()
|
| 70 |
+
corrcoef_encoder = RollingCorrCoef(indices=feature_indices, with_self=True)
|
| 71 |
+
|
| 72 |
+
# Get encoder & decoder directions
|
| 73 |
+
feature_out_dir = self.encoder.W_dec[feature_indices] # [feats d_autoencoder]
|
| 74 |
+
feature_resid_dir = to_resid_direction(
|
| 75 |
+
feature_out_dir, self.model # type: ignore
|
| 76 |
+
) # [feats d_model]
|
| 77 |
+
|
| 78 |
+
# ! Compute & concatenate together all feature activations & post-activation function values
|
| 79 |
+
for i, minibatch in enumerate(self.token_minibatches):
|
| 80 |
+
minibatch.to(self.cfg.device)
|
| 81 |
+
model_activation_dict = self.get_model_acts(i, minibatch)
|
| 82 |
+
primary_acts = model_activation_dict[
|
| 83 |
+
self.model.activation_config.primary_hook_point # type: ignore
|
| 84 |
+
].to(
|
| 85 |
+
self.encoder.device
|
| 86 |
+
) # make sure acts are on the correct device
|
| 87 |
+
|
| 88 |
+
# For TopK, compute all activations first, then select features
|
| 89 |
+
if isinstance(self.encoder.activation_fn, TopK):
|
| 90 |
+
# Get all features' activations
|
| 91 |
+
all_features_acts = self.encoder.encode(primary_acts)
|
| 92 |
+
# Then select only the features we're interested in
|
| 93 |
+
feature_acts = all_features_acts[:, :, feature_indices].to(
|
| 94 |
+
DTYPES[self.cfg.dtype]
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
# For other activation functions, use the masking context
|
| 98 |
+
with FeatureMaskingContext(self.encoder, feature_indices):
|
| 99 |
+
feature_acts = self.encoder.encode(primary_acts).to(
|
| 100 |
+
DTYPES[self.cfg.dtype]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.update_rolling_coefficients(
|
| 104 |
+
model_acts=primary_acts,
|
| 105 |
+
feature_acts=feature_acts,
|
| 106 |
+
corrcoef_neurons=corrcoef_neurons,
|
| 107 |
+
corrcoef_encoder=corrcoef_encoder,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Add these to the lists (we'll eventually concat)
|
| 111 |
+
all_feat_acts.append(feature_acts)
|
| 112 |
+
|
| 113 |
+
# Calculate DFA
|
| 114 |
+
if self.cfg.use_dfa and self.dfa_calculator:
|
| 115 |
+
max_value_indices = torch.argmax(feature_acts, dim=1)
|
| 116 |
+
batch_dfa_results = self.dfa_calculator.calculate(
|
| 117 |
+
model_activation_dict,
|
| 118 |
+
self.model.hook_layer, # type: ignore
|
| 119 |
+
feature_indices,
|
| 120 |
+
max_value_indices,
|
| 121 |
+
)
|
| 122 |
+
for feature_idx, feature_data in batch_dfa_results.items():
|
| 123 |
+
for prompt_idx in range(feature_data.shape[0]):
|
| 124 |
+
global_prompt_idx = total_prompts + prompt_idx
|
| 125 |
+
all_dfa_results[feature_idx][global_prompt_idx] = {
|
| 126 |
+
"dfaValues": feature_data[prompt_idx][
|
| 127 |
+
"dfa_values"
|
| 128 |
+
].tolist(),
|
| 129 |
+
"dfaTargetIndex": int(
|
| 130 |
+
feature_data[prompt_idx]["dfa_target_index"]
|
| 131 |
+
),
|
| 132 |
+
"dfaMaxValue": float(
|
| 133 |
+
feature_data[prompt_idx]["dfa_max_value"]
|
| 134 |
+
),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
total_prompts += len(minibatch)
|
| 138 |
+
|
| 139 |
+
# Update the 1st progress bar (fwd passes & getting sequence data dominates the runtime of these computations)
|
| 140 |
+
if progress is not None:
|
| 141 |
+
progress[0].update(1)
|
| 142 |
+
|
| 143 |
+
all_feat_acts = torch.cat(all_feat_acts, dim=0)
|
| 144 |
+
|
| 145 |
+
return (
|
| 146 |
+
all_feat_acts,
|
| 147 |
+
torch.tensor([]), # all_resid_post, no longer used
|
| 148 |
+
feature_resid_dir,
|
| 149 |
+
feature_out_dir,
|
| 150 |
+
corrcoef_neurons,
|
| 151 |
+
corrcoef_encoder,
|
| 152 |
+
all_dfa_results,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
@torch.inference_mode()
|
| 156 |
+
def get_model_acts(
|
| 157 |
+
self,
|
| 158 |
+
minibatch_index: int,
|
| 159 |
+
minibatch_tokens: torch.Tensor,
|
| 160 |
+
use_cache: bool = True,
|
| 161 |
+
) -> Dict[str, torch.Tensor]:
|
| 162 |
+
"""
|
| 163 |
+
A function that gets the model activations for a given minibatch of tokens.
|
| 164 |
+
Uses np.memmap for efficient caching.
|
| 165 |
+
"""
|
| 166 |
+
if self.cfg.cache_dir is not None:
|
| 167 |
+
cache_path = self.cfg.cache_dir / f"model_activations_{minibatch_index}.pt"
|
| 168 |
+
if use_cache and cache_path.exists():
|
| 169 |
+
activation_dict = load_tensor_dict_torch(cache_path, self.cfg.device)
|
| 170 |
+
else:
|
| 171 |
+
activation_dict = self.model.forward(
|
| 172 |
+
minibatch_tokens.to("cpu"), return_logits=False # type: ignore
|
| 173 |
+
)
|
| 174 |
+
save_tensor_dict_torch(activation_dict, cache_path)
|
| 175 |
+
else:
|
| 176 |
+
activation_dict = self.model.forward(
|
| 177 |
+
minibatch_tokens.to("cpu"), return_logits=False # type: ignore
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return activation_dict
|
| 181 |
+
|
| 182 |
+
@torch.inference_mode()
|
| 183 |
+
def update_rolling_coefficients(
|
| 184 |
+
self,
|
| 185 |
+
model_acts: Float[Tensor, "batch seq d_in"],
|
| 186 |
+
feature_acts: Float[Tensor, "batch seq feats"],
|
| 187 |
+
corrcoef_neurons: RollingCorrCoef | None,
|
| 188 |
+
corrcoef_encoder: RollingCorrCoef | None,
|
| 189 |
+
) -> None:
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
model_acts: Float[Tensor, "batch seq d_in"]
|
| 194 |
+
The activations of the model, which the SAE was trained on.
|
| 195 |
+
feature_idx: list[int]
|
| 196 |
+
The features we're computing the activations for. This will be used to index the encoder's weights.
|
| 197 |
+
corrcoef_neurons: Optional[RollingCorrCoef]
|
| 198 |
+
The object storing the minimal data necessary to compute corrcoef between feature activations & neurons.
|
| 199 |
+
corrcoef_encoder: Optional[RollingCorrCoef]
|
| 200 |
+
The object storing the minimal data necessary to compute corrcoef between pairwise feature activations.
|
| 201 |
+
"""
|
| 202 |
+
# Update the CorrCoef object between feature activation & neurons
|
| 203 |
+
if corrcoef_neurons is not None:
|
| 204 |
+
corrcoef_neurons.update(
|
| 205 |
+
einops.rearrange(feature_acts, "batch seq feats -> feats (batch seq)"),
|
| 206 |
+
einops.rearrange(model_acts, "batch seq d_in -> d_in (batch seq)"),
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Update the CorrCoef object between pairwise feature activations
|
| 210 |
+
if corrcoef_encoder is not None:
|
| 211 |
+
corrcoef_encoder.update(
|
| 212 |
+
einops.rearrange(feature_acts, "batch seq feats -> feats (batch seq)"),
|
| 213 |
+
einops.rearrange(feature_acts, "batch seq feats -> feats (batch seq)"),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def save_tensor_dict_torch(tensor_dict: Dict[str, torch.Tensor], filename: Path):
|
| 218 |
+
torch.save(tensor_dict, filename)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def load_tensor_dict_torch(filename: Path, device: str) -> Dict[str, torch.Tensor]:
|
| 222 |
+
return torch.load(
|
| 223 |
+
filename, map_location=torch.device(device)
|
| 224 |
+
) # Directly load to GPU
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class FeatureMaskingContext:
|
| 228 |
+
def __init__(self, sae: SAE[Any], feature_idxs: List[int]):
|
| 229 |
+
self.sae = sae
|
| 230 |
+
self.feature_idxs = feature_idxs
|
| 231 |
+
self.original_weight = {}
|
| 232 |
+
|
| 233 |
+
def __enter__(self):
|
| 234 |
+
## W_dec
|
| 235 |
+
self.original_weight["W_dec"] = getattr(self.sae, "W_dec").data.clone()
|
| 236 |
+
# mask the weight
|
| 237 |
+
masked_weight = self.sae.W_dec[self.feature_idxs]
|
| 238 |
+
# set the weight
|
| 239 |
+
setattr(self.sae, "W_dec", nn.Parameter(masked_weight))
|
| 240 |
+
|
| 241 |
+
## W_enc
|
| 242 |
+
# clone the weight.
|
| 243 |
+
self.original_weight["W_enc"] = getattr(self.sae, "W_enc").data.clone()
|
| 244 |
+
# mask the weight
|
| 245 |
+
masked_weight = self.sae.W_enc[:, self.feature_idxs]
|
| 246 |
+
# set the weight
|
| 247 |
+
setattr(self.sae, "W_enc", nn.Parameter(masked_weight))
|
| 248 |
+
|
| 249 |
+
# Handle architecture as either attribute or method
|
| 250 |
+
architecture = self.sae.cfg.architecture
|
| 251 |
+
if callable(architecture):
|
| 252 |
+
architecture = architecture()
|
| 253 |
+
|
| 254 |
+
if architecture in [
|
| 255 |
+
"standard",
|
| 256 |
+
"standard_transcoder",
|
| 257 |
+
"transcoder",
|
| 258 |
+
"skip_transcoder",
|
| 259 |
+
]:
|
| 260 |
+
## b_enc
|
| 261 |
+
self.original_weight["b_enc"] = getattr(self.sae, "b_enc").data.clone()
|
| 262 |
+
# mask the weight
|
| 263 |
+
masked_weight = self.sae.b_enc[self.feature_idxs] # type: ignore
|
| 264 |
+
# set the weight
|
| 265 |
+
setattr(self.sae, "b_enc", nn.Parameter(masked_weight))
|
| 266 |
+
|
| 267 |
+
elif architecture in ["jumprelu", "jumprelu_transcoder"]:
|
| 268 |
+
## b_enc
|
| 269 |
+
self.original_weight["b_enc"] = getattr(self.sae, "b_enc").data.clone()
|
| 270 |
+
# mask the weight
|
| 271 |
+
masked_weight = self.sae.b_enc[self.feature_idxs] # type: ignore
|
| 272 |
+
# set the weight
|
| 273 |
+
setattr(self.sae, "b_enc", nn.Parameter(masked_weight))
|
| 274 |
+
|
| 275 |
+
## threshold
|
| 276 |
+
self.original_weight["threshold"] = getattr(
|
| 277 |
+
self.sae, "threshold"
|
| 278 |
+
).data.clone()
|
| 279 |
+
# mask the weight
|
| 280 |
+
masked_weight = self.sae.threshold[self.feature_idxs] # type: ignore
|
| 281 |
+
# set the weight
|
| 282 |
+
setattr(self.sae, "threshold", nn.Parameter(masked_weight))
|
| 283 |
+
|
| 284 |
+
elif architecture in ["gated", "gated_transcoder"]:
|
| 285 |
+
## b_gate
|
| 286 |
+
self.original_weight["b_gate"] = getattr(self.sae, "b_gate").data.clone()
|
| 287 |
+
# mask the weight
|
| 288 |
+
masked_weight = self.sae.b_gate[self.feature_idxs] # type: ignore
|
| 289 |
+
# set the weight
|
| 290 |
+
setattr(self.sae, "b_gate", nn.Parameter(masked_weight))
|
| 291 |
+
|
| 292 |
+
## r_mag
|
| 293 |
+
self.original_weight["r_mag"] = getattr(self.sae, "r_mag").data.clone()
|
| 294 |
+
# mask the weight
|
| 295 |
+
masked_weight = self.sae.r_mag[self.feature_idxs] # type: ignore
|
| 296 |
+
# set the weight
|
| 297 |
+
setattr(self.sae, "r_mag", nn.Parameter(masked_weight))
|
| 298 |
+
|
| 299 |
+
## b_mag
|
| 300 |
+
self.original_weight["b_mag"] = getattr(self.sae, "b_mag").data.clone()
|
| 301 |
+
# mask the weight
|
| 302 |
+
masked_weight = self.sae.b_mag[self.feature_idxs] # type: ignore
|
| 303 |
+
# set the weight
|
| 304 |
+
setattr(self.sae, "b_mag", nn.Parameter(masked_weight))
|
| 305 |
+
else:
|
| 306 |
+
raise (ValueError("Invalid architecture"))
|
| 307 |
+
|
| 308 |
+
return self
|
| 309 |
+
|
| 310 |
+
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
|
| 311 |
+
# set everything back to normal
|
| 312 |
+
for key, value in self.original_weight.items():
|
| 313 |
+
setattr(self.sae, key, nn.Parameter(value))
|
SAEDashboard/sae_dashboard/html/acts_histogram_template.html
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- Activation densities histogram -->
|
| 2 |
+
<div class="plotly-hist" id="HISTOGRAM_ACTS_ID" style="height: 150px; margin-top: 0px;"></div>
|
SAEDashboard/sae_dashboard/html/feature_tables_template.html
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- Feature Info Tables -->
|
| 2 |
+
<div class="feature-tables" id="FEATURE_TABLES_ID"></div>
|
SAEDashboard/sae_dashboard/html/logits_histogram_template.html
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- Logits histogram -->
|
| 2 |
+
<div class="plotly-hist" id="HISTOGRAM_LOGITS_ID" style="height: 150px; margin-top: 0px;"></div>
|
SAEDashboard/sae_dashboard/html/logits_table_template.html
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- Logits table -->
|
| 2 |
+
<div class="logits-table" id="LOGITS_TABLE_ID"></div>
|
SAEDashboard/sae_dashboard/html/sequences_group_template.html
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- Sequence group -->
|
| 2 |
+
<div class="seq-group" id="SEQUENCE_GROUP_ID"></div>
|