lhallee commited on
Commit
407875b
·
verified ·
1 Parent(s): 410f70b

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -221,27 +221,27 @@ with torch.inference_mode():
221
  decoded = model.input_builder.decode(output, features, chain_infos)
222
  ```
223
 
224
- Set `load_esmc=False` when loading if you want to provide precomputed `lm_hidden_states` manually or run folding-trunk tests without loading the 6B ESM++ backbone:
225
-
226
- ```python
227
- model = AutoModel.from_pretrained(
228
- "Synthyra/ESMFold2-Fast",
229
  trust_remote_code=True,
230
  load_esmc=False,
231
- ).cuda().eval()
232
- ```
233
-
234
- For FP8 LM inference, install `transformer_engine.pytorch` in a CUDA
235
- environment with FP8-capable hardware and load the shared FastPLMs ESM++
236
- backbone with:
237
-
238
- ```python
239
- model = AutoModel.from_pretrained(
240
- "Synthyra/ESMFold2-Fast",
241
- trust_remote_code=True,
242
- esmc_precision="fp8",
243
- ).cuda().eval()
244
- ```
245
-
246
- FP8 is inference-only for the ESMFold2 LM backbone. TTT remains a bf16/fp32
247
- path.
 
221
  decoded = model.input_builder.decode(output, features, chain_infos)
222
  ```
223
 
224
+ Set `load_esmc=False` when loading if you want to provide precomputed `lm_hidden_states` manually or run folding-trunk tests without loading the 6B ESM++ backbone:
225
+
226
+ ```python
227
+ model = AutoModel.from_pretrained(
228
+ "Synthyra/ESMFold2-Fast",
229
  trust_remote_code=True,
230
  load_esmc=False,
231
+ ).cuda().eval()
232
+ ```
233
+
234
+ For FP8 LM inference, install `transformer_engine.pytorch` in a CUDA
235
+ environment with FP8-capable hardware and load the shared FastPLMs ESM++
236
+ backbone with:
237
+
238
+ ```python
239
+ model = AutoModel.from_pretrained(
240
+ "Synthyra/ESMFold2-Fast",
241
+ trust_remote_code=True,
242
+ esmc_precision="fp8",
243
+ ).cuda().eval()
244
+ ```
245
+
246
+ FP8 is inference-only for the ESMFold2 LM backbone. TTT remains a bf16/fp32
247
+ path.
__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- from .configuration_esmfold2 import ESMFold2Config
2
- from .modeling_esmfold2_experimental import ESMFold2ExperimentalModel
3
- from .modeling_esmfold2 import ESMFold2Model
4
-
5
- __all__ = ["ESMFold2Config", "ESMFold2ExperimentalModel", "ESMFold2Model"]
 
1
+ from .configuration_esmfold2 import ESMFold2Config
2
+ from .modeling_esmfold2_experimental import ESMFold2ExperimentalModel
3
+ from .modeling_esmfold2 import ESMFold2Model
4
+
5
+ __all__ = ["ESMFold2Config", "ESMFold2ExperimentalModel", "ESMFold2Model"]
configuration_esmfold2.py CHANGED
@@ -201,19 +201,19 @@ class ESMFold2Config(PretrainedConfig):
201
  Number of trunk loops for iterative refinement.
202
  num_diffusion_samples (`int`, defaults to 8):
203
  Number of parallel structure predictions to generate.
204
- lm_dropout (`float`, defaults to 0.0):
205
- Dropout probability on LM pair embeddings. When > 0, dropout is
206
- applied with ``training=True`` (including at inference) to match
207
- the experimental training recipe used by binder design.
208
- force_lm_dropout_during_inference (`bool`, defaults to False):
209
- When True, apply ``lm_dropout`` even when ``model.eval()`` and
210
- ``lm_dropout`` > 0. Binder-design loads set this to True.
211
- lm_mask_pct (`float`, defaults to 0.0):
212
- Fraction of LM residue tokens randomly replaced with the LM mask
213
- token before running the PLM backbone.
214
- disable_msa_features (`bool`, defaults to False):
215
- When True, zero out MSA-derived ``profile`` and ``deletion_mean``
216
- before the inputs embedder (experimental medium/large checkpoints).
217
  inputs (`InputsEmbedderConfig`):
218
  Configuration for the inputs embedder module.
219
  folding_trunk (`FoldingTrunkConfig`):
@@ -263,12 +263,12 @@ class ESMFold2Config(PretrainedConfig):
263
  # embedder.
264
  self.disable_msa_features: bool = kwargs.get("disable_msa_features", False)
265
  self.lm_dropout: float = kwargs.get("lm_dropout", 0.0)
266
- self.force_lm_dropout_during_inference: bool = kwargs.get(
267
- "force_lm_dropout_during_inference", False
268
- )
269
- self.lm_mask_pct: float = kwargs.get("lm_mask_pct", 0.0)
270
-
271
- self.lm_d_model: int = kwargs.get("lm_d_model", 2560)
272
  self.lm_num_layers: int = kwargs.get("lm_num_layers", 80)
273
  # Backward-compatible field name; values now point to FastPLMs ESM++.
274
  raw_esmc_id = (
 
201
  Number of trunk loops for iterative refinement.
202
  num_diffusion_samples (`int`, defaults to 8):
203
  Number of parallel structure predictions to generate.
204
+ lm_dropout (`float`, defaults to 0.0):
205
+ Dropout probability on LM pair embeddings. When > 0, dropout is
206
+ applied with ``training=True`` (including at inference) to match
207
+ the experimental training recipe used by binder design.
208
+ force_lm_dropout_during_inference (`bool`, defaults to False):
209
+ When True, apply ``lm_dropout`` even when ``model.eval()`` and
210
+ ``lm_dropout`` > 0. Binder-design loads set this to True.
211
+ lm_mask_pct (`float`, defaults to 0.0):
212
+ Fraction of LM residue tokens randomly replaced with the LM mask
213
+ token before running the PLM backbone.
214
+ disable_msa_features (`bool`, defaults to False):
215
+ When True, zero out MSA-derived ``profile`` and ``deletion_mean``
216
+ before the inputs embedder (experimental medium/large checkpoints).
217
  inputs (`InputsEmbedderConfig`):
218
  Configuration for the inputs embedder module.
219
  folding_trunk (`FoldingTrunkConfig`):
 
263
  # embedder.
264
  self.disable_msa_features: bool = kwargs.get("disable_msa_features", False)
265
  self.lm_dropout: float = kwargs.get("lm_dropout", 0.0)
266
+ self.force_lm_dropout_during_inference: bool = kwargs.get(
267
+ "force_lm_dropout_during_inference", False
268
+ )
269
+ self.lm_mask_pct: float = kwargs.get("lm_mask_pct", 0.0)
270
+
271
+ self.lm_d_model: int = kwargs.get("lm_d_model", 2560)
272
  self.lm_num_layers: int = kwargs.get("lm_num_layers", 80)
273
  # Backward-compatible field name; values now point to FastPLMs ESM++.
274
  raw_esmc_id = (
modeling_esmfold2.py CHANGED
@@ -59,12 +59,12 @@ from .modeling_esmfold2_common import (
59
  TriangleMultiplicativeUpdate,
60
  _categorical_mean,
61
  _compute_intra_token_idx,
62
- compute_lm_hidden_states,
63
- gather_rep_atom_coords,
64
- gather_token_to_atom,
65
- maybe_apply_msa_column_masking,
66
- maybe_subsample_msa,
67
- )
68
  from .esmfold2_affine3d import Affine3D as _FastPLMSESMFold2Affine3D
69
  from .esmfold2_aligner import Aligner as _FastPLMSESMFold2Aligner
70
  from .esmfold2_atom_indexer import AtomIndexer as _FastPLMSESMFold2AtomIndexer
@@ -699,27 +699,27 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
699
  self.post_init()
700
  self.init_ttt({"lora_target_replace_module": "MultiHeadAttention"})
701
 
702
- def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None:
703
- """Load the FastPLMs ESM++ LM used as the ESMFold2 PLM backbone.
704
-
705
- ``precision``: ``"bf16"`` (default), ``"fp32"``, or opt-in ``"fp8"``.
706
- """
707
- dtype_map = {
708
- "bf16": torch.bfloat16,
709
- "fp32": torch.float32,
710
- "fp8": torch.bfloat16,
711
- }
712
- if precision not in dtype_map:
713
- raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
714
- if precision == "fp8" and not TE_AVAILABLE:
715
- raise RuntimeError(
716
- "esmc_precision='fp8' requires transformer_engine.pytorch."
717
- )
718
- dtype = dtype_map[precision]
719
-
720
- esmc = _load_fastplms_esmplusplus_for_esmfold2(
721
- esmc_model_path=esmc_model_path,
722
- attn_backend=self.config.esmc_attn_backend,
723
  device=self.device,
724
  dtype=dtype,
725
  )
@@ -730,24 +730,24 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
730
  assert esmc.config.num_hidden_layers == self.config.lm_num_layers, (
731
  f"ESMFold2 expected lm_num_layers={self.config.lm_num_layers}, "
732
  f"but loaded ESM++ num_hidden_layers={esmc.config.num_hidden_layers}."
733
- )
734
- for p in esmc.parameters():
735
- p.requires_grad_(False)
736
-
737
- if precision == "fp8":
738
- with torch.no_grad():
739
- _convert_te_modules_to_fp8_inplace(esmc)
740
-
741
- self._esmc_fp8 = precision == "fp8"
742
- self._esmc = esmc
743
- self._ttt_lm_head = None
744
-
745
- def _ensure_ttt_lm_head(self) -> None:
746
- assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
747
- if self._esmc_fp8:
748
- raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
749
- if self._ttt_lm_head is not None:
750
- return
751
  try:
752
  from fastplms.esm_plusplus.modeling_esm_plusplus import (
753
  ESMplusplusConfig,
@@ -781,11 +781,11 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
781
  self._ttt_lm_head.requires_grad_(False)
782
  del mlm
783
 
784
- def _ttt_get_trainable_modules(self) -> list[nn.Module]:
785
- assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
786
- if self._esmc_fp8:
787
- raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
788
- return [self._esmc]
789
 
790
  def _ttt_tokenize(
791
  self,
@@ -846,13 +846,13 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
846
  **kwargs,
847
  ) -> torch.Tensor:
848
  del kwargs
849
- assert isinstance(batch, torch.Tensor), (
850
- "ESMFold2 TTT expects input_ids tensors."
851
- )
852
- assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
853
- if self._esmc_fp8:
854
- raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
855
- self._ensure_ttt_lm_head()
856
  assert self._ttt_lm_head is not None
857
  attention_mask = batch.ne(SEQUENCE_PAD_TOKEN)
858
  output = self._esmc(
@@ -947,30 +947,30 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
947
  if self.msa_encoder is not None:
948
  self.msa_encoder.set_chunk_size(chunk_size)
949
 
950
- def _compute_lm_hidden_states(
951
- self,
952
- input_ids: Tensor,
953
- asym_id: Tensor,
954
- residue_index: Tensor,
955
- mol_type: Tensor,
956
- tok_mask: Tensor,
957
- lm_mask_pct: float = 0.0,
958
- ) -> Tensor:
959
- assert self._esmc is not None
960
- # fp8 TE kernels require prod(shape[:-1]) % 8 == 0.
961
- pad_to = 8 if self._esmc_fp8 else None
962
- with _lm_precision_context(self._esmc_fp8):
963
  return compute_lm_hidden_states(
964
  self._esmc,
965
  input_ids,
966
  asym_id,
967
  residue_index,
968
- mol_type,
969
- tok_mask,
970
- pad_to_multiple=pad_to,
971
- lm_mask_pct=lm_mask_pct,
972
- mask_token_id=SEQUENCE_MASK_TOKEN,
973
- )
974
 
975
  def _discretized_dynamics(self) -> tuple[Tensor, Tensor]:
976
  delta = F.softplus(self.parcae_log_delta)
@@ -985,17 +985,17 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
985
  return state.to(dtype=ref.dtype)
986
 
987
  def _run_one_loop(
988
- self,
989
- z: Tensor,
990
- z_init: Tensor,
991
- lm_z: Tensor | None,
992
- _msa_inputs: dict | None,
993
- pair_mask: Tensor,
994
- a: Tensor,
995
- b_mat: Tensor,
996
- tok_mask: Tensor,
997
- total_steps: int,
998
- ) -> Tensor:
999
  # Helper method (not inline) so per-iter locals free on return —
1000
  # otherwise leaks ~2 GB L²×c_z into distogram/sample scope.
1001
  # training=True forces dropout under eval(), matching the per-loop
@@ -1025,49 +1025,49 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
1025
  if lm_z_i is not None and self.lm_encoder is None:
1026
  z_inject_pair = z_inject_pair + lm_z_i.to(z_inject_pair.dtype)
1027
 
1028
- if self.msa_encoder is not None and _msa_inputs is not None:
1029
- msa_i, mask_i, hd_i, dv_i = maybe_subsample_msa(
1030
- _msa_inputs["msa"],
1031
- _msa_inputs["msa_attention_mask"],
1032
- _msa_inputs["has_deletion"],
1033
- _msa_inputs["deletion_value"],
1034
- max_depth=_msa_inputs["max_depth"],
1035
- enabled=_msa_inputs["subsample_enabled"],
1036
- )
1037
- B_msa, M, L_msa = msa_i.shape
1038
- msa_oh = F.one_hot(
1039
- msa_i.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES
1040
- ).float()
1041
- msa_attn = (
1042
- mask_i.permute(0, 2, 1).float()
1043
- if mask_i is not None
1044
- else tok_mask[:, :, None].expand(-1, -1, M).float()
1045
- )
1046
- # Bias-free MSAEncoder.embed requires zeroed padding.
1047
- msa_oh = msa_oh * msa_attn.unsqueeze(-1)
1048
- hd = (
1049
- hd_i.permute(0, 2, 1).float()
1050
- if hd_i is not None
1051
- else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
1052
- )
1053
- dv = (
1054
- dv_i.permute(0, 2, 1).float()
1055
- if dv_i is not None
1056
- else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
1057
- )
1058
- msa_pair = self.msa_encoder(
1059
- x_pair=z_inject_pair,
1060
- x_inputs=_msa_inputs["x_inputs"],
1061
- msa_oh=msa_oh,
1062
- has_deletion=hd,
1063
- deletion_value=dv,
1064
- msa_attention_mask=msa_attn,
1065
- ).to(z_inject_pair.dtype)
1066
- z_inject_pair = (
1067
- msa_pair
1068
- if self.config.msa_encoder_overwrite
1069
- else (z_inject_pair + msa_pair)
1070
- )
1071
 
1072
  if refined_lm_z is not None:
1073
  z_inject_pair = z_inject_pair + refined_lm_z.to(z_inject_pair.dtype)
@@ -1104,16 +1104,16 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
1104
  deletion_value: Tensor | None = None,
1105
  msa_attention_mask: Tensor | None = None,
1106
  input_ids: Tensor | None = None,
1107
- lm_hidden_states: Tensor | None = None,
1108
- num_loops: int | None = None,
1109
- num_diffusion_samples: int | None = None,
1110
- num_sampling_steps: int | None = None,
1111
- lm_mask_pct: float | None = None,
1112
- msa_max_depth: int = 1024,
1113
- msa_column_mask_rate: float = 0.1,
1114
- msa_subsample_at_inference: bool = True,
1115
- **kwargs,
1116
- ) -> dict[str, Tensor]:
1117
  tok_mask = token_attention_mask
1118
  atm_mask = atom_attention_mask
1119
  disto_idx = distogram_atom_idx
@@ -1196,19 +1196,19 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
1196
  lm_hidden_states is None
1197
  and input_ids is not None
1198
  and self._esmc is not None
1199
- ):
1200
- lm_hidden_states = self._compute_lm_hidden_states(
1201
- input_ids,
1202
- asym_id,
1203
- residue_index,
1204
- mol_type,
1205
- tok_mask,
1206
- lm_mask_pct=(
1207
- self.config.lm_mask_pct
1208
- if lm_mask_pct is None
1209
- else lm_mask_pct
1210
- ),
1211
- )
1212
  lm_z: Tensor | None = None
1213
  if lm_hidden_states is not None:
1214
  lm_z = self.language_model(lm_hidden_states.detach())
@@ -1222,35 +1222,35 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
1222
  a = a.view(1, 1, 1, -1).to(device=z.device, dtype=z.dtype)
1223
  b_mat = b.to(device=z.device, dtype=z.dtype)
1224
 
1225
- _msa_inputs: dict | None = None
1226
- if self.msa_encoder is not None and msa is not None:
1227
- msa_attention_mask = maybe_apply_msa_column_masking(
1228
- msa_attention_mask,
1229
- msa_column_mask_rate,
1230
- )
1231
- _msa_inputs = dict(
1232
- x_inputs=x_inputs,
1233
- msa=msa,
1234
- msa_attention_mask=msa_attention_mask,
1235
- has_deletion=has_deletion,
1236
- deletion_value=deletion_value,
1237
- max_depth=msa_max_depth,
1238
- subsample_enabled=msa_subsample_at_inference,
1239
- )
1240
 
1241
  # Method call (not inline loop) frees per-iter L²×c_z locals.
1242
  z = self._run_one_loop(
1243
- z=z,
1244
- z_init=z_init,
1245
- lm_z=lm_z,
1246
- _msa_inputs=_msa_inputs,
1247
- pair_mask=pair_mask,
1248
- a=a,
1249
- b_mat=b_mat,
1250
- tok_mask=tok_mask,
1251
- total_steps=total_steps,
1252
- )
1253
- del z_init, lm_z, _msa_inputs, a, b_mat
1254
 
1255
  z = self.parcae_readout(z)
1256
  z = self.parcae_coda(z, pair_attention_mask=pair_mask)
@@ -1362,38 +1362,38 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
1362
  complex_id=complex_id,
1363
  )
1364
 
1365
- def _fold_protein_no_ttt(
1366
- self,
1367
- sequence: str,
1368
- *,
1369
- chain_id: str = "A",
1370
- msa: Any | None = None,
1371
- msa_path: str | Path | None = None,
1372
- msa_max_sequences: int | None = None,
1373
- num_loops: int = 3,
1374
- num_sampling_steps: int = 50,
1375
- num_diffusion_samples: int = 1,
1376
- seed: int | None = None,
1377
- complex_id: str = "pred",
1378
- ):
1379
- from .esmfold2_types import MSA, ProteinInput, StructurePredictionInput
1380
-
1381
- assert not (
1382
- msa is not None and msa_path is not None
1383
- ), "Pass at most one of msa or msa_path."
1384
- if msa_path is not None:
1385
- msa = MSA.from_a3m(msa_path, max_sequences=msa_max_sequences)
1386
- if msa is not None:
1387
- query = str(msa.query).replace("-", "").upper()
1388
- assert query == sequence.upper(), (
1389
- f"MSA query does not match sequence: expected {sequence.upper()!r}, got {query!r}"
1390
- )
1391
-
1392
- input = StructurePredictionInput(
1393
- sequences=[ProteinInput(id=chain_id, sequence=sequence, msa=msa)]
1394
- )
1395
- return self.fold(
1396
- input,
1397
  num_loops=num_loops,
1398
  num_sampling_steps=num_sampling_steps,
1399
  num_diffusion_samples=num_diffusion_samples,
@@ -1442,15 +1442,15 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
1442
 
1443
  def fold_protein(
1444
  self,
1445
- sequence: str,
1446
- *,
1447
- chain_id: str = "A",
1448
- msa: Any | None = None,
1449
- msa_path: str | Path | None = None,
1450
- msa_max_sequences: int | None = None,
1451
- num_loops: int = 3,
1452
- num_sampling_steps: int = 50,
1453
- num_diffusion_samples: int = 1,
1454
  seed: int | None = None,
1455
  complex_id: str = "pred",
1456
  ttt: bool = False,
@@ -1458,57 +1458,57 @@ class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
1458
  ):
1459
  if ttt:
1460
  return self.fold_protein_ttt(
1461
- sequence=sequence,
1462
- chain_id=chain_id,
1463
- msa=msa,
1464
- msa_path=msa_path,
1465
- msa_max_sequences=msa_max_sequences,
1466
- num_loops=num_loops,
1467
- num_sampling_steps=num_sampling_steps,
1468
- num_diffusion_samples=num_diffusion_samples,
1469
  seed=seed,
1470
  complex_id=complex_id,
1471
  ttt_config=ttt_config,
1472
  )
1473
  return self._fold_protein_no_ttt(
1474
- sequence=sequence,
1475
- chain_id=chain_id,
1476
- msa=msa,
1477
- msa_path=msa_path,
1478
- msa_max_sequences=msa_max_sequences,
1479
- num_loops=num_loops,
1480
- num_sampling_steps=num_sampling_steps,
1481
- num_diffusion_samples=num_diffusion_samples,
1482
  seed=seed,
1483
  complex_id=complex_id,
1484
  )
1485
 
1486
  def fold_protein_ttt(
1487
  self,
1488
- sequence: str,
1489
- *,
1490
- chain_id: str = "A",
1491
- msa: Any | None = None,
1492
- msa_path: str | Path | None = None,
1493
- msa_max_sequences: int | None = None,
1494
- num_loops: int = 3,
1495
- num_sampling_steps: int = 50,
1496
- num_diffusion_samples: int = 1,
1497
  seed: int | None = None,
1498
  complex_id: str = "pred",
1499
  ttt_config: TTTConfig | dict[str, Any] | None = None,
1500
- ):
1501
- assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
1502
- if self._esmc_fp8:
1503
- raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
1504
- fold_kwargs = {
1505
- "chain_id": chain_id,
1506
- "msa": msa,
1507
- "msa_path": msa_path,
1508
- "msa_max_sequences": msa_max_sequences,
1509
- "num_loops": num_loops,
1510
- "num_sampling_steps": num_sampling_steps,
1511
- "num_diffusion_samples": num_diffusion_samples,
1512
  "seed": seed,
1513
  "complex_id": complex_id,
1514
  }
 
59
  TriangleMultiplicativeUpdate,
60
  _categorical_mean,
61
  _compute_intra_token_idx,
62
+ compute_lm_hidden_states,
63
+ gather_rep_atom_coords,
64
+ gather_token_to_atom,
65
+ maybe_apply_msa_column_masking,
66
+ maybe_subsample_msa,
67
+ )
68
  from .esmfold2_affine3d import Affine3D as _FastPLMSESMFold2Affine3D
69
  from .esmfold2_aligner import Aligner as _FastPLMSESMFold2Aligner
70
  from .esmfold2_atom_indexer import AtomIndexer as _FastPLMSESMFold2AtomIndexer
 
699
  self.post_init()
700
  self.init_ttt({"lora_target_replace_module": "MultiHeadAttention"})
701
 
702
+ def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None:
703
+ """Load the FastPLMs ESM++ LM used as the ESMFold2 PLM backbone.
704
+
705
+ ``precision``: ``"bf16"`` (default), ``"fp32"``, or opt-in ``"fp8"``.
706
+ """
707
+ dtype_map = {
708
+ "bf16": torch.bfloat16,
709
+ "fp32": torch.float32,
710
+ "fp8": torch.bfloat16,
711
+ }
712
+ if precision not in dtype_map:
713
+ raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
714
+ if precision == "fp8" and not TE_AVAILABLE:
715
+ raise RuntimeError(
716
+ "esmc_precision='fp8' requires transformer_engine.pytorch."
717
+ )
718
+ dtype = dtype_map[precision]
719
+
720
+ esmc = _load_fastplms_esmplusplus_for_esmfold2(
721
+ esmc_model_path=esmc_model_path,
722
+ attn_backend=self.config.esmc_attn_backend,
723
  device=self.device,
724
  dtype=dtype,
725
  )
 
730
  assert esmc.config.num_hidden_layers == self.config.lm_num_layers, (
731
  f"ESMFold2 expected lm_num_layers={self.config.lm_num_layers}, "
732
  f"but loaded ESM++ num_hidden_layers={esmc.config.num_hidden_layers}."
733
+ )
734
+ for p in esmc.parameters():
735
+ p.requires_grad_(False)
736
+
737
+ if precision == "fp8":
738
+ with torch.no_grad():
739
+ _convert_te_modules_to_fp8_inplace(esmc)
740
+
741
+ self._esmc_fp8 = precision == "fp8"
742
+ self._esmc = esmc
743
+ self._ttt_lm_head = None
744
+
745
+ def _ensure_ttt_lm_head(self) -> None:
746
+ assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
747
+ if self._esmc_fp8:
748
+ raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
749
+ if self._ttt_lm_head is not None:
750
+ return
751
  try:
752
  from fastplms.esm_plusplus.modeling_esm_plusplus import (
753
  ESMplusplusConfig,
 
781
  self._ttt_lm_head.requires_grad_(False)
782
  del mlm
783
 
784
+ def _ttt_get_trainable_modules(self) -> list[nn.Module]:
785
+ assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
786
+ if self._esmc_fp8:
787
+ raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
788
+ return [self._esmc]
789
 
790
  def _ttt_tokenize(
791
  self,
 
846
  **kwargs,
847
  ) -> torch.Tensor:
848
  del kwargs
849
+ assert isinstance(batch, torch.Tensor), (
850
+ "ESMFold2 TTT expects input_ids tensors."
851
+ )
852
+ assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
853
+ if self._esmc_fp8:
854
+ raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
855
+ self._ensure_ttt_lm_head()
856
  assert self._ttt_lm_head is not None
857
  attention_mask = batch.ne(SEQUENCE_PAD_TOKEN)
858
  output = self._esmc(
 
947
  if self.msa_encoder is not None:
948
  self.msa_encoder.set_chunk_size(chunk_size)
949
 
950
+ def _compute_lm_hidden_states(
951
+ self,
952
+ input_ids: Tensor,
953
+ asym_id: Tensor,
954
+ residue_index: Tensor,
955
+ mol_type: Tensor,
956
+ tok_mask: Tensor,
957
+ lm_mask_pct: float = 0.0,
958
+ ) -> Tensor:
959
+ assert self._esmc is not None
960
+ # fp8 TE kernels require prod(shape[:-1]) % 8 == 0.
961
+ pad_to = 8 if self._esmc_fp8 else None
962
+ with _lm_precision_context(self._esmc_fp8):
963
  return compute_lm_hidden_states(
964
  self._esmc,
965
  input_ids,
966
  asym_id,
967
  residue_index,
968
+ mol_type,
969
+ tok_mask,
970
+ pad_to_multiple=pad_to,
971
+ lm_mask_pct=lm_mask_pct,
972
+ mask_token_id=SEQUENCE_MASK_TOKEN,
973
+ )
974
 
975
  def _discretized_dynamics(self) -> tuple[Tensor, Tensor]:
976
  delta = F.softplus(self.parcae_log_delta)
 
985
  return state.to(dtype=ref.dtype)
986
 
987
  def _run_one_loop(
988
+ self,
989
+ z: Tensor,
990
+ z_init: Tensor,
991
+ lm_z: Tensor | None,
992
+ _msa_inputs: dict | None,
993
+ pair_mask: Tensor,
994
+ a: Tensor,
995
+ b_mat: Tensor,
996
+ tok_mask: Tensor,
997
+ total_steps: int,
998
+ ) -> Tensor:
999
  # Helper method (not inline) so per-iter locals free on return —
1000
  # otherwise leaks ~2 GB L²×c_z into distogram/sample scope.
1001
  # training=True forces dropout under eval(), matching the per-loop
 
1025
  if lm_z_i is not None and self.lm_encoder is None:
1026
  z_inject_pair = z_inject_pair + lm_z_i.to(z_inject_pair.dtype)
1027
 
1028
+ if self.msa_encoder is not None and _msa_inputs is not None:
1029
+ msa_i, mask_i, hd_i, dv_i = maybe_subsample_msa(
1030
+ _msa_inputs["msa"],
1031
+ _msa_inputs["msa_attention_mask"],
1032
+ _msa_inputs["has_deletion"],
1033
+ _msa_inputs["deletion_value"],
1034
+ max_depth=_msa_inputs["max_depth"],
1035
+ enabled=_msa_inputs["subsample_enabled"],
1036
+ )
1037
+ B_msa, M, L_msa = msa_i.shape
1038
+ msa_oh = F.one_hot(
1039
+ msa_i.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES
1040
+ ).float()
1041
+ msa_attn = (
1042
+ mask_i.permute(0, 2, 1).float()
1043
+ if mask_i is not None
1044
+ else tok_mask[:, :, None].expand(-1, -1, M).float()
1045
+ )
1046
+ # Bias-free MSAEncoder.embed requires zeroed padding.
1047
+ msa_oh = msa_oh * msa_attn.unsqueeze(-1)
1048
+ hd = (
1049
+ hd_i.permute(0, 2, 1).float()
1050
+ if hd_i is not None
1051
+ else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
1052
+ )
1053
+ dv = (
1054
+ dv_i.permute(0, 2, 1).float()
1055
+ if dv_i is not None
1056
+ else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
1057
+ )
1058
+ msa_pair = self.msa_encoder(
1059
+ x_pair=z_inject_pair,
1060
+ x_inputs=_msa_inputs["x_inputs"],
1061
+ msa_oh=msa_oh,
1062
+ has_deletion=hd,
1063
+ deletion_value=dv,
1064
+ msa_attention_mask=msa_attn,
1065
+ ).to(z_inject_pair.dtype)
1066
+ z_inject_pair = (
1067
+ msa_pair
1068
+ if self.config.msa_encoder_overwrite
1069
+ else (z_inject_pair + msa_pair)
1070
+ )
1071
 
1072
  if refined_lm_z is not None:
1073
  z_inject_pair = z_inject_pair + refined_lm_z.to(z_inject_pair.dtype)
 
1104
  deletion_value: Tensor | None = None,
1105
  msa_attention_mask: Tensor | None = None,
1106
  input_ids: Tensor | None = None,
1107
+ lm_hidden_states: Tensor | None = None,
1108
+ num_loops: int | None = None,
1109
+ num_diffusion_samples: int | None = None,
1110
+ num_sampling_steps: int | None = None,
1111
+ lm_mask_pct: float | None = None,
1112
+ msa_max_depth: int = 1024,
1113
+ msa_column_mask_rate: float = 0.1,
1114
+ msa_subsample_at_inference: bool = True,
1115
+ **kwargs,
1116
+ ) -> dict[str, Tensor]:
1117
  tok_mask = token_attention_mask
1118
  atm_mask = atom_attention_mask
1119
  disto_idx = distogram_atom_idx
 
1196
  lm_hidden_states is None
1197
  and input_ids is not None
1198
  and self._esmc is not None
1199
+ ):
1200
+ lm_hidden_states = self._compute_lm_hidden_states(
1201
+ input_ids,
1202
+ asym_id,
1203
+ residue_index,
1204
+ mol_type,
1205
+ tok_mask,
1206
+ lm_mask_pct=(
1207
+ self.config.lm_mask_pct
1208
+ if lm_mask_pct is None
1209
+ else lm_mask_pct
1210
+ ),
1211
+ )
1212
  lm_z: Tensor | None = None
1213
  if lm_hidden_states is not None:
1214
  lm_z = self.language_model(lm_hidden_states.detach())
 
1222
  a = a.view(1, 1, 1, -1).to(device=z.device, dtype=z.dtype)
1223
  b_mat = b.to(device=z.device, dtype=z.dtype)
1224
 
1225
+ _msa_inputs: dict | None = None
1226
+ if self.msa_encoder is not None and msa is not None:
1227
+ msa_attention_mask = maybe_apply_msa_column_masking(
1228
+ msa_attention_mask,
1229
+ msa_column_mask_rate,
1230
+ )
1231
+ _msa_inputs = dict(
1232
+ x_inputs=x_inputs,
1233
+ msa=msa,
1234
+ msa_attention_mask=msa_attention_mask,
1235
+ has_deletion=has_deletion,
1236
+ deletion_value=deletion_value,
1237
+ max_depth=msa_max_depth,
1238
+ subsample_enabled=msa_subsample_at_inference,
1239
+ )
1240
 
1241
  # Method call (not inline loop) frees per-iter L²×c_z locals.
1242
  z = self._run_one_loop(
1243
+ z=z,
1244
+ z_init=z_init,
1245
+ lm_z=lm_z,
1246
+ _msa_inputs=_msa_inputs,
1247
+ pair_mask=pair_mask,
1248
+ a=a,
1249
+ b_mat=b_mat,
1250
+ tok_mask=tok_mask,
1251
+ total_steps=total_steps,
1252
+ )
1253
+ del z_init, lm_z, _msa_inputs, a, b_mat
1254
 
1255
  z = self.parcae_readout(z)
1256
  z = self.parcae_coda(z, pair_attention_mask=pair_mask)
 
1362
  complex_id=complex_id,
1363
  )
1364
 
1365
+ def _fold_protein_no_ttt(
1366
+ self,
1367
+ sequence: str,
1368
+ *,
1369
+ chain_id: str = "A",
1370
+ msa: Any | None = None,
1371
+ msa_path: str | Path | None = None,
1372
+ msa_max_sequences: int | None = None,
1373
+ num_loops: int = 3,
1374
+ num_sampling_steps: int = 50,
1375
+ num_diffusion_samples: int = 1,
1376
+ seed: int | None = None,
1377
+ complex_id: str = "pred",
1378
+ ):
1379
+ from .esmfold2_types import MSA, ProteinInput, StructurePredictionInput
1380
+
1381
+ assert not (
1382
+ msa is not None and msa_path is not None
1383
+ ), "Pass at most one of msa or msa_path."
1384
+ if msa_path is not None:
1385
+ msa = MSA.from_a3m(msa_path, max_sequences=msa_max_sequences)
1386
+ if msa is not None:
1387
+ query = str(msa.query).replace("-", "").upper()
1388
+ assert query == sequence.upper(), (
1389
+ f"MSA query does not match sequence: expected {sequence.upper()!r}, got {query!r}"
1390
+ )
1391
+
1392
+ input = StructurePredictionInput(
1393
+ sequences=[ProteinInput(id=chain_id, sequence=sequence, msa=msa)]
1394
+ )
1395
+ return self.fold(
1396
+ input,
1397
  num_loops=num_loops,
1398
  num_sampling_steps=num_sampling_steps,
1399
  num_diffusion_samples=num_diffusion_samples,
 
1442
 
1443
  def fold_protein(
1444
  self,
1445
+ sequence: str,
1446
+ *,
1447
+ chain_id: str = "A",
1448
+ msa: Any | None = None,
1449
+ msa_path: str | Path | None = None,
1450
+ msa_max_sequences: int | None = None,
1451
+ num_loops: int = 3,
1452
+ num_sampling_steps: int = 50,
1453
+ num_diffusion_samples: int = 1,
1454
  seed: int | None = None,
1455
  complex_id: str = "pred",
1456
  ttt: bool = False,
 
1458
  ):
1459
  if ttt:
1460
  return self.fold_protein_ttt(
1461
+ sequence=sequence,
1462
+ chain_id=chain_id,
1463
+ msa=msa,
1464
+ msa_path=msa_path,
1465
+ msa_max_sequences=msa_max_sequences,
1466
+ num_loops=num_loops,
1467
+ num_sampling_steps=num_sampling_steps,
1468
+ num_diffusion_samples=num_diffusion_samples,
1469
  seed=seed,
1470
  complex_id=complex_id,
1471
  ttt_config=ttt_config,
1472
  )
1473
  return self._fold_protein_no_ttt(
1474
+ sequence=sequence,
1475
+ chain_id=chain_id,
1476
+ msa=msa,
1477
+ msa_path=msa_path,
1478
+ msa_max_sequences=msa_max_sequences,
1479
+ num_loops=num_loops,
1480
+ num_sampling_steps=num_sampling_steps,
1481
+ num_diffusion_samples=num_diffusion_samples,
1482
  seed=seed,
1483
  complex_id=complex_id,
1484
  )
1485
 
1486
  def fold_protein_ttt(
1487
  self,
1488
+ sequence: str,
1489
+ *,
1490
+ chain_id: str = "A",
1491
+ msa: Any | None = None,
1492
+ msa_path: str | Path | None = None,
1493
+ msa_max_sequences: int | None = None,
1494
+ num_loops: int = 3,
1495
+ num_sampling_steps: int = 50,
1496
+ num_diffusion_samples: int = 1,
1497
  seed: int | None = None,
1498
  complex_id: str = "pred",
1499
  ttt_config: TTTConfig | dict[str, Any] | None = None,
1500
+ ):
1501
+ assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
1502
+ if self._esmc_fp8:
1503
+ raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
1504
+ fold_kwargs = {
1505
+ "chain_id": chain_id,
1506
+ "msa": msa,
1507
+ "msa_path": msa_path,
1508
+ "msa_max_sequences": msa_max_sequences,
1509
+ "num_loops": num_loops,
1510
+ "num_sampling_steps": num_sampling_steps,
1511
+ "num_diffusion_samples": num_diffusion_samples,
1512
  "seed": seed,
1513
  "complex_id": complex_id,
1514
  }
modeling_esmfold2_common.py CHANGED
@@ -140,61 +140,61 @@ _EPS = 1e-5
140
  # chunk=64 leaves headroom for the largest foldbench targets). Override via
141
  # ``model.set_chunk_size(...)``; pass None to disable chunking (faster for
142
  # short L but OOM-prone past ~600).
143
- _DEFAULT_CHUNK_SIZE = 64
144
-
145
-
146
- # ===========================================================================
147
- # MSA inference-time diversity augmentations
148
- # ===========================================================================
149
-
150
-
151
- def maybe_subsample_msa(
152
- msa: Tensor,
153
- msa_attention_mask: Tensor | None,
154
- has_deletion: Tensor | None,
155
- deletion_value: Tensor | None,
156
- *,
157
- max_depth: int | None,
158
- enabled: bool,
159
- ) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]:
160
- if not enabled or max_depth is None:
161
- return msa, msa_attention_mask, has_deletion, deletion_value
162
-
163
- depth = msa.size(1)
164
- if depth <= 1 or depth <= max_depth:
165
- return msa, msa_attention_mask, has_deletion, deletion_value
166
-
167
- indices = torch.zeros(max_depth, dtype=torch.long, device=msa.device)
168
- indices[1:] = torch.randperm(depth - 1, device=msa.device)[: max_depth - 1] + 1
169
- indices = indices.sort().values
170
-
171
- msa = msa[:, indices]
172
- if msa_attention_mask is not None:
173
- msa_attention_mask = msa_attention_mask[:, indices]
174
- if has_deletion is not None:
175
- has_deletion = has_deletion[:, indices]
176
- if deletion_value is not None:
177
- deletion_value = deletion_value[:, indices]
178
- return msa, msa_attention_mask, has_deletion, deletion_value
179
-
180
-
181
- def maybe_apply_msa_column_masking(
182
- msa_attention_mask: Tensor | None,
183
- rate: float,
184
- ) -> Tensor | None:
185
- if msa_attention_mask is None or rate <= 0.0 or msa_attention_mask.size(1) <= 1:
186
- return msa_attention_mask
187
-
188
- batch_size, _, length = msa_attention_mask.shape
189
- col_keep = torch.rand(batch_size, length, device=msa_attention_mask.device) >= rate
190
- col_keep = col_keep.unsqueeze(1).expand_as(msa_attention_mask).clone()
191
- col_keep[:, 0, :] = True
192
- return msa_attention_mask.bool() & col_keep
193
-
194
-
195
- # ===========================================================================
196
- # Atom-token utilities
197
- # ===========================================================================
198
 
199
 
200
  def gather_token_to_atom(token_features: Tensor, atom_to_token_idx: Tensor) -> Tensor:
@@ -2182,17 +2182,17 @@ def _seed_context(seed: int | None, *, cuda: bool = True):
2182
  # ===========================================================================
2183
 
2184
 
2185
- def compute_lm_hidden_states(
2186
- esmc: nn.Module,
2187
- input_ids: Tensor,
2188
- asym_id: Tensor,
2189
- residue_index: Tensor,
2190
- mol_type: Tensor,
2191
- token_mask: Tensor,
2192
- pad_to_multiple: int | None = None,
2193
- lm_mask_pct: float = 0.0,
2194
- mask_token_id: int = 32,
2195
- ) -> Tensor:
2196
  """Run ESMC with BOS/EOS wrapping, return hidden states [B, L, N, D] with N=81 layers.
2197
 
2198
  Atom-tokenized modified residues (HYP, MSE, ACE, NH2, ...) span multiple
@@ -2277,21 +2277,21 @@ def compute_lm_hidden_states(
2277
  for b in range(B):
2278
  lm_input_ids[b, : lm_lengths[b]] = lm_input_list[b]
2279
 
2280
- # sequence_id for chain-aware attention; PAD tokens get -1 (no attention).
2281
- sequence_id = (lm_input_ids == 0).cumsum(dim=1) - 1 # BOS=0
2282
- sequence_id = sequence_id.masked_fill(lm_input_ids == 1, -1) # PAD=1
2283
-
2284
- if lm_mask_pct > 0.0:
2285
- special = (lm_input_ids == 0) | (lm_input_ids == 1) | (lm_input_ids == 2)
2286
- do_mask = (
2287
- torch.rand(lm_input_ids.shape, device=device) < lm_mask_pct
2288
- ) & ~special
2289
- lm_input_ids = lm_input_ids.masked_fill(do_mask, mask_token_id)
2290
-
2291
- with torch.inference_mode():
2292
- esmc_out = esmc(
2293
- input_ids=lm_input_ids, sequence_id=sequence_id, output_hidden_states=True
2294
- )
2295
 
2296
  hs = esmc_out.hidden_states # [n_layers+1, B, max_len, D]
2297
  n_layers_plus_1, _, _, D = hs.shape
 
140
  # chunk=64 leaves headroom for the largest foldbench targets). Override via
141
  # ``model.set_chunk_size(...)``; pass None to disable chunking (faster for
142
  # short L but OOM-prone past ~600).
143
+ _DEFAULT_CHUNK_SIZE = 64
144
+
145
+
146
+ # ===========================================================================
147
+ # MSA inference-time diversity augmentations
148
+ # ===========================================================================
149
+
150
+
151
+ def maybe_subsample_msa(
152
+ msa: Tensor,
153
+ msa_attention_mask: Tensor | None,
154
+ has_deletion: Tensor | None,
155
+ deletion_value: Tensor | None,
156
+ *,
157
+ max_depth: int | None,
158
+ enabled: bool,
159
+ ) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]:
160
+ if not enabled or max_depth is None:
161
+ return msa, msa_attention_mask, has_deletion, deletion_value
162
+
163
+ depth = msa.size(1)
164
+ if depth <= 1 or depth <= max_depth:
165
+ return msa, msa_attention_mask, has_deletion, deletion_value
166
+
167
+ indices = torch.zeros(max_depth, dtype=torch.long, device=msa.device)
168
+ indices[1:] = torch.randperm(depth - 1, device=msa.device)[: max_depth - 1] + 1
169
+ indices = indices.sort().values
170
+
171
+ msa = msa[:, indices]
172
+ if msa_attention_mask is not None:
173
+ msa_attention_mask = msa_attention_mask[:, indices]
174
+ if has_deletion is not None:
175
+ has_deletion = has_deletion[:, indices]
176
+ if deletion_value is not None:
177
+ deletion_value = deletion_value[:, indices]
178
+ return msa, msa_attention_mask, has_deletion, deletion_value
179
+
180
+
181
+ def maybe_apply_msa_column_masking(
182
+ msa_attention_mask: Tensor | None,
183
+ rate: float,
184
+ ) -> Tensor | None:
185
+ if msa_attention_mask is None or rate <= 0.0 or msa_attention_mask.size(1) <= 1:
186
+ return msa_attention_mask
187
+
188
+ batch_size, _, length = msa_attention_mask.shape
189
+ col_keep = torch.rand(batch_size, length, device=msa_attention_mask.device) >= rate
190
+ col_keep = col_keep.unsqueeze(1).expand_as(msa_attention_mask).clone()
191
+ col_keep[:, 0, :] = True
192
+ return msa_attention_mask.bool() & col_keep
193
+
194
+
195
+ # ===========================================================================
196
+ # Atom-token utilities
197
+ # ===========================================================================
198
 
199
 
200
  def gather_token_to_atom(token_features: Tensor, atom_to_token_idx: Tensor) -> Tensor:
 
2182
  # ===========================================================================
2183
 
2184
 
2185
+ def compute_lm_hidden_states(
2186
+ esmc: nn.Module,
2187
+ input_ids: Tensor,
2188
+ asym_id: Tensor,
2189
+ residue_index: Tensor,
2190
+ mol_type: Tensor,
2191
+ token_mask: Tensor,
2192
+ pad_to_multiple: int | None = None,
2193
+ lm_mask_pct: float = 0.0,
2194
+ mask_token_id: int = 32,
2195
+ ) -> Tensor:
2196
  """Run ESMC with BOS/EOS wrapping, return hidden states [B, L, N, D] with N=81 layers.
2197
 
2198
  Atom-tokenized modified residues (HYP, MSE, ACE, NH2, ...) span multiple
 
2277
  for b in range(B):
2278
  lm_input_ids[b, : lm_lengths[b]] = lm_input_list[b]
2279
 
2280
+ # sequence_id for chain-aware attention; PAD tokens get -1 (no attention).
2281
+ sequence_id = (lm_input_ids == 0).cumsum(dim=1) - 1 # BOS=0
2282
+ sequence_id = sequence_id.masked_fill(lm_input_ids == 1, -1) # PAD=1
2283
+
2284
+ if lm_mask_pct > 0.0:
2285
+ special = (lm_input_ids == 0) | (lm_input_ids == 1) | (lm_input_ids == 2)
2286
+ do_mask = (
2287
+ torch.rand(lm_input_ids.shape, device=device) < lm_mask_pct
2288
+ ) & ~special
2289
+ lm_input_ids = lm_input_ids.masked_fill(do_mask, mask_token_id)
2290
+
2291
+ with torch.inference_mode():
2292
+ esmc_out = esmc(
2293
+ input_ids=lm_input_ids, sequence_id=sequence_id, output_hidden_states=True
2294
+ )
2295
 
2296
  hs = esmc_out.hidden_states # [n_layers+1, B, max_len, D]
2297
  n_layers_plus_1, _, _, D = hs.shape
modeling_esmfold2_experimental.py CHANGED
@@ -521,14 +521,14 @@ class ESMFold2ExperimentalModel(PreTrainedModel):
521
  "bf16": torch.bfloat16,
522
  "fp32": torch.float32,
523
  }
524
- if precision not in dtype_map:
525
- if precision == "fp8":
526
- raise RuntimeError(
527
- "esmc_precision='fp8' is supported only by the standard "
528
- "released ESMFold2 model. The experimental binder-design "
529
- "model keeps the FastPLMs ESM++ backbone in bf16 or fp32."
530
- )
531
- raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
532
  esmc = _load_fastplms_esmplusplus_for_esmfold2(
533
  esmc_model_path=esmc_model_path,
534
  attn_backend=self.config.esmc_attn_backend,
@@ -852,13 +852,29 @@ class ESMFold2ExperimentalModel(PreTrainedModel):
852
  return_atom_repr=False,
853
  denoising_early_exit_rmsd=(0.10 if early_exit else None),
854
  )
855
- sample_coords = structure_output["sample_atom_coords"]
856
- assert sample_coords is not None
857
-
858
- output: dict[str, Tensor] = {
859
- "distogram_logits": distogram_logits,
860
- "sample_atom_coords": sample_coords,
861
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  if calculate_confidence and self.confidence_head is not None:
863
  confidence_output = self.confidence_head(
864
  s_inputs=x_inputs.detach(),
 
521
  "bf16": torch.bfloat16,
522
  "fp32": torch.float32,
523
  }
524
+ if precision not in dtype_map:
525
+ if precision == "fp8":
526
+ raise RuntimeError(
527
+ "esmc_precision='fp8' is supported only by the standard "
528
+ "released ESMFold2 model. The experimental binder-design "
529
+ "model keeps the FastPLMs ESM++ backbone in bf16 or fp32."
530
+ )
531
+ raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
532
  esmc = _load_fastplms_esmplusplus_for_esmfold2(
533
  esmc_model_path=esmc_model_path,
534
  attn_backend=self.config.esmc_attn_backend,
 
852
  return_atom_repr=False,
853
  denoising_early_exit_rmsd=(0.10 if early_exit else None),
854
  )
855
+ sample_coords = structure_output["sample_atom_coords"]
856
+ assert sample_coords is not None
857
+ if sample_coords.ndim == 4:
858
+ batch, sample_count, atom_count, coord_dim = sample_coords.shape
859
+ sample_coords_for_gather = sample_coords.reshape(
860
+ batch * sample_count,
861
+ atom_count,
862
+ coord_dim,
863
+ )
864
+ rep_idx = distogram_atom_idx.repeat_interleave(sample_count, 0).long()
865
+ else:
866
+ sample_coords_for_gather = sample_coords
867
+ rep_idx = distogram_atom_idx.long()
868
+ representative_atom_coords = gather_rep_atom_coords(
869
+ sample_coords_for_gather,
870
+ rep_idx,
871
+ )
872
+
873
+ output: dict[str, Tensor] = {
874
+ "distogram_logits": distogram_logits,
875
+ "sample_atom_coords": sample_coords,
876
+ "representative_atom_coords": representative_atom_coords,
877
+ }
878
  if calculate_confidence and self.confidence_head is not None:
879
  confidence_output = self.confidence_head(
880
  s_inputs=x_inputs.detach(),