Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2-Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2-Fast with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/ESMFold2-Fast", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2-Fast", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- README.md +22 -22
- __init__.py +5 -5
- configuration_esmfold2.py +19 -19
- modeling_esmfold2.py +257 -257
- modeling_esmfold2_common.py +81 -81
- modeling_esmfold2_experimental.py +31 -15
README.md
CHANGED
|
@@ -221,27 +221,27 @@ with torch.inference_mode():
|
|
| 221 |
decoded = model.input_builder.decode(output, features, chain_infos)
|
| 222 |
```
|
| 223 |
|
| 224 |
-
Set `load_esmc=False` when loading if you want to provide precomputed `lm_hidden_states` manually or run folding-trunk tests without loading the 6B ESM++ backbone:
|
| 225 |
-
|
| 226 |
-
```python
|
| 227 |
-
model = AutoModel.from_pretrained(
|
| 228 |
-
"Synthyra/ESMFold2-Fast",
|
| 229 |
trust_remote_code=True,
|
| 230 |
load_esmc=False,
|
| 231 |
-
).cuda().eval()
|
| 232 |
-
```
|
| 233 |
-
|
| 234 |
-
For FP8 LM inference, install `transformer_engine.pytorch` in a CUDA
|
| 235 |
-
environment with FP8-capable hardware and load the shared FastPLMs ESM++
|
| 236 |
-
backbone with:
|
| 237 |
-
|
| 238 |
-
```python
|
| 239 |
-
model = AutoModel.from_pretrained(
|
| 240 |
-
"Synthyra/ESMFold2-Fast",
|
| 241 |
-
trust_remote_code=True,
|
| 242 |
-
esmc_precision="fp8",
|
| 243 |
-
).cuda().eval()
|
| 244 |
-
```
|
| 245 |
-
|
| 246 |
-
FP8 is inference-only for the ESMFold2 LM backbone. TTT remains a bf16/fp32
|
| 247 |
-
path.
|
|
|
|
| 221 |
decoded = model.input_builder.decode(output, features, chain_infos)
|
| 222 |
```
|
| 223 |
|
| 224 |
+
Set `load_esmc=False` when loading if you want to provide precomputed `lm_hidden_states` manually or run folding-trunk tests without loading the 6B ESM++ backbone:
|
| 225 |
+
|
| 226 |
+
```python
|
| 227 |
+
model = AutoModel.from_pretrained(
|
| 228 |
+
"Synthyra/ESMFold2-Fast",
|
| 229 |
trust_remote_code=True,
|
| 230 |
load_esmc=False,
|
| 231 |
+
).cuda().eval()
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
For FP8 LM inference, install `transformer_engine.pytorch` in a CUDA
|
| 235 |
+
environment with FP8-capable hardware and load the shared FastPLMs ESM++
|
| 236 |
+
backbone with:
|
| 237 |
+
|
| 238 |
+
```python
|
| 239 |
+
model = AutoModel.from_pretrained(
|
| 240 |
+
"Synthyra/ESMFold2-Fast",
|
| 241 |
+
trust_remote_code=True,
|
| 242 |
+
esmc_precision="fp8",
|
| 243 |
+
).cuda().eval()
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
FP8 is inference-only for the ESMFold2 LM backbone. TTT remains a bf16/fp32
|
| 247 |
+
path.
|
__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from .configuration_esmfold2 import ESMFold2Config
|
| 2 |
-
from .modeling_esmfold2_experimental import ESMFold2ExperimentalModel
|
| 3 |
-
from .modeling_esmfold2 import ESMFold2Model
|
| 4 |
-
|
| 5 |
-
__all__ = ["ESMFold2Config", "ESMFold2ExperimentalModel", "ESMFold2Model"]
|
|
|
|
| 1 |
+
from .configuration_esmfold2 import ESMFold2Config
|
| 2 |
+
from .modeling_esmfold2_experimental import ESMFold2ExperimentalModel
|
| 3 |
+
from .modeling_esmfold2 import ESMFold2Model
|
| 4 |
+
|
| 5 |
+
__all__ = ["ESMFold2Config", "ESMFold2ExperimentalModel", "ESMFold2Model"]
|
configuration_esmfold2.py
CHANGED
|
@@ -201,19 +201,19 @@ class ESMFold2Config(PretrainedConfig):
|
|
| 201 |
Number of trunk loops for iterative refinement.
|
| 202 |
num_diffusion_samples (`int`, defaults to 8):
|
| 203 |
Number of parallel structure predictions to generate.
|
| 204 |
-
lm_dropout (`float`, defaults to 0.0):
|
| 205 |
-
Dropout probability on LM pair embeddings. When > 0, dropout is
|
| 206 |
-
applied with ``training=True`` (including at inference) to match
|
| 207 |
-
the experimental training recipe used by binder design.
|
| 208 |
-
force_lm_dropout_during_inference (`bool`, defaults to False):
|
| 209 |
-
When True, apply ``lm_dropout`` even when ``model.eval()`` and
|
| 210 |
-
``lm_dropout`` > 0. Binder-design loads set this to True.
|
| 211 |
-
lm_mask_pct (`float`, defaults to 0.0):
|
| 212 |
-
Fraction of LM residue tokens randomly replaced with the LM mask
|
| 213 |
-
token before running the PLM backbone.
|
| 214 |
-
disable_msa_features (`bool`, defaults to False):
|
| 215 |
-
When True, zero out MSA-derived ``profile`` and ``deletion_mean``
|
| 216 |
-
before the inputs embedder (experimental medium/large checkpoints).
|
| 217 |
inputs (`InputsEmbedderConfig`):
|
| 218 |
Configuration for the inputs embedder module.
|
| 219 |
folding_trunk (`FoldingTrunkConfig`):
|
|
@@ -263,12 +263,12 @@ class ESMFold2Config(PretrainedConfig):
|
|
| 263 |
# embedder.
|
| 264 |
self.disable_msa_features: bool = kwargs.get("disable_msa_features", False)
|
| 265 |
self.lm_dropout: float = kwargs.get("lm_dropout", 0.0)
|
| 266 |
-
self.force_lm_dropout_during_inference: bool = kwargs.get(
|
| 267 |
-
"force_lm_dropout_during_inference", False
|
| 268 |
-
)
|
| 269 |
-
self.lm_mask_pct: float = kwargs.get("lm_mask_pct", 0.0)
|
| 270 |
-
|
| 271 |
-
self.lm_d_model: int = kwargs.get("lm_d_model", 2560)
|
| 272 |
self.lm_num_layers: int = kwargs.get("lm_num_layers", 80)
|
| 273 |
# Backward-compatible field name; values now point to FastPLMs ESM++.
|
| 274 |
raw_esmc_id = (
|
|
|
|
| 201 |
Number of trunk loops for iterative refinement.
|
| 202 |
num_diffusion_samples (`int`, defaults to 8):
|
| 203 |
Number of parallel structure predictions to generate.
|
| 204 |
+
lm_dropout (`float`, defaults to 0.0):
|
| 205 |
+
Dropout probability on LM pair embeddings. When > 0, dropout is
|
| 206 |
+
applied with ``training=True`` (including at inference) to match
|
| 207 |
+
the experimental training recipe used by binder design.
|
| 208 |
+
force_lm_dropout_during_inference (`bool`, defaults to False):
|
| 209 |
+
When True, apply ``lm_dropout`` even when ``model.eval()`` and
|
| 210 |
+
``lm_dropout`` > 0. Binder-design loads set this to True.
|
| 211 |
+
lm_mask_pct (`float`, defaults to 0.0):
|
| 212 |
+
Fraction of LM residue tokens randomly replaced with the LM mask
|
| 213 |
+
token before running the PLM backbone.
|
| 214 |
+
disable_msa_features (`bool`, defaults to False):
|
| 215 |
+
When True, zero out MSA-derived ``profile`` and ``deletion_mean``
|
| 216 |
+
before the inputs embedder (experimental medium/large checkpoints).
|
| 217 |
inputs (`InputsEmbedderConfig`):
|
| 218 |
Configuration for the inputs embedder module.
|
| 219 |
folding_trunk (`FoldingTrunkConfig`):
|
|
|
|
| 263 |
# embedder.
|
| 264 |
self.disable_msa_features: bool = kwargs.get("disable_msa_features", False)
|
| 265 |
self.lm_dropout: float = kwargs.get("lm_dropout", 0.0)
|
| 266 |
+
self.force_lm_dropout_during_inference: bool = kwargs.get(
|
| 267 |
+
"force_lm_dropout_during_inference", False
|
| 268 |
+
)
|
| 269 |
+
self.lm_mask_pct: float = kwargs.get("lm_mask_pct", 0.0)
|
| 270 |
+
|
| 271 |
+
self.lm_d_model: int = kwargs.get("lm_d_model", 2560)
|
| 272 |
self.lm_num_layers: int = kwargs.get("lm_num_layers", 80)
|
| 273 |
# Backward-compatible field name; values now point to FastPLMs ESM++.
|
| 274 |
raw_esmc_id = (
|
modeling_esmfold2.py
CHANGED
|
@@ -59,12 +59,12 @@ from .modeling_esmfold2_common import (
|
|
| 59 |
TriangleMultiplicativeUpdate,
|
| 60 |
_categorical_mean,
|
| 61 |
_compute_intra_token_idx,
|
| 62 |
-
compute_lm_hidden_states,
|
| 63 |
-
gather_rep_atom_coords,
|
| 64 |
-
gather_token_to_atom,
|
| 65 |
-
maybe_apply_msa_column_masking,
|
| 66 |
-
maybe_subsample_msa,
|
| 67 |
-
)
|
| 68 |
from .esmfold2_affine3d import Affine3D as _FastPLMSESMFold2Affine3D
|
| 69 |
from .esmfold2_aligner import Aligner as _FastPLMSESMFold2Aligner
|
| 70 |
from .esmfold2_atom_indexer import AtomIndexer as _FastPLMSESMFold2AtomIndexer
|
|
@@ -699,27 +699,27 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 699 |
self.post_init()
|
| 700 |
self.init_ttt({"lora_target_replace_module": "MultiHeadAttention"})
|
| 701 |
|
| 702 |
-
def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None:
|
| 703 |
-
"""Load the FastPLMs ESM++ LM used as the ESMFold2 PLM backbone.
|
| 704 |
-
|
| 705 |
-
``precision``: ``"bf16"`` (default), ``"fp32"``, or opt-in ``"fp8"``.
|
| 706 |
-
"""
|
| 707 |
-
dtype_map = {
|
| 708 |
-
"bf16": torch.bfloat16,
|
| 709 |
-
"fp32": torch.float32,
|
| 710 |
-
"fp8": torch.bfloat16,
|
| 711 |
-
}
|
| 712 |
-
if precision not in dtype_map:
|
| 713 |
-
raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
|
| 714 |
-
if precision == "fp8" and not TE_AVAILABLE:
|
| 715 |
-
raise RuntimeError(
|
| 716 |
-
"esmc_precision='fp8' requires transformer_engine.pytorch."
|
| 717 |
-
)
|
| 718 |
-
dtype = dtype_map[precision]
|
| 719 |
-
|
| 720 |
-
esmc = _load_fastplms_esmplusplus_for_esmfold2(
|
| 721 |
-
esmc_model_path=esmc_model_path,
|
| 722 |
-
attn_backend=self.config.esmc_attn_backend,
|
| 723 |
device=self.device,
|
| 724 |
dtype=dtype,
|
| 725 |
)
|
|
@@ -730,24 +730,24 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 730 |
assert esmc.config.num_hidden_layers == self.config.lm_num_layers, (
|
| 731 |
f"ESMFold2 expected lm_num_layers={self.config.lm_num_layers}, "
|
| 732 |
f"but loaded ESM++ num_hidden_layers={esmc.config.num_hidden_layers}."
|
| 733 |
-
)
|
| 734 |
-
for p in esmc.parameters():
|
| 735 |
-
p.requires_grad_(False)
|
| 736 |
-
|
| 737 |
-
if precision == "fp8":
|
| 738 |
-
with torch.no_grad():
|
| 739 |
-
_convert_te_modules_to_fp8_inplace(esmc)
|
| 740 |
-
|
| 741 |
-
self._esmc_fp8 = precision == "fp8"
|
| 742 |
-
self._esmc = esmc
|
| 743 |
-
self._ttt_lm_head = None
|
| 744 |
-
|
| 745 |
-
def _ensure_ttt_lm_head(self) -> None:
|
| 746 |
-
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
|
| 747 |
-
if self._esmc_fp8:
|
| 748 |
-
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
|
| 749 |
-
if self._ttt_lm_head is not None:
|
| 750 |
-
return
|
| 751 |
try:
|
| 752 |
from fastplms.esm_plusplus.modeling_esm_plusplus import (
|
| 753 |
ESMplusplusConfig,
|
|
@@ -781,11 +781,11 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 781 |
self._ttt_lm_head.requires_grad_(False)
|
| 782 |
del mlm
|
| 783 |
|
| 784 |
-
def _ttt_get_trainable_modules(self) -> list[nn.Module]:
|
| 785 |
-
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
|
| 786 |
-
if self._esmc_fp8:
|
| 787 |
-
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
|
| 788 |
-
return [self._esmc]
|
| 789 |
|
| 790 |
def _ttt_tokenize(
|
| 791 |
self,
|
|
@@ -846,13 +846,13 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 846 |
**kwargs,
|
| 847 |
) -> torch.Tensor:
|
| 848 |
del kwargs
|
| 849 |
-
assert isinstance(batch, torch.Tensor), (
|
| 850 |
-
"ESMFold2 TTT expects input_ids tensors."
|
| 851 |
-
)
|
| 852 |
-
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
|
| 853 |
-
if self._esmc_fp8:
|
| 854 |
-
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
|
| 855 |
-
self._ensure_ttt_lm_head()
|
| 856 |
assert self._ttt_lm_head is not None
|
| 857 |
attention_mask = batch.ne(SEQUENCE_PAD_TOKEN)
|
| 858 |
output = self._esmc(
|
|
@@ -947,30 +947,30 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 947 |
if self.msa_encoder is not None:
|
| 948 |
self.msa_encoder.set_chunk_size(chunk_size)
|
| 949 |
|
| 950 |
-
def _compute_lm_hidden_states(
|
| 951 |
-
self,
|
| 952 |
-
input_ids: Tensor,
|
| 953 |
-
asym_id: Tensor,
|
| 954 |
-
residue_index: Tensor,
|
| 955 |
-
mol_type: Tensor,
|
| 956 |
-
tok_mask: Tensor,
|
| 957 |
-
lm_mask_pct: float = 0.0,
|
| 958 |
-
) -> Tensor:
|
| 959 |
-
assert self._esmc is not None
|
| 960 |
-
# fp8 TE kernels require prod(shape[:-1]) % 8 == 0.
|
| 961 |
-
pad_to = 8 if self._esmc_fp8 else None
|
| 962 |
-
with _lm_precision_context(self._esmc_fp8):
|
| 963 |
return compute_lm_hidden_states(
|
| 964 |
self._esmc,
|
| 965 |
input_ids,
|
| 966 |
asym_id,
|
| 967 |
residue_index,
|
| 968 |
-
mol_type,
|
| 969 |
-
tok_mask,
|
| 970 |
-
pad_to_multiple=pad_to,
|
| 971 |
-
lm_mask_pct=lm_mask_pct,
|
| 972 |
-
mask_token_id=SEQUENCE_MASK_TOKEN,
|
| 973 |
-
)
|
| 974 |
|
| 975 |
def _discretized_dynamics(self) -> tuple[Tensor, Tensor]:
|
| 976 |
delta = F.softplus(self.parcae_log_delta)
|
|
@@ -985,17 +985,17 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 985 |
return state.to(dtype=ref.dtype)
|
| 986 |
|
| 987 |
def _run_one_loop(
|
| 988 |
-
self,
|
| 989 |
-
z: Tensor,
|
| 990 |
-
z_init: Tensor,
|
| 991 |
-
lm_z: Tensor | None,
|
| 992 |
-
_msa_inputs: dict | None,
|
| 993 |
-
pair_mask: Tensor,
|
| 994 |
-
a: Tensor,
|
| 995 |
-
b_mat: Tensor,
|
| 996 |
-
tok_mask: Tensor,
|
| 997 |
-
total_steps: int,
|
| 998 |
-
) -> Tensor:
|
| 999 |
# Helper method (not inline) so per-iter locals free on return —
|
| 1000 |
# otherwise leaks ~2 GB L²×c_z into distogram/sample scope.
|
| 1001 |
# training=True forces dropout under eval(), matching the per-loop
|
|
@@ -1025,49 +1025,49 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 1025 |
if lm_z_i is not None and self.lm_encoder is None:
|
| 1026 |
z_inject_pair = z_inject_pair + lm_z_i.to(z_inject_pair.dtype)
|
| 1027 |
|
| 1028 |
-
if self.msa_encoder is not None and _msa_inputs is not None:
|
| 1029 |
-
msa_i, mask_i, hd_i, dv_i = maybe_subsample_msa(
|
| 1030 |
-
_msa_inputs["msa"],
|
| 1031 |
-
_msa_inputs["msa_attention_mask"],
|
| 1032 |
-
_msa_inputs["has_deletion"],
|
| 1033 |
-
_msa_inputs["deletion_value"],
|
| 1034 |
-
max_depth=_msa_inputs["max_depth"],
|
| 1035 |
-
enabled=_msa_inputs["subsample_enabled"],
|
| 1036 |
-
)
|
| 1037 |
-
B_msa, M, L_msa = msa_i.shape
|
| 1038 |
-
msa_oh = F.one_hot(
|
| 1039 |
-
msa_i.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES
|
| 1040 |
-
).float()
|
| 1041 |
-
msa_attn = (
|
| 1042 |
-
mask_i.permute(0, 2, 1).float()
|
| 1043 |
-
if mask_i is not None
|
| 1044 |
-
else tok_mask[:, :, None].expand(-1, -1, M).float()
|
| 1045 |
-
)
|
| 1046 |
-
# Bias-free MSAEncoder.embed requires zeroed padding.
|
| 1047 |
-
msa_oh = msa_oh * msa_attn.unsqueeze(-1)
|
| 1048 |
-
hd = (
|
| 1049 |
-
hd_i.permute(0, 2, 1).float()
|
| 1050 |
-
if hd_i is not None
|
| 1051 |
-
else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
|
| 1052 |
-
)
|
| 1053 |
-
dv = (
|
| 1054 |
-
dv_i.permute(0, 2, 1).float()
|
| 1055 |
-
if dv_i is not None
|
| 1056 |
-
else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
|
| 1057 |
-
)
|
| 1058 |
-
msa_pair = self.msa_encoder(
|
| 1059 |
-
x_pair=z_inject_pair,
|
| 1060 |
-
x_inputs=_msa_inputs["x_inputs"],
|
| 1061 |
-
msa_oh=msa_oh,
|
| 1062 |
-
has_deletion=hd,
|
| 1063 |
-
deletion_value=dv,
|
| 1064 |
-
msa_attention_mask=msa_attn,
|
| 1065 |
-
).to(z_inject_pair.dtype)
|
| 1066 |
-
z_inject_pair = (
|
| 1067 |
-
msa_pair
|
| 1068 |
-
if self.config.msa_encoder_overwrite
|
| 1069 |
-
else (z_inject_pair + msa_pair)
|
| 1070 |
-
)
|
| 1071 |
|
| 1072 |
if refined_lm_z is not None:
|
| 1073 |
z_inject_pair = z_inject_pair + refined_lm_z.to(z_inject_pair.dtype)
|
|
@@ -1104,16 +1104,16 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 1104 |
deletion_value: Tensor | None = None,
|
| 1105 |
msa_attention_mask: Tensor | None = None,
|
| 1106 |
input_ids: Tensor | None = None,
|
| 1107 |
-
lm_hidden_states: Tensor | None = None,
|
| 1108 |
-
num_loops: int | None = None,
|
| 1109 |
-
num_diffusion_samples: int | None = None,
|
| 1110 |
-
num_sampling_steps: int | None = None,
|
| 1111 |
-
lm_mask_pct: float | None = None,
|
| 1112 |
-
msa_max_depth: int = 1024,
|
| 1113 |
-
msa_column_mask_rate: float = 0.1,
|
| 1114 |
-
msa_subsample_at_inference: bool = True,
|
| 1115 |
-
**kwargs,
|
| 1116 |
-
) -> dict[str, Tensor]:
|
| 1117 |
tok_mask = token_attention_mask
|
| 1118 |
atm_mask = atom_attention_mask
|
| 1119 |
disto_idx = distogram_atom_idx
|
|
@@ -1196,19 +1196,19 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 1196 |
lm_hidden_states is None
|
| 1197 |
and input_ids is not None
|
| 1198 |
and self._esmc is not None
|
| 1199 |
-
):
|
| 1200 |
-
lm_hidden_states = self._compute_lm_hidden_states(
|
| 1201 |
-
input_ids,
|
| 1202 |
-
asym_id,
|
| 1203 |
-
residue_index,
|
| 1204 |
-
mol_type,
|
| 1205 |
-
tok_mask,
|
| 1206 |
-
lm_mask_pct=(
|
| 1207 |
-
self.config.lm_mask_pct
|
| 1208 |
-
if lm_mask_pct is None
|
| 1209 |
-
else lm_mask_pct
|
| 1210 |
-
),
|
| 1211 |
-
)
|
| 1212 |
lm_z: Tensor | None = None
|
| 1213 |
if lm_hidden_states is not None:
|
| 1214 |
lm_z = self.language_model(lm_hidden_states.detach())
|
|
@@ -1222,35 +1222,35 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 1222 |
a = a.view(1, 1, 1, -1).to(device=z.device, dtype=z.dtype)
|
| 1223 |
b_mat = b.to(device=z.device, dtype=z.dtype)
|
| 1224 |
|
| 1225 |
-
_msa_inputs: dict | None = None
|
| 1226 |
-
if self.msa_encoder is not None and msa is not None:
|
| 1227 |
-
msa_attention_mask = maybe_apply_msa_column_masking(
|
| 1228 |
-
msa_attention_mask,
|
| 1229 |
-
msa_column_mask_rate,
|
| 1230 |
-
)
|
| 1231 |
-
_msa_inputs = dict(
|
| 1232 |
-
x_inputs=x_inputs,
|
| 1233 |
-
msa=msa,
|
| 1234 |
-
msa_attention_mask=msa_attention_mask,
|
| 1235 |
-
has_deletion=has_deletion,
|
| 1236 |
-
deletion_value=deletion_value,
|
| 1237 |
-
max_depth=msa_max_depth,
|
| 1238 |
-
subsample_enabled=msa_subsample_at_inference,
|
| 1239 |
-
)
|
| 1240 |
|
| 1241 |
# Method call (not inline loop) frees per-iter L²×c_z locals.
|
| 1242 |
z = self._run_one_loop(
|
| 1243 |
-
z=z,
|
| 1244 |
-
z_init=z_init,
|
| 1245 |
-
lm_z=lm_z,
|
| 1246 |
-
_msa_inputs=_msa_inputs,
|
| 1247 |
-
pair_mask=pair_mask,
|
| 1248 |
-
a=a,
|
| 1249 |
-
b_mat=b_mat,
|
| 1250 |
-
tok_mask=tok_mask,
|
| 1251 |
-
total_steps=total_steps,
|
| 1252 |
-
)
|
| 1253 |
-
del z_init, lm_z, _msa_inputs, a, b_mat
|
| 1254 |
|
| 1255 |
z = self.parcae_readout(z)
|
| 1256 |
z = self.parcae_coda(z, pair_attention_mask=pair_mask)
|
|
@@ -1362,38 +1362,38 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 1362 |
complex_id=complex_id,
|
| 1363 |
)
|
| 1364 |
|
| 1365 |
-
def _fold_protein_no_ttt(
|
| 1366 |
-
self,
|
| 1367 |
-
sequence: str,
|
| 1368 |
-
*,
|
| 1369 |
-
chain_id: str = "A",
|
| 1370 |
-
msa: Any | None = None,
|
| 1371 |
-
msa_path: str | Path | None = None,
|
| 1372 |
-
msa_max_sequences: int | None = None,
|
| 1373 |
-
num_loops: int = 3,
|
| 1374 |
-
num_sampling_steps: int = 50,
|
| 1375 |
-
num_diffusion_samples: int = 1,
|
| 1376 |
-
seed: int | None = None,
|
| 1377 |
-
complex_id: str = "pred",
|
| 1378 |
-
):
|
| 1379 |
-
from .esmfold2_types import MSA, ProteinInput, StructurePredictionInput
|
| 1380 |
-
|
| 1381 |
-
assert not (
|
| 1382 |
-
msa is not None and msa_path is not None
|
| 1383 |
-
), "Pass at most one of msa or msa_path."
|
| 1384 |
-
if msa_path is not None:
|
| 1385 |
-
msa = MSA.from_a3m(msa_path, max_sequences=msa_max_sequences)
|
| 1386 |
-
if msa is not None:
|
| 1387 |
-
query = str(msa.query).replace("-", "").upper()
|
| 1388 |
-
assert query == sequence.upper(), (
|
| 1389 |
-
f"MSA query does not match sequence: expected {sequence.upper()!r}, got {query!r}"
|
| 1390 |
-
)
|
| 1391 |
-
|
| 1392 |
-
input = StructurePredictionInput(
|
| 1393 |
-
sequences=[ProteinInput(id=chain_id, sequence=sequence, msa=msa)]
|
| 1394 |
-
)
|
| 1395 |
-
return self.fold(
|
| 1396 |
-
input,
|
| 1397 |
num_loops=num_loops,
|
| 1398 |
num_sampling_steps=num_sampling_steps,
|
| 1399 |
num_diffusion_samples=num_diffusion_samples,
|
|
@@ -1442,15 +1442,15 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 1442 |
|
| 1443 |
def fold_protein(
|
| 1444 |
self,
|
| 1445 |
-
sequence: str,
|
| 1446 |
-
*,
|
| 1447 |
-
chain_id: str = "A",
|
| 1448 |
-
msa: Any | None = None,
|
| 1449 |
-
msa_path: str | Path | None = None,
|
| 1450 |
-
msa_max_sequences: int | None = None,
|
| 1451 |
-
num_loops: int = 3,
|
| 1452 |
-
num_sampling_steps: int = 50,
|
| 1453 |
-
num_diffusion_samples: int = 1,
|
| 1454 |
seed: int | None = None,
|
| 1455 |
complex_id: str = "pred",
|
| 1456 |
ttt: bool = False,
|
|
@@ -1458,57 +1458,57 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
|
|
| 1458 |
):
|
| 1459 |
if ttt:
|
| 1460 |
return self.fold_protein_ttt(
|
| 1461 |
-
sequence=sequence,
|
| 1462 |
-
chain_id=chain_id,
|
| 1463 |
-
msa=msa,
|
| 1464 |
-
msa_path=msa_path,
|
| 1465 |
-
msa_max_sequences=msa_max_sequences,
|
| 1466 |
-
num_loops=num_loops,
|
| 1467 |
-
num_sampling_steps=num_sampling_steps,
|
| 1468 |
-
num_diffusion_samples=num_diffusion_samples,
|
| 1469 |
seed=seed,
|
| 1470 |
complex_id=complex_id,
|
| 1471 |
ttt_config=ttt_config,
|
| 1472 |
)
|
| 1473 |
return self._fold_protein_no_ttt(
|
| 1474 |
-
sequence=sequence,
|
| 1475 |
-
chain_id=chain_id,
|
| 1476 |
-
msa=msa,
|
| 1477 |
-
msa_path=msa_path,
|
| 1478 |
-
msa_max_sequences=msa_max_sequences,
|
| 1479 |
-
num_loops=num_loops,
|
| 1480 |
-
num_sampling_steps=num_sampling_steps,
|
| 1481 |
-
num_diffusion_samples=num_diffusion_samples,
|
| 1482 |
seed=seed,
|
| 1483 |
complex_id=complex_id,
|
| 1484 |
)
|
| 1485 |
|
| 1486 |
def fold_protein_ttt(
|
| 1487 |
self,
|
| 1488 |
-
sequence: str,
|
| 1489 |
-
*,
|
| 1490 |
-
chain_id: str = "A",
|
| 1491 |
-
msa: Any | None = None,
|
| 1492 |
-
msa_path: str | Path | None = None,
|
| 1493 |
-
msa_max_sequences: int | None = None,
|
| 1494 |
-
num_loops: int = 3,
|
| 1495 |
-
num_sampling_steps: int = 50,
|
| 1496 |
-
num_diffusion_samples: int = 1,
|
| 1497 |
seed: int | None = None,
|
| 1498 |
complex_id: str = "pred",
|
| 1499 |
ttt_config: TTTConfig | dict[str, Any] | None = None,
|
| 1500 |
-
):
|
| 1501 |
-
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
|
| 1502 |
-
if self._esmc_fp8:
|
| 1503 |
-
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
|
| 1504 |
-
fold_kwargs = {
|
| 1505 |
-
"chain_id": chain_id,
|
| 1506 |
-
"msa": msa,
|
| 1507 |
-
"msa_path": msa_path,
|
| 1508 |
-
"msa_max_sequences": msa_max_sequences,
|
| 1509 |
-
"num_loops": num_loops,
|
| 1510 |
-
"num_sampling_steps": num_sampling_steps,
|
| 1511 |
-
"num_diffusion_samples": num_diffusion_samples,
|
| 1512 |
"seed": seed,
|
| 1513 |
"complex_id": complex_id,
|
| 1514 |
}
|
|
|
|
| 59 |
TriangleMultiplicativeUpdate,
|
| 60 |
_categorical_mean,
|
| 61 |
_compute_intra_token_idx,
|
| 62 |
+
compute_lm_hidden_states,
|
| 63 |
+
gather_rep_atom_coords,
|
| 64 |
+
gather_token_to_atom,
|
| 65 |
+
maybe_apply_msa_column_masking,
|
| 66 |
+
maybe_subsample_msa,
|
| 67 |
+
)
|
| 68 |
from .esmfold2_affine3d import Affine3D as _FastPLMSESMFold2Affine3D
|
| 69 |
from .esmfold2_aligner import Aligner as _FastPLMSESMFold2Aligner
|
| 70 |
from .esmfold2_atom_indexer import AtomIndexer as _FastPLMSESMFold2AtomIndexer
|
|
|
|
| 699 |
self.post_init()
|
| 700 |
self.init_ttt({"lora_target_replace_module": "MultiHeadAttention"})
|
| 701 |
|
| 702 |
+
def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None:
|
| 703 |
+
"""Load the FastPLMs ESM++ LM used as the ESMFold2 PLM backbone.
|
| 704 |
+
|
| 705 |
+
``precision``: ``"bf16"`` (default), ``"fp32"``, or opt-in ``"fp8"``.
|
| 706 |
+
"""
|
| 707 |
+
dtype_map = {
|
| 708 |
+
"bf16": torch.bfloat16,
|
| 709 |
+
"fp32": torch.float32,
|
| 710 |
+
"fp8": torch.bfloat16,
|
| 711 |
+
}
|
| 712 |
+
if precision not in dtype_map:
|
| 713 |
+
raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
|
| 714 |
+
if precision == "fp8" and not TE_AVAILABLE:
|
| 715 |
+
raise RuntimeError(
|
| 716 |
+
"esmc_precision='fp8' requires transformer_engine.pytorch."
|
| 717 |
+
)
|
| 718 |
+
dtype = dtype_map[precision]
|
| 719 |
+
|
| 720 |
+
esmc = _load_fastplms_esmplusplus_for_esmfold2(
|
| 721 |
+
esmc_model_path=esmc_model_path,
|
| 722 |
+
attn_backend=self.config.esmc_attn_backend,
|
| 723 |
device=self.device,
|
| 724 |
dtype=dtype,
|
| 725 |
)
|
|
|
|
| 730 |
assert esmc.config.num_hidden_layers == self.config.lm_num_layers, (
|
| 731 |
f"ESMFold2 expected lm_num_layers={self.config.lm_num_layers}, "
|
| 732 |
f"but loaded ESM++ num_hidden_layers={esmc.config.num_hidden_layers}."
|
| 733 |
+
)
|
| 734 |
+
for p in esmc.parameters():
|
| 735 |
+
p.requires_grad_(False)
|
| 736 |
+
|
| 737 |
+
if precision == "fp8":
|
| 738 |
+
with torch.no_grad():
|
| 739 |
+
_convert_te_modules_to_fp8_inplace(esmc)
|
| 740 |
+
|
| 741 |
+
self._esmc_fp8 = precision == "fp8"
|
| 742 |
+
self._esmc = esmc
|
| 743 |
+
self._ttt_lm_head = None
|
| 744 |
+
|
| 745 |
+
def _ensure_ttt_lm_head(self) -> None:
|
| 746 |
+
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
|
| 747 |
+
if self._esmc_fp8:
|
| 748 |
+
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
|
| 749 |
+
if self._ttt_lm_head is not None:
|
| 750 |
+
return
|
| 751 |
try:
|
| 752 |
from fastplms.esm_plusplus.modeling_esm_plusplus import (
|
| 753 |
ESMplusplusConfig,
|
|
|
|
| 781 |
self._ttt_lm_head.requires_grad_(False)
|
| 782 |
del mlm
|
| 783 |
|
| 784 |
+
def _ttt_get_trainable_modules(self) -> list[nn.Module]:
|
| 785 |
+
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
|
| 786 |
+
if self._esmc_fp8:
|
| 787 |
+
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
|
| 788 |
+
return [self._esmc]
|
| 789 |
|
| 790 |
def _ttt_tokenize(
|
| 791 |
self,
|
|
|
|
| 846 |
**kwargs,
|
| 847 |
) -> torch.Tensor:
|
| 848 |
del kwargs
|
| 849 |
+
assert isinstance(batch, torch.Tensor), (
|
| 850 |
+
"ESMFold2 TTT expects input_ids tensors."
|
| 851 |
+
)
|
| 852 |
+
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
|
| 853 |
+
if self._esmc_fp8:
|
| 854 |
+
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
|
| 855 |
+
self._ensure_ttt_lm_head()
|
| 856 |
assert self._ttt_lm_head is not None
|
| 857 |
attention_mask = batch.ne(SEQUENCE_PAD_TOKEN)
|
| 858 |
output = self._esmc(
|
|
|
|
| 947 |
if self.msa_encoder is not None:
|
| 948 |
self.msa_encoder.set_chunk_size(chunk_size)
|
| 949 |
|
| 950 |
+
def _compute_lm_hidden_states(
|
| 951 |
+
self,
|
| 952 |
+
input_ids: Tensor,
|
| 953 |
+
asym_id: Tensor,
|
| 954 |
+
residue_index: Tensor,
|
| 955 |
+
mol_type: Tensor,
|
| 956 |
+
tok_mask: Tensor,
|
| 957 |
+
lm_mask_pct: float = 0.0,
|
| 958 |
+
) -> Tensor:
|
| 959 |
+
assert self._esmc is not None
|
| 960 |
+
# fp8 TE kernels require prod(shape[:-1]) % 8 == 0.
|
| 961 |
+
pad_to = 8 if self._esmc_fp8 else None
|
| 962 |
+
with _lm_precision_context(self._esmc_fp8):
|
| 963 |
return compute_lm_hidden_states(
|
| 964 |
self._esmc,
|
| 965 |
input_ids,
|
| 966 |
asym_id,
|
| 967 |
residue_index,
|
| 968 |
+
mol_type,
|
| 969 |
+
tok_mask,
|
| 970 |
+
pad_to_multiple=pad_to,
|
| 971 |
+
lm_mask_pct=lm_mask_pct,
|
| 972 |
+
mask_token_id=SEQUENCE_MASK_TOKEN,
|
| 973 |
+
)
|
| 974 |
|
| 975 |
def _discretized_dynamics(self) -> tuple[Tensor, Tensor]:
|
| 976 |
delta = F.softplus(self.parcae_log_delta)
|
|
|
|
| 985 |
return state.to(dtype=ref.dtype)
|
| 986 |
|
| 987 |
def _run_one_loop(
|
| 988 |
+
self,
|
| 989 |
+
z: Tensor,
|
| 990 |
+
z_init: Tensor,
|
| 991 |
+
lm_z: Tensor | None,
|
| 992 |
+
_msa_inputs: dict | None,
|
| 993 |
+
pair_mask: Tensor,
|
| 994 |
+
a: Tensor,
|
| 995 |
+
b_mat: Tensor,
|
| 996 |
+
tok_mask: Tensor,
|
| 997 |
+
total_steps: int,
|
| 998 |
+
) -> Tensor:
|
| 999 |
# Helper method (not inline) so per-iter locals free on return —
|
| 1000 |
# otherwise leaks ~2 GB L²×c_z into distogram/sample scope.
|
| 1001 |
# training=True forces dropout under eval(), matching the per-loop
|
|
|
|
| 1025 |
if lm_z_i is not None and self.lm_encoder is None:
|
| 1026 |
z_inject_pair = z_inject_pair + lm_z_i.to(z_inject_pair.dtype)
|
| 1027 |
|
| 1028 |
+
if self.msa_encoder is not None and _msa_inputs is not None:
|
| 1029 |
+
msa_i, mask_i, hd_i, dv_i = maybe_subsample_msa(
|
| 1030 |
+
_msa_inputs["msa"],
|
| 1031 |
+
_msa_inputs["msa_attention_mask"],
|
| 1032 |
+
_msa_inputs["has_deletion"],
|
| 1033 |
+
_msa_inputs["deletion_value"],
|
| 1034 |
+
max_depth=_msa_inputs["max_depth"],
|
| 1035 |
+
enabled=_msa_inputs["subsample_enabled"],
|
| 1036 |
+
)
|
| 1037 |
+
B_msa, M, L_msa = msa_i.shape
|
| 1038 |
+
msa_oh = F.one_hot(
|
| 1039 |
+
msa_i.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES
|
| 1040 |
+
).float()
|
| 1041 |
+
msa_attn = (
|
| 1042 |
+
mask_i.permute(0, 2, 1).float()
|
| 1043 |
+
if mask_i is not None
|
| 1044 |
+
else tok_mask[:, :, None].expand(-1, -1, M).float()
|
| 1045 |
+
)
|
| 1046 |
+
# Bias-free MSAEncoder.embed requires zeroed padding.
|
| 1047 |
+
msa_oh = msa_oh * msa_attn.unsqueeze(-1)
|
| 1048 |
+
hd = (
|
| 1049 |
+
hd_i.permute(0, 2, 1).float()
|
| 1050 |
+
if hd_i is not None
|
| 1051 |
+
else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
|
| 1052 |
+
)
|
| 1053 |
+
dv = (
|
| 1054 |
+
dv_i.permute(0, 2, 1).float()
|
| 1055 |
+
if dv_i is not None
|
| 1056 |
+
else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
|
| 1057 |
+
)
|
| 1058 |
+
msa_pair = self.msa_encoder(
|
| 1059 |
+
x_pair=z_inject_pair,
|
| 1060 |
+
x_inputs=_msa_inputs["x_inputs"],
|
| 1061 |
+
msa_oh=msa_oh,
|
| 1062 |
+
has_deletion=hd,
|
| 1063 |
+
deletion_value=dv,
|
| 1064 |
+
msa_attention_mask=msa_attn,
|
| 1065 |
+
).to(z_inject_pair.dtype)
|
| 1066 |
+
z_inject_pair = (
|
| 1067 |
+
msa_pair
|
| 1068 |
+
if self.config.msa_encoder_overwrite
|
| 1069 |
+
else (z_inject_pair + msa_pair)
|
| 1070 |
+
)
|
| 1071 |
|
| 1072 |
if refined_lm_z is not None:
|
| 1073 |
z_inject_pair = z_inject_pair + refined_lm_z.to(z_inject_pair.dtype)
|
|
|
|
| 1104 |
deletion_value: Tensor | None = None,
|
| 1105 |
msa_attention_mask: Tensor | None = None,
|
| 1106 |
input_ids: Tensor | None = None,
|
| 1107 |
+
lm_hidden_states: Tensor | None = None,
|
| 1108 |
+
num_loops: int | None = None,
|
| 1109 |
+
num_diffusion_samples: int | None = None,
|
| 1110 |
+
num_sampling_steps: int | None = None,
|
| 1111 |
+
lm_mask_pct: float | None = None,
|
| 1112 |
+
msa_max_depth: int = 1024,
|
| 1113 |
+
msa_column_mask_rate: float = 0.1,
|
| 1114 |
+
msa_subsample_at_inference: bool = True,
|
| 1115 |
+
**kwargs,
|
| 1116 |
+
) -> dict[str, Tensor]:
|
| 1117 |
tok_mask = token_attention_mask
|
| 1118 |
atm_mask = atom_attention_mask
|
| 1119 |
disto_idx = distogram_atom_idx
|
|
|
|
| 1196 |
lm_hidden_states is None
|
| 1197 |
and input_ids is not None
|
| 1198 |
and self._esmc is not None
|
| 1199 |
+
):
|
| 1200 |
+
lm_hidden_states = self._compute_lm_hidden_states(
|
| 1201 |
+
input_ids,
|
| 1202 |
+
asym_id,
|
| 1203 |
+
residue_index,
|
| 1204 |
+
mol_type,
|
| 1205 |
+
tok_mask,
|
| 1206 |
+
lm_mask_pct=(
|
| 1207 |
+
self.config.lm_mask_pct
|
| 1208 |
+
if lm_mask_pct is None
|
| 1209 |
+
else lm_mask_pct
|
| 1210 |
+
),
|
| 1211 |
+
)
|
| 1212 |
lm_z: Tensor | None = None
|
| 1213 |
if lm_hidden_states is not None:
|
| 1214 |
lm_z = self.language_model(lm_hidden_states.detach())
|
|
|
|
| 1222 |
a = a.view(1, 1, 1, -1).to(device=z.device, dtype=z.dtype)
|
| 1223 |
b_mat = b.to(device=z.device, dtype=z.dtype)
|
| 1224 |
|
| 1225 |
+
_msa_inputs: dict | None = None
|
| 1226 |
+
if self.msa_encoder is not None and msa is not None:
|
| 1227 |
+
msa_attention_mask = maybe_apply_msa_column_masking(
|
| 1228 |
+
msa_attention_mask,
|
| 1229 |
+
msa_column_mask_rate,
|
| 1230 |
+
)
|
| 1231 |
+
_msa_inputs = dict(
|
| 1232 |
+
x_inputs=x_inputs,
|
| 1233 |
+
msa=msa,
|
| 1234 |
+
msa_attention_mask=msa_attention_mask,
|
| 1235 |
+
has_deletion=has_deletion,
|
| 1236 |
+
deletion_value=deletion_value,
|
| 1237 |
+
max_depth=msa_max_depth,
|
| 1238 |
+
subsample_enabled=msa_subsample_at_inference,
|
| 1239 |
+
)
|
| 1240 |
|
| 1241 |
# Method call (not inline loop) frees per-iter L²×c_z locals.
|
| 1242 |
z = self._run_one_loop(
|
| 1243 |
+
z=z,
|
| 1244 |
+
z_init=z_init,
|
| 1245 |
+
lm_z=lm_z,
|
| 1246 |
+
_msa_inputs=_msa_inputs,
|
| 1247 |
+
pair_mask=pair_mask,
|
| 1248 |
+
a=a,
|
| 1249 |
+
b_mat=b_mat,
|
| 1250 |
+
tok_mask=tok_mask,
|
| 1251 |
+
total_steps=total_steps,
|
| 1252 |
+
)
|
| 1253 |
+
del z_init, lm_z, _msa_inputs, a, b_mat
|
| 1254 |
|
| 1255 |
z = self.parcae_readout(z)
|
| 1256 |
z = self.parcae_coda(z, pair_attention_mask=pair_mask)
|
|
|
|
| 1362 |
complex_id=complex_id,
|
| 1363 |
)
|
| 1364 |
|
| 1365 |
+
def _fold_protein_no_ttt(
|
| 1366 |
+
self,
|
| 1367 |
+
sequence: str,
|
| 1368 |
+
*,
|
| 1369 |
+
chain_id: str = "A",
|
| 1370 |
+
msa: Any | None = None,
|
| 1371 |
+
msa_path: str | Path | None = None,
|
| 1372 |
+
msa_max_sequences: int | None = None,
|
| 1373 |
+
num_loops: int = 3,
|
| 1374 |
+
num_sampling_steps: int = 50,
|
| 1375 |
+
num_diffusion_samples: int = 1,
|
| 1376 |
+
seed: int | None = None,
|
| 1377 |
+
complex_id: str = "pred",
|
| 1378 |
+
):
|
| 1379 |
+
from .esmfold2_types import MSA, ProteinInput, StructurePredictionInput
|
| 1380 |
+
|
| 1381 |
+
assert not (
|
| 1382 |
+
msa is not None and msa_path is not None
|
| 1383 |
+
), "Pass at most one of msa or msa_path."
|
| 1384 |
+
if msa_path is not None:
|
| 1385 |
+
msa = MSA.from_a3m(msa_path, max_sequences=msa_max_sequences)
|
| 1386 |
+
if msa is not None:
|
| 1387 |
+
query = str(msa.query).replace("-", "").upper()
|
| 1388 |
+
assert query == sequence.upper(), (
|
| 1389 |
+
f"MSA query does not match sequence: expected {sequence.upper()!r}, got {query!r}"
|
| 1390 |
+
)
|
| 1391 |
+
|
| 1392 |
+
input = StructurePredictionInput(
|
| 1393 |
+
sequences=[ProteinInput(id=chain_id, sequence=sequence, msa=msa)]
|
| 1394 |
+
)
|
| 1395 |
+
return self.fold(
|
| 1396 |
+
input,
|
| 1397 |
num_loops=num_loops,
|
| 1398 |
num_sampling_steps=num_sampling_steps,
|
| 1399 |
num_diffusion_samples=num_diffusion_samples,
|
|
|
|
| 1442 |
|
| 1443 |
def fold_protein(
|
| 1444 |
self,
|
| 1445 |
+
sequence: str,
|
| 1446 |
+
*,
|
| 1447 |
+
chain_id: str = "A",
|
| 1448 |
+
msa: Any | None = None,
|
| 1449 |
+
msa_path: str | Path | None = None,
|
| 1450 |
+
msa_max_sequences: int | None = None,
|
| 1451 |
+
num_loops: int = 3,
|
| 1452 |
+
num_sampling_steps: int = 50,
|
| 1453 |
+
num_diffusion_samples: int = 1,
|
| 1454 |
seed: int | None = None,
|
| 1455 |
complex_id: str = "pred",
|
| 1456 |
ttt: bool = False,
|
|
|
|
| 1458 |
):
|
| 1459 |
if ttt:
|
| 1460 |
return self.fold_protein_ttt(
|
| 1461 |
+
sequence=sequence,
|
| 1462 |
+
chain_id=chain_id,
|
| 1463 |
+
msa=msa,
|
| 1464 |
+
msa_path=msa_path,
|
| 1465 |
+
msa_max_sequences=msa_max_sequences,
|
| 1466 |
+
num_loops=num_loops,
|
| 1467 |
+
num_sampling_steps=num_sampling_steps,
|
| 1468 |
+
num_diffusion_samples=num_diffusion_samples,
|
| 1469 |
seed=seed,
|
| 1470 |
complex_id=complex_id,
|
| 1471 |
ttt_config=ttt_config,
|
| 1472 |
)
|
| 1473 |
return self._fold_protein_no_ttt(
|
| 1474 |
+
sequence=sequence,
|
| 1475 |
+
chain_id=chain_id,
|
| 1476 |
+
msa=msa,
|
| 1477 |
+
msa_path=msa_path,
|
| 1478 |
+
msa_max_sequences=msa_max_sequences,
|
| 1479 |
+
num_loops=num_loops,
|
| 1480 |
+
num_sampling_steps=num_sampling_steps,
|
| 1481 |
+
num_diffusion_samples=num_diffusion_samples,
|
| 1482 |
seed=seed,
|
| 1483 |
complex_id=complex_id,
|
| 1484 |
)
|
| 1485 |
|
| 1486 |
def fold_protein_ttt(
|
| 1487 |
self,
|
| 1488 |
+
sequence: str,
|
| 1489 |
+
*,
|
| 1490 |
+
chain_id: str = "A",
|
| 1491 |
+
msa: Any | None = None,
|
| 1492 |
+
msa_path: str | Path | None = None,
|
| 1493 |
+
msa_max_sequences: int | None = None,
|
| 1494 |
+
num_loops: int = 3,
|
| 1495 |
+
num_sampling_steps: int = 50,
|
| 1496 |
+
num_diffusion_samples: int = 1,
|
| 1497 |
seed: int | None = None,
|
| 1498 |
complex_id: str = "pred",
|
| 1499 |
ttt_config: TTTConfig | dict[str, Any] | None = None,
|
| 1500 |
+
):
|
| 1501 |
+
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
|
| 1502 |
+
if self._esmc_fp8:
|
| 1503 |
+
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
|
| 1504 |
+
fold_kwargs = {
|
| 1505 |
+
"chain_id": chain_id,
|
| 1506 |
+
"msa": msa,
|
| 1507 |
+
"msa_path": msa_path,
|
| 1508 |
+
"msa_max_sequences": msa_max_sequences,
|
| 1509 |
+
"num_loops": num_loops,
|
| 1510 |
+
"num_sampling_steps": num_sampling_steps,
|
| 1511 |
+
"num_diffusion_samples": num_diffusion_samples,
|
| 1512 |
"seed": seed,
|
| 1513 |
"complex_id": complex_id,
|
| 1514 |
}
|
modeling_esmfold2_common.py
CHANGED
|
@@ -140,61 +140,61 @@ _EPS = 1e-5
|
|
| 140 |
# chunk=64 leaves headroom for the largest foldbench targets). Override via
|
| 141 |
# ``model.set_chunk_size(...)``; pass None to disable chunking (faster for
|
| 142 |
# short L but OOM-prone past ~600).
|
| 143 |
-
_DEFAULT_CHUNK_SIZE = 64
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
# ===========================================================================
|
| 147 |
-
# MSA inference-time diversity augmentations
|
| 148 |
-
# ===========================================================================
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def maybe_subsample_msa(
|
| 152 |
-
msa: Tensor,
|
| 153 |
-
msa_attention_mask: Tensor | None,
|
| 154 |
-
has_deletion: Tensor | None,
|
| 155 |
-
deletion_value: Tensor | None,
|
| 156 |
-
*,
|
| 157 |
-
max_depth: int | None,
|
| 158 |
-
enabled: bool,
|
| 159 |
-
) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]:
|
| 160 |
-
if not enabled or max_depth is None:
|
| 161 |
-
return msa, msa_attention_mask, has_deletion, deletion_value
|
| 162 |
-
|
| 163 |
-
depth = msa.size(1)
|
| 164 |
-
if depth <= 1 or depth <= max_depth:
|
| 165 |
-
return msa, msa_attention_mask, has_deletion, deletion_value
|
| 166 |
-
|
| 167 |
-
indices = torch.zeros(max_depth, dtype=torch.long, device=msa.device)
|
| 168 |
-
indices[1:] = torch.randperm(depth - 1, device=msa.device)[: max_depth - 1] + 1
|
| 169 |
-
indices = indices.sort().values
|
| 170 |
-
|
| 171 |
-
msa = msa[:, indices]
|
| 172 |
-
if msa_attention_mask is not None:
|
| 173 |
-
msa_attention_mask = msa_attention_mask[:, indices]
|
| 174 |
-
if has_deletion is not None:
|
| 175 |
-
has_deletion = has_deletion[:, indices]
|
| 176 |
-
if deletion_value is not None:
|
| 177 |
-
deletion_value = deletion_value[:, indices]
|
| 178 |
-
return msa, msa_attention_mask, has_deletion, deletion_value
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
def maybe_apply_msa_column_masking(
|
| 182 |
-
msa_attention_mask: Tensor | None,
|
| 183 |
-
rate: float,
|
| 184 |
-
) -> Tensor | None:
|
| 185 |
-
if msa_attention_mask is None or rate <= 0.0 or msa_attention_mask.size(1) <= 1:
|
| 186 |
-
return msa_attention_mask
|
| 187 |
-
|
| 188 |
-
batch_size, _, length = msa_attention_mask.shape
|
| 189 |
-
col_keep = torch.rand(batch_size, length, device=msa_attention_mask.device) >= rate
|
| 190 |
-
col_keep = col_keep.unsqueeze(1).expand_as(msa_attention_mask).clone()
|
| 191 |
-
col_keep[:, 0, :] = True
|
| 192 |
-
return msa_attention_mask.bool() & col_keep
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
# ===========================================================================
|
| 196 |
-
# Atom-token utilities
|
| 197 |
-
# ===========================================================================
|
| 198 |
|
| 199 |
|
| 200 |
def gather_token_to_atom(token_features: Tensor, atom_to_token_idx: Tensor) -> Tensor:
|
|
@@ -2182,17 +2182,17 @@ def _seed_context(seed: int | None, *, cuda: bool = True):
|
|
| 2182 |
# ===========================================================================
|
| 2183 |
|
| 2184 |
|
| 2185 |
-
def compute_lm_hidden_states(
|
| 2186 |
-
esmc: nn.Module,
|
| 2187 |
-
input_ids: Tensor,
|
| 2188 |
-
asym_id: Tensor,
|
| 2189 |
-
residue_index: Tensor,
|
| 2190 |
-
mol_type: Tensor,
|
| 2191 |
-
token_mask: Tensor,
|
| 2192 |
-
pad_to_multiple: int | None = None,
|
| 2193 |
-
lm_mask_pct: float = 0.0,
|
| 2194 |
-
mask_token_id: int = 32,
|
| 2195 |
-
) -> Tensor:
|
| 2196 |
"""Run ESMC with BOS/EOS wrapping, return hidden states [B, L, N, D] with N=81 layers.
|
| 2197 |
|
| 2198 |
Atom-tokenized modified residues (HYP, MSE, ACE, NH2, ...) span multiple
|
|
@@ -2277,21 +2277,21 @@ def compute_lm_hidden_states(
|
|
| 2277 |
for b in range(B):
|
| 2278 |
lm_input_ids[b, : lm_lengths[b]] = lm_input_list[b]
|
| 2279 |
|
| 2280 |
-
# sequence_id for chain-aware attention; PAD tokens get -1 (no attention).
|
| 2281 |
-
sequence_id = (lm_input_ids == 0).cumsum(dim=1) - 1 # BOS=0
|
| 2282 |
-
sequence_id = sequence_id.masked_fill(lm_input_ids == 1, -1) # PAD=1
|
| 2283 |
-
|
| 2284 |
-
if lm_mask_pct > 0.0:
|
| 2285 |
-
special = (lm_input_ids == 0) | (lm_input_ids == 1) | (lm_input_ids == 2)
|
| 2286 |
-
do_mask = (
|
| 2287 |
-
torch.rand(lm_input_ids.shape, device=device) < lm_mask_pct
|
| 2288 |
-
) & ~special
|
| 2289 |
-
lm_input_ids = lm_input_ids.masked_fill(do_mask, mask_token_id)
|
| 2290 |
-
|
| 2291 |
-
with torch.inference_mode():
|
| 2292 |
-
esmc_out = esmc(
|
| 2293 |
-
input_ids=lm_input_ids, sequence_id=sequence_id, output_hidden_states=True
|
| 2294 |
-
)
|
| 2295 |
|
| 2296 |
hs = esmc_out.hidden_states # [n_layers+1, B, max_len, D]
|
| 2297 |
n_layers_plus_1, _, _, D = hs.shape
|
|
|
|
| 140 |
# chunk=64 leaves headroom for the largest foldbench targets). Override via
|
| 141 |
# ``model.set_chunk_size(...)``; pass None to disable chunking (faster for
|
| 142 |
# short L but OOM-prone past ~600).
|
| 143 |
+
_DEFAULT_CHUNK_SIZE = 64
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ===========================================================================
|
| 147 |
+
# MSA inference-time diversity augmentations
|
| 148 |
+
# ===========================================================================
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def maybe_subsample_msa(
|
| 152 |
+
msa: Tensor,
|
| 153 |
+
msa_attention_mask: Tensor | None,
|
| 154 |
+
has_deletion: Tensor | None,
|
| 155 |
+
deletion_value: Tensor | None,
|
| 156 |
+
*,
|
| 157 |
+
max_depth: int | None,
|
| 158 |
+
enabled: bool,
|
| 159 |
+
) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]:
|
| 160 |
+
if not enabled or max_depth is None:
|
| 161 |
+
return msa, msa_attention_mask, has_deletion, deletion_value
|
| 162 |
+
|
| 163 |
+
depth = msa.size(1)
|
| 164 |
+
if depth <= 1 or depth <= max_depth:
|
| 165 |
+
return msa, msa_attention_mask, has_deletion, deletion_value
|
| 166 |
+
|
| 167 |
+
indices = torch.zeros(max_depth, dtype=torch.long, device=msa.device)
|
| 168 |
+
indices[1:] = torch.randperm(depth - 1, device=msa.device)[: max_depth - 1] + 1
|
| 169 |
+
indices = indices.sort().values
|
| 170 |
+
|
| 171 |
+
msa = msa[:, indices]
|
| 172 |
+
if msa_attention_mask is not None:
|
| 173 |
+
msa_attention_mask = msa_attention_mask[:, indices]
|
| 174 |
+
if has_deletion is not None:
|
| 175 |
+
has_deletion = has_deletion[:, indices]
|
| 176 |
+
if deletion_value is not None:
|
| 177 |
+
deletion_value = deletion_value[:, indices]
|
| 178 |
+
return msa, msa_attention_mask, has_deletion, deletion_value
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def maybe_apply_msa_column_masking(
|
| 182 |
+
msa_attention_mask: Tensor | None,
|
| 183 |
+
rate: float,
|
| 184 |
+
) -> Tensor | None:
|
| 185 |
+
if msa_attention_mask is None or rate <= 0.0 or msa_attention_mask.size(1) <= 1:
|
| 186 |
+
return msa_attention_mask
|
| 187 |
+
|
| 188 |
+
batch_size, _, length = msa_attention_mask.shape
|
| 189 |
+
col_keep = torch.rand(batch_size, length, device=msa_attention_mask.device) >= rate
|
| 190 |
+
col_keep = col_keep.unsqueeze(1).expand_as(msa_attention_mask).clone()
|
| 191 |
+
col_keep[:, 0, :] = True
|
| 192 |
+
return msa_attention_mask.bool() & col_keep
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ===========================================================================
|
| 196 |
+
# Atom-token utilities
|
| 197 |
+
# ===========================================================================
|
| 198 |
|
| 199 |
|
| 200 |
def gather_token_to_atom(token_features: Tensor, atom_to_token_idx: Tensor) -> Tensor:
|
|
|
|
| 2182 |
# ===========================================================================
|
| 2183 |
|
| 2184 |
|
| 2185 |
+
def compute_lm_hidden_states(
|
| 2186 |
+
esmc: nn.Module,
|
| 2187 |
+
input_ids: Tensor,
|
| 2188 |
+
asym_id: Tensor,
|
| 2189 |
+
residue_index: Tensor,
|
| 2190 |
+
mol_type: Tensor,
|
| 2191 |
+
token_mask: Tensor,
|
| 2192 |
+
pad_to_multiple: int | None = None,
|
| 2193 |
+
lm_mask_pct: float = 0.0,
|
| 2194 |
+
mask_token_id: int = 32,
|
| 2195 |
+
) -> Tensor:
|
| 2196 |
"""Run ESMC with BOS/EOS wrapping, return hidden states [B, L, N, D] with N=81 layers.
|
| 2197 |
|
| 2198 |
Atom-tokenized modified residues (HYP, MSE, ACE, NH2, ...) span multiple
|
|
|
|
| 2277 |
for b in range(B):
|
| 2278 |
lm_input_ids[b, : lm_lengths[b]] = lm_input_list[b]
|
| 2279 |
|
| 2280 |
+
# sequence_id for chain-aware attention; PAD tokens get -1 (no attention).
|
| 2281 |
+
sequence_id = (lm_input_ids == 0).cumsum(dim=1) - 1 # BOS=0
|
| 2282 |
+
sequence_id = sequence_id.masked_fill(lm_input_ids == 1, -1) # PAD=1
|
| 2283 |
+
|
| 2284 |
+
if lm_mask_pct > 0.0:
|
| 2285 |
+
special = (lm_input_ids == 0) | (lm_input_ids == 1) | (lm_input_ids == 2)
|
| 2286 |
+
do_mask = (
|
| 2287 |
+
torch.rand(lm_input_ids.shape, device=device) < lm_mask_pct
|
| 2288 |
+
) & ~special
|
| 2289 |
+
lm_input_ids = lm_input_ids.masked_fill(do_mask, mask_token_id)
|
| 2290 |
+
|
| 2291 |
+
with torch.inference_mode():
|
| 2292 |
+
esmc_out = esmc(
|
| 2293 |
+
input_ids=lm_input_ids, sequence_id=sequence_id, output_hidden_states=True
|
| 2294 |
+
)
|
| 2295 |
|
| 2296 |
hs = esmc_out.hidden_states # [n_layers+1, B, max_len, D]
|
| 2297 |
n_layers_plus_1, _, _, D = hs.shape
|
modeling_esmfold2_experimental.py
CHANGED
|
@@ -521,14 +521,14 @@ class ESMFold2ExperimentalModel(PreTrainedModel):
|
|
| 521 |
"bf16": torch.bfloat16,
|
| 522 |
"fp32": torch.float32,
|
| 523 |
}
|
| 524 |
-
if precision not in dtype_map:
|
| 525 |
-
if precision == "fp8":
|
| 526 |
-
raise RuntimeError(
|
| 527 |
-
"esmc_precision='fp8' is supported only by the standard "
|
| 528 |
-
"released ESMFold2 model. The experimental binder-design "
|
| 529 |
-
"model keeps the FastPLMs ESM++ backbone in bf16 or fp32."
|
| 530 |
-
)
|
| 531 |
-
raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
|
| 532 |
esmc = _load_fastplms_esmplusplus_for_esmfold2(
|
| 533 |
esmc_model_path=esmc_model_path,
|
| 534 |
attn_backend=self.config.esmc_attn_backend,
|
|
@@ -852,13 +852,29 @@ class ESMFold2ExperimentalModel(PreTrainedModel):
|
|
| 852 |
return_atom_repr=False,
|
| 853 |
denoising_early_exit_rmsd=(0.10 if early_exit else None),
|
| 854 |
)
|
| 855 |
-
sample_coords = structure_output["sample_atom_coords"]
|
| 856 |
-
assert sample_coords is not None
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
if calculate_confidence and self.confidence_head is not None:
|
| 863 |
confidence_output = self.confidence_head(
|
| 864 |
s_inputs=x_inputs.detach(),
|
|
|
|
| 521 |
"bf16": torch.bfloat16,
|
| 522 |
"fp32": torch.float32,
|
| 523 |
}
|
| 524 |
+
if precision not in dtype_map:
|
| 525 |
+
if precision == "fp8":
|
| 526 |
+
raise RuntimeError(
|
| 527 |
+
"esmc_precision='fp8' is supported only by the standard "
|
| 528 |
+
"released ESMFold2 model. The experimental binder-design "
|
| 529 |
+
"model keeps the FastPLMs ESM++ backbone in bf16 or fp32."
|
| 530 |
+
)
|
| 531 |
+
raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
|
| 532 |
esmc = _load_fastplms_esmplusplus_for_esmfold2(
|
| 533 |
esmc_model_path=esmc_model_path,
|
| 534 |
attn_backend=self.config.esmc_attn_backend,
|
|
|
|
| 852 |
return_atom_repr=False,
|
| 853 |
denoising_early_exit_rmsd=(0.10 if early_exit else None),
|
| 854 |
)
|
| 855 |
+
sample_coords = structure_output["sample_atom_coords"]
|
| 856 |
+
assert sample_coords is not None
|
| 857 |
+
if sample_coords.ndim == 4:
|
| 858 |
+
batch, sample_count, atom_count, coord_dim = sample_coords.shape
|
| 859 |
+
sample_coords_for_gather = sample_coords.reshape(
|
| 860 |
+
batch * sample_count,
|
| 861 |
+
atom_count,
|
| 862 |
+
coord_dim,
|
| 863 |
+
)
|
| 864 |
+
rep_idx = distogram_atom_idx.repeat_interleave(sample_count, 0).long()
|
| 865 |
+
else:
|
| 866 |
+
sample_coords_for_gather = sample_coords
|
| 867 |
+
rep_idx = distogram_atom_idx.long()
|
| 868 |
+
representative_atom_coords = gather_rep_atom_coords(
|
| 869 |
+
sample_coords_for_gather,
|
| 870 |
+
rep_idx,
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
output: dict[str, Tensor] = {
|
| 874 |
+
"distogram_logits": distogram_logits,
|
| 875 |
+
"sample_atom_coords": sample_coords,
|
| 876 |
+
"representative_atom_coords": representative_atom_coords,
|
| 877 |
+
}
|
| 878 |
if calculate_confidence and self.confidence_head is not None:
|
| 879 |
confidence_output = self.confidence_head(
|
| 880 |
s_inputs=x_inputs.detach(),
|