ESMFold2-Fast / configuration_esmc_sae.py
lhallee's picture
Upload folder using huggingface_hub
fb8a87c verified
Raw
History Blame Contribute Delete
2.69 kB
# Copyright 2026 Biohub. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ESMC sparse autoencoder (SAE) configuration."""
from dataclasses import dataclass
from transformers.configuration_utils import PretrainedConfig
@dataclass
class ESMCSAEParams:
"""Parameters for one backbone layer's SAE inside :class:`ESMCSAEModel`.
The SAE itself is an internal ``nn.Module``; this dataclass just bundles
the handful of fields needed to instantiate one.
"""
d_model: int = 2560
codebook_dim: int = 65536
k: int = 64
layer: int = 0
class ESMCSAEConfig(PretrainedConfig):
"""
Configuration class for [`ESMCSAEModel`] — a container that holds one
SAE per backbone layer for a fixed ``(model, codebook_dim, k)`` group.
All SAEs in a container share ``d_model``, ``codebook_dim``, and ``k``;
they differ only in the backbone layer they were trained on.
``available_layers`` lists the backbone-layer indices the repo ships;
each entry ``i`` is stored on disk as ``layer_{i}.safetensors`` (the
filename index *is* the backbone layer, so a single-layer repo for
layer 23 stores ``layer_23.safetensors``).
Args:
d_model (`int`, *optional*, defaults to 2560):
Dimensionality of the ESMC hidden states fed into the SAEs.
codebook_dim (`int`, *optional*, defaults to 65536):
Number of sparse features in each SAE's codebook.
k (`int`, *optional*, defaults to 64):
Top-k sparsity per SAE.
available_layers (`list[int]`, *optional*, defaults to ``[0]``):
Which backbone-layer indices the repo ships.
"""
model_type = "esmc_sae"
def __init__(
self,
d_model: int = 2560,
codebook_dim: int = 65536,
k: int = 64,
available_layers: list[int] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.codebook_dim = codebook_dim
self.k = k
self.available_layers = (
list(available_layers) if available_layers is not None else [0]
)
__all__ = ["ESMCSAEConfig", "ESMCSAEParams"]