ESMFold2-Fast / configuration_esmfold2.py
lhallee's picture
Upload folder using huggingface_hub
407875b verified
Raw
History Blame Contribute Delete
11.7 kB
# Copyright 2026 Biohub. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ESMFold2 model configuration."""
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from transformers.configuration_utils import PretrainedConfig
# ---------------------------------------------------------------------------
# Nested dataclass configs
# ---------------------------------------------------------------------------
_DEFAULT_ESMC_HF_REPO = "Synthyra/ESMplusplus_6B"
_DEFAULT_ESMC_ATTN_BACKEND = "flex"
_ESMC_ID_ALIASES = {
"biohub/ESMC-300M": "Synthyra/ESMplusplus_small",
"biohub/ESMC-600M": "Synthyra/ESMplusplus_large",
"biohub/ESMC-6B": "Synthyra/ESMplusplus_6B",
}
def normalize_esmc_id(esmc_id: str) -> str:
if esmc_id in _ESMC_ID_ALIASES:
return _ESMC_ID_ALIASES[esmc_id]
return esmc_id
@dataclass
class MSAEncoderConfig:
"""Config for the optional MSA encoder module (Large MSA models only)."""
enabled: bool = False
d_msa: int = 128
d_hidden: int = 32
n_layers: int = 4
n_heads_msa: int = 8
msa_head_width: int = 32
@dataclass
class ParcaeConfig:
"""Release-only config for the parcae diffusion-loop scheduler."""
enabled: bool = True
poisson_mean: float = 3.0
min_steps: int = 1
max_steps: int | None = 6
coda_n_layers: int = 2
@dataclass
class LMEncoderConfig:
"""Release-only config for the LM-side pair encoder."""
enabled: bool = True
n_layers: int = 4
lm_dropout: float = 0.25
per_loop_lm_dropout: bool = True
@dataclass
class AtomAttentionConfig:
"""Config for SWA atom encoder/decoder with 3D RoPE."""
d_atom: int = 128
d_token: int = 768
n_blocks: int = 3
n_heads: int = 4
swa_window_size: int = 128
expansion_ratio: int = 2
# 3D RoPE config
spatial_rope_base_frequency: float = 20.0
n_spatial_rope_pairs_per_axis: int = 2
n_uid_rope_pairs: int = 10
uid_rope_base_frequency: float = 10000.0
@dataclass
class FoldingTrunkConfig:
n_layers: int = 24
n_heads: int = 8
dropout: float = 0.0
@dataclass
class InputsEmbedderConfig:
d_inputs: int = 451
atom_encoder: AtomAttentionConfig = field(default_factory=AtomAttentionConfig)
def __post_init__(self):
if isinstance(self.atom_encoder, dict):
self.atom_encoder = AtomAttentionConfig(**self.atom_encoder)
@dataclass
class DiffusionModuleConfig:
"""Config for the DiffusionModule."""
sigma_data: float = 16.0
c_atom: int = 128
c_token: int = 768
c_z: int = 256
c_s_inputs: int = 451
fourier_dim: int = 256
relpos_r_max: int = 32
relpos_s_max: int = 2
atom_num_blocks: int = 3
atom_num_heads: int = 4
token_num_blocks: int = 12
token_num_heads: int = 16
transition_multiplier: int = 2
@dataclass
class DiffusionStructureHeadConfig:
"""Config for the diffusion-based structure prediction head."""
diffusion_module: DiffusionModuleConfig = field(
default_factory=DiffusionModuleConfig
)
distogram_bins: int = 128
# Training noise: sigma ~ sigma_data * exp(mu + sigma * N(0,1))
train_noise_log_mean: float = -1.2
train_noise_log_std: float = 1.5
# Sampling defaults (ODE)
gamma_0: float = 0.605
gamma_min: float = 1.107
noise_scale: float = 0.0
step_scale: float = 1.0
# Inference schedule defaults
inference_s_max: float = 160.0
inference_s_min: float = 4e-4
inference_p: float = 8.0
inference_num_steps: int = 68
def __post_init__(self):
if isinstance(self.diffusion_module, dict):
self.diffusion_module = DiffusionModuleConfig(**self.diffusion_module)
@dataclass
class ConfidenceHeadConfig:
enabled: bool = True
num_plddt_bins: int = 50
num_pde_bins: int = 64
num_pae_bins: int = 64
min_dist: float = 2.0
max_dist: float = 52.0
distogram_bins: int = 128
folding_trunk: FoldingTrunkConfig = field(
default_factory=lambda: FoldingTrunkConfig(n_layers=4)
)
def __post_init__(self):
if isinstance(self.folding_trunk, dict):
self.folding_trunk = FoldingTrunkConfig(**self.folding_trunk)
# ---------------------------------------------------------------------------
# Top-level config
# ---------------------------------------------------------------------------
class ESMFold2Config(PretrainedConfig):
"""
Configuration for the ESMFold2 structure prediction model.
Uses SWA atom encoders with 3D RoPE, a diffusion transformer,
a folding trunk, and an ESMC 6B PLM backbone.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control
the model outputs. Read the documentation from [`PretrainedConfig`] for more
information.
Args:
d_single (`int`, defaults to 384):
Dimensionality of single (per-residue) representations.
d_pair (`int`, defaults to 256):
Dimensionality of pair (residue-residue) representations.
n_relative_residx_bins (`int`, defaults to 32):
Number of bins for relative residue index encoding.
n_relative_chain_bins (`int`, defaults to 2):
Number of bins for relative chain encoding.
num_loops (`int`, defaults to 10):
Number of trunk loops for iterative refinement.
num_diffusion_samples (`int`, defaults to 8):
Number of parallel structure predictions to generate.
lm_dropout (`float`, defaults to 0.0):
Dropout probability on LM pair embeddings. When > 0, dropout is
applied with ``training=True`` (including at inference) to match
the experimental training recipe used by binder design.
force_lm_dropout_during_inference (`bool`, defaults to False):
When True, apply ``lm_dropout`` even when ``model.eval()`` and
``lm_dropout`` > 0. Binder-design loads set this to True.
lm_mask_pct (`float`, defaults to 0.0):
Fraction of LM residue tokens randomly replaced with the LM mask
token before running the PLM backbone.
disable_msa_features (`bool`, defaults to False):
When True, zero out MSA-derived ``profile`` and ``deletion_mean``
before the inputs embedder (experimental medium/large checkpoints).
inputs (`InputsEmbedderConfig`):
Configuration for the inputs embedder module.
folding_trunk (`FoldingTrunkConfig`):
Configuration for the folding trunk.
structure_head (`DiffusionStructureHeadConfig`):
Configuration for the diffusion-based structure prediction head.
confidence_head (`ConfidenceHeadConfig`):
Configuration for the confidence prediction head.
Examples:
```python
>>> from transformers import ESMFold2Config, ESMFold2ExperimentalModel
>>> # Initializing an ESMFold2 configuration
>>> configuration = ESMFold2Config(type="experimental")
>>> # Initializing a model (with random weights) from the configuration
>>> model = ESMFold2ExperimentalModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "esmfold2"
has_no_defaults_at_init = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type: str = kwargs.get("type", "release")
if self.type not in ("release", "experimental"):
raise ValueError(
f"ESMFold2Config.type must be 'release' or 'experimental', "
f"got {self.type!r}"
)
# Top-level scalar fields
self.d_single: int = kwargs.get("d_single", 384)
self.d_pair: int = kwargs.get("d_pair", 256)
self.n_relative_residx_bins: int = kwargs.get("n_relative_residx_bins", 32)
self.n_relative_chain_bins: int = kwargs.get("n_relative_chain_bins", 2)
self.num_loops: int = kwargs.get("num_loops", 10)
self.num_diffusion_samples: int = kwargs.get("num_diffusion_samples", 8)
# If True, ``profile`` / ``deletion_mean`` are zeroed before the inputs
# embedder.
self.disable_msa_features: bool = kwargs.get("disable_msa_features", False)
self.lm_dropout: float = kwargs.get("lm_dropout", 0.0)
self.force_lm_dropout_during_inference: bool = kwargs.get(
"force_lm_dropout_during_inference", False
)
self.lm_mask_pct: float = kwargs.get("lm_mask_pct", 0.0)
self.lm_d_model: int = kwargs.get("lm_d_model", 2560)
self.lm_num_layers: int = kwargs.get("lm_num_layers", 80)
# Backward-compatible field name; values now point to FastPLMs ESM++.
raw_esmc_id = (
kwargs["esmc_id"] if "esmc_id" in kwargs else _DEFAULT_ESMC_HF_REPO
)
self.esmc_id: str = normalize_esmc_id(raw_esmc_id)
self.esmc_attn_backend: str = (
kwargs["esmc_attn_backend"]
if "esmc_attn_backend" in kwargs
else _DEFAULT_ESMC_ATTN_BACKEND
)
def _init_nested(cls, val):
if isinstance(val, cls):
return val
if isinstance(val, dict):
return cls(**val)
return cls()
self.inputs = _init_nested(InputsEmbedderConfig, kwargs.get("inputs"))
self.folding_trunk = _init_nested(
FoldingTrunkConfig, kwargs.get("folding_trunk")
)
self.structure_head = _init_nested(
DiffusionStructureHeadConfig, kwargs.get("structure_head")
)
self.confidence_head = _init_nested(
ConfidenceHeadConfig, kwargs.get("confidence_head")
)
self.msa_encoder = _init_nested(MSAEncoderConfig, kwargs.get("msa_encoder"))
# Release-only modules — ignored when ``type == "experimental"``.
self.parcae = _init_nested(ParcaeConfig, kwargs.get("parcae"))
self.lm_encoder = _init_nested(LMEncoderConfig, kwargs.get("lm_encoder"))
# If True, MSA encoder output replaces the pair stream; if False, it is added.
self.msa_encoder_overwrite: bool = bool(
kwargs.get("msa_encoder_overwrite", True)
)
def to_dict(self):
output = super().to_dict()
output["inputs"] = asdict(self.inputs)
output["folding_trunk"] = asdict(self.folding_trunk)
output["structure_head"] = asdict(self.structure_head)
output["confidence_head"] = asdict(self.confidence_head)
output["msa_encoder"] = asdict(self.msa_encoder)
output["parcae"] = asdict(self.parcae)
output["lm_encoder"] = asdict(self.lm_encoder)
return output
__all__ = [
"ESMFold2Config",
"MSAEncoderConfig",
"ParcaeConfig",
"LMEncoderConfig",
"normalize_esmc_id",
]