ESMFold2-Experimental-Fast / modeling_esmfold2_experimental.py
lhallee's picture
Upload folder using huggingface_hub
ad52206 verified
Raw
History Blame Contribute Delete
39.8 kB
"""FastPLMs ESMFold2 experimental architecture.
This module supports Biohub's experimental binder-design checkpoints. The
released ESMFold2 architecture in ``modeling_esmfold2.py`` intentionally
rejects those configs because the experimental trunk uses explicit pair-loop
re-injection and a different confidence/MSA stack.
"""
from __future__ import annotations
import math
from pathlib import Path
from typing import Any, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers.modeling_utils import PreTrainedModel
from .configuration_esmfold2 import ESMFold2Config
from .modeling_esmfold2 import (
_load_fastplms_esmplusplus_for_esmfold2,
_lm_precision_context,
)
from .modeling_esmfold2_common import (
CHAR_VOCAB_SIZE,
MAX_ATOMIC_NUMBER,
NUM_RES_TYPES,
DiffusionModule,
DiffusionStructureHead,
DiffusionTransformer,
FoldingTrunk,
InputsEmbedder,
LanguageModelShim,
MSAPairWeightedAveraging,
OuterProductMean,
PairUpdateBlock,
ResIdxAsymIdSymIdEntityIdEncoding,
RowAttentionPooling,
SwiGLUMLP,
TriangleMultiplicativeUpdate,
_categorical_mean,
_compute_intra_token_idx,
_seed_context,
compute_lm_hidden_states,
gather_rep_atom_coords,
gather_token_to_atom,
)
_EPS = 1e-5
_NONPOLYMER_ID = 3
class ConfidenceHead(nn.Module):
"""Experimental confidence head predicting pLDDT, PAE, pTM, and ipTM."""
boundaries: Tensor
def __init__(self, config: ESMFold2Config) -> None:
super().__init__()
ch = config.confidence_head
d_single = config.d_single
d_pair = config.d_pair
d_inputs = config.inputs.d_inputs
boundaries = torch.linspace(ch.min_dist, ch.max_dist, ch.distogram_bins - 1)
self.register_buffer("boundaries", boundaries)
self.dist_bin_pairwise_embed = nn.Embedding(ch.distogram_bins, d_pair)
self.s_norm = nn.LayerNorm(d_single)
self.s_inputs_to_single = nn.Linear(d_inputs, d_single, bias=False)
self.s_to_z = nn.Linear(d_inputs, d_pair, bias=False)
self.s_to_z_transpose = nn.Linear(d_inputs, d_pair, bias=False)
self.s_to_z_prod_in1 = nn.Linear(d_inputs, d_pair, bias=False)
self.s_to_z_prod_in2 = nn.Linear(d_inputs, d_pair, bias=False)
self.s_to_z_prod_out = nn.Linear(d_pair, d_pair, bias=False)
self.s_input_to_s = nn.Linear(d_inputs, d_single, bias=False)
self.s_inputs_norm = nn.LayerNorm(d_inputs)
self.z_norm = nn.LayerNorm(d_pair)
self.row_attention_pooling = RowAttentionPooling(
d_pair=d_pair, d_single=d_single
)
pf = ch.folding_trunk
self.folding_trunk = FoldingTrunk(
n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4
)
self.plddt_ln = nn.LayerNorm(d_single)
max_atoms_per_token = 23
self.plddt_weight = nn.Parameter(
torch.zeros(max_atoms_per_token, d_single, ch.num_plddt_bins)
)
self.pae_head = nn.Linear(d_pair, ch.num_pae_bins, bias=False)
def set_kernel_backend(self, backend: str | None) -> None:
self.folding_trunk.set_kernel_backend(backend)
def set_chunk_size(self, chunk_size: int | None) -> None:
self.folding_trunk.set_chunk_size(chunk_size)
@staticmethod
def _repeat_batch(x: Tensor, num_diffusion_samples: int) -> Tensor:
if num_diffusion_samples == 1:
return x
return x.repeat_interleave(num_diffusion_samples, 0)
@staticmethod
def _flatten_sample_axis(x: Tensor) -> Tensor:
if x.ndim == 4:
b, mult, n, c = x.shape
return x.reshape(b * mult, n, c)
return x
def forward(
self,
s_inputs: Tensor,
z: Tensor,
x_pred: Tensor,
distogram_atom_idx: Tensor,
token_attention_mask: Tensor,
atom_to_token: Tensor,
atom_attention_mask: Tensor,
asym_id: Tensor,
mol_type: Tensor,
num_diffusion_samples: int = 1,
relative_position_encoding: Tensor | None = None,
token_bonds_encoding: Tensor | None = None,
) -> dict[str, Tensor]:
s_inputs_normed = self.s_inputs_norm(s_inputs)
z_base = self.z_norm(z)
if relative_position_encoding is not None:
z_base = z_base + relative_position_encoding
if token_bonds_encoding is not None:
z_base = z_base + token_bonds_encoding
z_base = z_base + self.s_to_z(s_inputs_normed).unsqueeze(2)
z_base = z_base + self.s_to_z_transpose(s_inputs_normed).unsqueeze(1)
z_base = z_base + self.s_to_z_prod_out(
self.s_to_z_prod_in1(s_inputs_normed)[:, :, None, :]
* self.s_to_z_prod_in2(s_inputs_normed)[:, None, :, :]
)
pair = self._repeat_batch(z_base, num_diffusion_samples)
x_pred_flat = self._flatten_sample_axis(x_pred)
atom_to_token_m = self._repeat_batch(atom_to_token, num_diffusion_samples)
atom_mask_m = self._repeat_batch(atom_attention_mask, num_diffusion_samples)
rep_idx_m = self._repeat_batch(distogram_atom_idx, num_diffusion_samples).long()
mask = self._repeat_batch(token_attention_mask, num_diffusion_samples)
batch_mult = pair.shape[0]
rep_coords = gather_rep_atom_coords(x_pred_flat, rep_idx_m)
rep_distances = torch.cdist(
rep_coords, rep_coords, compute_mode="donot_use_mm_for_euclid_dist"
)
distogram_bins = (
(rep_distances.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
)
pair = pair + self.dist_bin_pairwise_embed(distogram_bins)
pair_mask = mask[:, :, None].float() * mask[:, None, :].float()
pair = pair + self.folding_trunk(pair, pair_attention_mask=pair_mask)
single = self.row_attention_pooling(pair, mask)
atom_mask_f = atom_mask_m.float()
s_at_atoms = gather_token_to_atom(single, atom_to_token_m)
s_at_atoms = self.plddt_ln(s_at_atoms)
intra_idx = _compute_intra_token_idx(atom_to_token_m)
intra_idx = intra_idx.clamp(max=self.plddt_weight.shape[0] - 1)
plddt_weight = self.plddt_weight[intra_idx]
plddt_logits = torch.einsum("...c,...cb->...b", s_at_atoms, plddt_weight)
plddt_per_atom = _categorical_mean(plddt_logits, start=0.0, end=1.0)
length = single.shape[1]
plddt_sum = torch.zeros(
batch_mult, length, device=single.device, dtype=plddt_per_atom.dtype
)
atom_count = torch.zeros(
batch_mult, length, device=single.device, dtype=plddt_per_atom.dtype
)
atom_mask_t = atom_mask_f.to(plddt_per_atom.dtype)
plddt_sum.scatter_add_(1, atom_to_token_m, plddt_per_atom * atom_mask_t)
atom_count.scatter_add_(1, atom_to_token_m, atom_mask_t)
plddt = plddt_sum / atom_count.clamp(min=1e-6)
complex_plddt = (plddt_per_atom * atom_mask_f).sum(dim=-1) / (
atom_mask_f.sum(dim=-1) + _EPS
)
expanded_type = self._repeat_batch(mol_type, num_diffusion_samples)
expanded_asym = self._repeat_batch(asym_id, num_diffusion_samples)
is_ligand = (expanded_type == _NONPOLYMER_ID).float()
inter_chain = (
expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2)
).float()
near_contact = (rep_distances < 8).float()
interface_per_token = (
near_contact * inter_chain * (1.0 - is_ligand).unsqueeze(-1)
).amax(dim=-1)
iplddt_weight = torch.where(
is_ligand.bool(),
torch.full_like(interface_per_token, 2.0),
interface_per_token,
)
iplddt_weight_atoms = gather_token_to_atom(
iplddt_weight.unsqueeze(-1), atom_to_token_m
).squeeze(-1)
atom_iplddt_w = atom_mask_f * iplddt_weight_atoms
complex_iplddt = (plddt_per_atom * atom_iplddt_w).sum(dim=-1) / (
atom_iplddt_w.sum(dim=-1) + _EPS
)
plddt_ca = plddt_per_atom.gather(1, rep_idx_m)
pae_logits = self.pae_head(pair)
pae = _categorical_mean(pae_logits, start=0.0, end=32.0).detach()
n_bins = pae_logits.shape[-1]
bin_width = 32.0 / n_bins
bin_centers = torch.arange(
0.5 * bin_width, 32.0, bin_width, device=pae_logits.device
)
mask_f = mask.float()
n_res = mask_f.sum(dim=-1, keepdim=True)
d0 = 1.24 * (n_res.clamp(min=19) - 15) ** (1 / 3) - 1.8
tm_per_bin = 1 / (1 + (bin_centers / d0) ** 2)
pae_probs = F.softmax(pae_logits, dim=-1)
tm_expected = (pae_probs * tm_per_bin[:, None, None, :]).sum(dim=-1)
pair_mask_2d = mask_f.unsqueeze(-1) * mask_f.unsqueeze(-2)
ptm_per_row = (tm_expected * pair_mask_2d).sum(dim=-1) / (
pair_mask_2d.sum(dim=-1) + _EPS
)
ptm = ptm_per_row.max(dim=-1).values
inter_chain_mask = (
expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2)
).float() * pair_mask_2d
iptm_per_row = (tm_expected * inter_chain_mask).sum(dim=-1) / (
inter_chain_mask.sum(dim=-1) + _EPS
)
iptm = iptm_per_row.max(dim=-1).values
max_chain_id = int(expanded_asym.max().item()) if batch_mult > 0 else 0
n_chains = max_chain_id + 1
pair_chains_iptm = torch.zeros(
batch_mult,
n_chains,
n_chains,
device=tm_expected.device,
dtype=tm_expected.dtype,
)
for c1 in range(n_chains):
chain_c1 = (expanded_asym == c1).float() * mask_f
if chain_c1.sum() == 0:
continue
for c2 in range(n_chains):
chain_c2 = (expanded_asym == c2).float() * mask_f
pair_m = chain_c1.unsqueeze(-1) * chain_c2.unsqueeze(-2)
denom = pair_m.sum(dim=(-1, -2)) + _EPS
pair_chains_iptm[:, c1, c2] = (tm_expected * pair_m).sum(
dim=(-1, -2)
) / denom
return {
"plddt_logits": plddt_logits,
"plddt": plddt.detach(),
"plddt_per_atom": plddt_per_atom.detach(),
"plddt_ca": plddt_ca.detach(),
"complex_plddt": complex_plddt.detach(),
"complex_iplddt": complex_iplddt.detach(),
"pae_logits": pae_logits,
"pae": pae,
"ptm": ptm.detach(),
"iptm": iptm.detach(),
"pair_chains_iptm": pair_chains_iptm.detach(),
}
class _TransitionFFN(nn.Module):
def __init__(self, d_model: int, expansion_ratio: int = 4) -> None:
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.ffn = SwiGLUMLP(d_model, expansion_ratio=expansion_ratio, bias=False)
def forward(self, x: Tensor) -> Tensor:
return self.ffn(self.norm(x))
class MSAEncoderBlock(nn.Module):
"""One experimental MSA update block."""
def __init__(
self,
d_msa: int,
d_pair: int,
d_hidden: int = 32,
n_heads_msa: int = 8,
msa_head_width: int = 32,
) -> None:
super().__init__()
self.outer_product_mean = OuterProductMean(
d_msa, d_hidden, d_pair, divide_outer_before_proj=True
)
self.msa_pair_weighted_averaging = MSAPairWeightedAveraging(
d_msa, d_pair, n_heads_msa, msa_head_width
)
self.msa_transition = _TransitionFFN(d_msa, expansion_ratio=4)
self.tri_mul_out = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=True)
self.tri_mul_in = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=False)
self.pair_transition = _TransitionFFN(d_pair, expansion_ratio=4)
def set_chunk_size(self, chunk_size: int | None) -> None:
self.outer_product_mean.set_chunk_size(chunk_size)
self.tri_mul_out.set_chunk_size(chunk_size)
self.tri_mul_in.set_chunk_size(chunk_size)
def forward(
self,
msa_repr: Tensor,
pair_repr: Tensor,
msa_attention_mask: Tensor,
pair_attention_mask: Tensor,
msa_track_mask: Tensor | None = None,
) -> tuple[Tensor, Tensor]:
mask4d = (
msa_track_mask[:, None, None, None].to(dtype=msa_repr.dtype)
if msa_track_mask is not None
else None
)
pair_mask4d = mask4d[:, :, :1] if mask4d is not None else None
msa_update = self.msa_pair_weighted_averaging(
msa_repr, pair_repr, pair_attention_mask
)
if mask4d is not None:
msa_update = msa_update * mask4d
msa_repr = msa_repr + msa_update
msa_transition = self.msa_transition(msa_repr)
if mask4d is not None:
msa_transition = msa_transition * mask4d
msa_repr = msa_repr + msa_transition
pair_opm = self.outer_product_mean(msa_repr, msa_attention_mask)
if pair_mask4d is not None:
pair_opm = pair_opm * pair_mask4d
pair_repr = pair_repr + pair_opm
pair_out = self.tri_mul_out(pair_repr, mask=pair_attention_mask)
if pair_mask4d is not None:
pair_out = pair_out * pair_mask4d
pair_repr = pair_repr + pair_out
pair_in = self.tri_mul_in(pair_repr, mask=pair_attention_mask)
if pair_mask4d is not None:
pair_in = pair_in * pair_mask4d
pair_repr = pair_repr + pair_in
pair_transition = self.pair_transition(pair_repr)
if pair_mask4d is not None:
pair_transition = pair_transition * pair_mask4d
pair_repr = pair_repr + pair_transition
return msa_repr, pair_repr
class MSAEncoder(nn.Module):
def __init__(
self,
d_msa: int,
d_pair: int,
d_inputs: int,
d_hidden: int = 32,
n_layers: int = 4,
n_heads_msa: int = 8,
msa_head_width: int = 32,
) -> None:
super().__init__()
self.embed = nn.Linear(35, d_msa, bias=False)
self.project_inputs = nn.Linear(d_inputs, d_msa, bias=False)
self.blocks = nn.ModuleList(
[
MSAEncoderBlock(
d_msa=d_msa,
d_pair=d_pair,
d_hidden=d_hidden,
n_heads_msa=n_heads_msa,
msa_head_width=msa_head_width,
)
for _ in range(n_layers)
]
)
def set_chunk_size(self, chunk_size: int | None) -> None:
for block in self.blocks:
cast(MSAEncoderBlock, block).set_chunk_size(chunk_size)
def forward(
self,
x_pair: Tensor,
x_inputs: Tensor,
msa_oh: Tensor,
has_deletion: Tensor,
deletion_value: Tensor,
msa_attention_mask: Tensor,
) -> Tensor:
batch_size, _, depth = msa_attention_mask.shape
m_feat = torch.cat(
[msa_oh, has_deletion.unsqueeze(-1), deletion_value.unsqueeze(-1)],
dim=-1,
)
m = self.embed(m_feat) + self.project_inputs(x_inputs).unsqueeze(2)
if depth > 1:
msa_track_mask = msa_attention_mask[:, :, 1:].any(dim=(1, 2))
else:
msa_track_mask = torch.zeros(
batch_size, dtype=torch.bool, device=x_pair.device
)
tok_mask = msa_attention_mask[:, :, 0]
pair_attention_mask = tok_mask.unsqueeze(2) * tok_mask.unsqueeze(1)
for block in self.blocks:
m, x_pair = cast(MSAEncoderBlock, block)(
m,
x_pair,
msa_attention_mask,
pair_attention_mask,
msa_track_mask,
)
return x_pair * msa_track_mask[:, None, None, None].to(dtype=x_pair.dtype)
class ESMFold2ExperimentalModel(PreTrainedModel):
"""Experimental ESMFold2 architecture used by binder-design checkpoints."""
config_class = ESMFold2Config
_keys_to_ignore_on_load_unexpected = [r"\._extra_state$"]
def __init__(self, config: ESMFold2Config) -> None:
super().__init__(config)
d_inputs = config.inputs.d_inputs
d_pair = config.d_pair
self.inputs_embedder = InputsEmbedder(config)
self.z_init_1 = nn.Linear(d_inputs, d_pair, bias=False)
self.z_init_2 = nn.Linear(d_inputs, d_pair, bias=False)
self.rel_pos = ResIdxAsymIdSymIdEntityIdEncoding(
n_relative_residx_bins=config.n_relative_residx_bins,
n_relative_chain_bins=config.n_relative_chain_bins,
d_pair=d_pair,
)
self.token_bonds = nn.Linear(1, d_pair, bias=False)
self.language_model = LanguageModelShim(
d_z=d_pair, d_model=config.lm_d_model, num_layers=config.lm_num_layers
)
self._esmc: nn.Module | None = None
self._esmc_fp8 = False
self._esmfold2_input_builder: Any | None = None
pf = config.folding_trunk
self.folding_trunk = FoldingTrunk(
n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4
)
self.pair_loop_proj = nn.Sequential(
nn.LayerNorm(d_pair), nn.Linear(d_pair, d_pair, bias=False)
)
nn.init.zeros_(cast(nn.Linear, self.pair_loop_proj[1]).weight)
self.structure_head = DiffusionStructureHead(config)
self.distogram_head = nn.Linear(
d_pair, config.structure_head.distogram_bins, bias=True
)
self.confidence_head: ConfidenceHead | None = (
ConfidenceHead(config) if config.confidence_head.enabled else None
)
msa_cfg = config.msa_encoder
self.msa_encoder: MSAEncoder | None = None
if msa_cfg.enabled:
self.msa_encoder = MSAEncoder(
d_msa=msa_cfg.d_msa,
d_pair=d_pair,
d_inputs=d_inputs,
d_hidden=msa_cfg.d_hidden,
n_layers=msa_cfg.n_layers,
n_heads_msa=msa_cfg.n_heads_msa,
msa_head_width=msa_cfg.msa_head_width,
)
self.post_init()
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def set_kernel_backend(self, backend: str | None) -> None:
self.folding_trunk.set_kernel_backend(backend)
if self.confidence_head is not None:
self.confidence_head.set_kernel_backend(backend)
self.structure_head.set_kernel_backend(backend)
def set_chunk_size(self, chunk_size: int | None) -> None:
self.folding_trunk.set_chunk_size(chunk_size)
if self.confidence_head is not None:
self.confidence_head.set_chunk_size(chunk_size)
if self.msa_encoder is not None:
self.msa_encoder.set_chunk_size(chunk_size)
def configure_lm_dropout(
self,
lm_dropout: float,
*,
force_lm_dropout_during_inference: bool = True,
) -> None:
self.config.lm_dropout = lm_dropout
self.config.force_lm_dropout_during_inference = (
force_lm_dropout_during_inference
)
def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None:
dtype_map = {
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
if precision not in dtype_map:
if precision == "fp8":
raise RuntimeError(
"esmc_precision='fp8' is supported only by the standard "
"released ESMFold2 model. The experimental binder-design "
"model keeps the FastPLMs ESM++ backbone in bf16 or fp32."
)
raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
esmc = _load_fastplms_esmplusplus_for_esmfold2(
esmc_model_path=esmc_model_path,
attn_backend=self.config.esmc_attn_backend,
device=self.device,
dtype=dtype_map[precision],
)
assert esmc.config.hidden_size == self.config.lm_d_model, (
f"ESMFold2 expected lm_d_model={self.config.lm_d_model}, "
f"but loaded ESM++ hidden_size={esmc.config.hidden_size}."
)
assert esmc.config.num_hidden_layers == self.config.lm_num_layers, (
f"ESMFold2 expected lm_num_layers={self.config.lm_num_layers}, "
f"but loaded ESM++ num_hidden_layers={esmc.config.num_hidden_layers}."
)
for parameter in esmc.parameters():
parameter.requires_grad_(False)
self._esmc_fp8 = False
self._esmc = esmc
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*model_args,
load_esmc: bool = True,
**kwargs,
):
if "config" not in kwargs:
kwargs["config"] = ESMFold2Config.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
esmc_precision = kwargs.pop("esmc_precision", "bf16")
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
if load_esmc:
model.load_esmc(model.config.esmc_id, precision=esmc_precision)
return model
def apply_torch_compile(
self, mode: str = "fixed_seqlen", dynamic: bool | None = None
) -> None:
import torch._dynamo
torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.accumulated_cache_size_limit = 512
torch._dynamo.config.capture_scalar_outputs = True
if dynamic is None:
dynamic = mode == "dynamic_seqlen"
compile_kwargs: dict[str, bool] = {"dynamic": dynamic}
compile_targets = (
PairUpdateBlock,
DiffusionTransformer,
DiffusionModule,
MSAEncoderBlock,
)
def _maybe_compile(module: nn.Module) -> None:
if isinstance(module, compile_targets):
module.forward = torch.compile(module.forward, **compile_kwargs)
self.apply(_maybe_compile)
def _compute_lm_hidden_states(
self,
input_ids: Tensor,
asym_id: Tensor,
residue_index: Tensor,
mol_type: Tensor,
tok_mask: Tensor,
) -> Tensor:
assert self._esmc is not None
pad_to = 8 if self._esmc_fp8 else None
with _lm_precision_context(self._esmc_fp8):
return compute_lm_hidden_states(
self._esmc,
input_ids,
asym_id,
residue_index,
mol_type,
tok_mask,
pad_to_multiple=pad_to,
)
def forward(
self,
token_index: Tensor,
residue_index: Tensor,
asym_id: Tensor,
sym_id: Tensor,
entity_id: Tensor,
mol_type: Tensor,
res_type: Tensor,
token_bonds: Tensor,
token_attention_mask: Tensor,
ref_pos: Tensor,
ref_element: Tensor,
ref_charge: Tensor,
ref_atom_name_chars: Tensor,
ref_space_uid: Tensor,
atom_attention_mask: Tensor,
atom_to_token: Tensor,
distogram_atom_idx: Tensor,
deletion_mean: Tensor | None = None,
msa: Tensor | None = None,
has_deletion: Tensor | None = None,
deletion_value: Tensor | None = None,
msa_attention_mask: Tensor | None = None,
input_ids: Tensor | None = None,
lm_hidden_states: Tensor | None = None,
res_type_soft: Tensor | None = None,
num_loops: int | None = None,
num_diffusion_samples: int | None = None,
num_sampling_steps: int | None = None,
early_exit: bool = False,
seed: int | None = None,
calculate_confidence: bool = True,
provide_soft_sequence_to_msa_and_profile: bool = True,
noise_scale: float | None = None,
step_scale: float | None = None,
max_inference_sigma: int | None = None,
) -> dict[str, Tensor]:
del noise_scale, step_scale, max_inference_sigma
tok_mask = token_attention_mask
atm_mask = atom_attention_mask
n_loops = num_loops if num_loops is not None else self.config.num_loops
n_samples = (
num_diffusion_samples
if num_diffusion_samples is not None
else self.config.num_diffusion_samples
)
if res_type.dim() == 2:
res_type_oh = F.one_hot(res_type.long(), num_classes=NUM_RES_TYPES).float()
res_type_oh = res_type_oh * tok_mask.unsqueeze(-1).float()
else:
res_type_oh = res_type.float()
if msa is not None:
msa_oh_profile = F.one_hot(msa.long(), num_classes=NUM_RES_TYPES).float()
if msa_attention_mask is not None:
mask_f = msa_attention_mask.float().unsqueeze(-1)
msa_oh_profile = msa_oh_profile * mask_f
valid_seq_count = msa_attention_mask.float().sum(dim=1).clamp(min=1)
profile = msa_oh_profile.sum(dim=1) / valid_seq_count.unsqueeze(-1)
else:
profile = msa_oh_profile.mean(dim=1)
else:
profile = res_type_oh
if res_type_soft is not None:
res_type_oh = res_type_soft.float()
if (
not self.config.disable_msa_features
and provide_soft_sequence_to_msa_and_profile
):
profile = res_type_oh
msa = res_type_oh.unsqueeze(1)
msa_attention_mask = tok_mask.unsqueeze(1)
if deletion_mean is None:
deletion_mean = torch.zeros(
res_type.shape[0], res_type.shape[1], device=res_type.device
)
if self.config.disable_msa_features:
profile = torch.zeros_like(profile)
deletion_mean = torch.zeros_like(deletion_mean)
ref_element_oh = F.one_hot(
ref_element.long(), num_classes=MAX_ATOMIC_NUMBER
).float()
ref_atom_name_chars_oh = F.one_hot(
ref_atom_name_chars.long(), num_classes=CHAR_VOCAB_SIZE
).float()
atm_mask_f = atm_mask.float()
ref_element_oh = ref_element_oh * atm_mask_f.unsqueeze(-1)
ref_atom_name_chars_oh = ref_atom_name_chars_oh * atm_mask_f.unsqueeze(
-1
).unsqueeze(-1)
atom_to_token = atom_to_token * atm_mask.long()
use_amp = ref_pos.device.type == "cuda"
with (
torch.set_grad_enabled(res_type_soft is not None),
torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16),
):
x_inputs = self.inputs_embedder(
aatype=res_type_oh,
profile=profile.float(),
deletion_mean=deletion_mean.float(),
ref_pos=ref_pos,
atom_attention_mask=atm_mask,
ref_space_uid=ref_space_uid,
ref_charge=ref_charge,
ref_element=ref_element_oh,
ref_atom_name_chars=ref_atom_name_chars_oh,
atom_to_token=atom_to_token,
)
z_init = self.z_init_1(x_inputs).unsqueeze(2) + self.z_init_2(
x_inputs
).unsqueeze(1)
relative_position_encoding = self.rel_pos(
residue_index=residue_index,
asym_id=asym_id,
sym_id=sym_id,
entity_id=entity_id,
token_index=token_index,
)
token_bonds_encoding = self.token_bonds(token_bonds.float())
z_init = z_init + relative_position_encoding + token_bonds_encoding
if (
lm_hidden_states is None
and input_ids is not None
and self._esmc is not None
):
lm_hidden_states = self._compute_lm_hidden_states(
input_ids, asym_id, residue_index, mol_type, tok_mask
)
if lm_hidden_states is not None:
lm_dropout = (
self.config.lm_dropout
if self.config.force_lm_dropout_during_inference or self.training
else 0.0
)
lm_z = self.language_model(
lm_hidden_states.detach(), lm_dropout=lm_dropout
)
z_init = z_init + lm_z.to(z_init.dtype)
msa_kwargs: dict[str, Tensor] | None = None
if self.msa_encoder is not None and msa is not None:
if msa.dim() == 4:
batch_msa, depth, length_msa, _ = msa.shape
msa_oh = msa.permute(0, 2, 1, 3).float()
else:
batch_msa, depth, length_msa = msa.shape
msa_oh = F.one_hot(
msa.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES
).float()
msa_attn = (
msa_attention_mask.permute(0, 2, 1).float()
if msa_attention_mask is not None
else tok_mask[:, :, None].expand(-1, -1, depth).float()
)
msa_oh = msa_oh * msa_attn.unsqueeze(-1)
hd = (
has_deletion.permute(0, 2, 1).float()
if has_deletion is not None
else torch.zeros(batch_msa, length_msa, depth, device=msa.device)
)
dv = (
deletion_value.permute(0, 2, 1).float()
if deletion_value is not None
else torch.zeros(batch_msa, length_msa, depth, device=msa.device)
)
msa_kwargs = {
"x_inputs": x_inputs,
"msa_oh": msa_oh,
"has_deletion": hd,
"deletion_value": dv,
"msa_attention_mask": msa_attn,
}
pair_mask = tok_mask[:, :, None].float() * tok_mask[:, None, :].float()
z = torch.zeros_like(z_init)
prev_pair: Tensor | None = None
prev_disto_probs: Tensor | None = None
for loop_num in range(n_loops + 1):
z = z_init + self.pair_loop_proj(z)
if msa_kwargs is not None and self.msa_encoder is not None:
z = z + self.msa_encoder(x_pair=z, **msa_kwargs).to(z.dtype)
z = self.folding_trunk(z, pair_attention_mask=pair_mask)
if early_exit and loop_num < n_loops:
l2_converged = False
if prev_pair is not None and loop_num > 0:
rel_l2 = (z.float() - prev_pair.float()).norm() / prev_pair.float().norm().clamp(
min=1e-8
)
l2_converged = rel_l2.item() < 0.25
prev_pair = z.detach().clone()
sym_z = z.float() + z.float().transpose(-2, -3)
cur_probs = F.softmax(self.distogram_head(sym_z).float(), dim=-1)
if prev_disto_probs is not None and loop_num > 0:
kl_per_pair = (
cur_probs
* (
cur_probs.clamp(min=1e-8)
/ prev_disto_probs.clamp(min=1e-8)
).log()
).sum(-1)
kl = (kl_per_pair + kl_per_pair.transpose(-1, -2)).mean() / 2
if l2_converged or kl.item() < 0.05:
break
prev_disto_probs = cur_probs.detach()
distogram_logits = self.distogram_head(z + z.transpose(-2, -3))
with torch.no_grad(), _seed_context(seed):
structure_output = self.structure_head.sample(
z_trunk=z.float(),
s_inputs=x_inputs,
s_trunk=None,
relative_position_encoding=relative_position_encoding,
ref_pos=ref_pos,
ref_charge=ref_charge,
ref_mask=atm_mask,
ref_element=ref_element_oh,
ref_atom_name_chars=ref_atom_name_chars_oh,
ref_space_uid=ref_space_uid,
tok_idx=atom_to_token,
asym_id=asym_id,
residue_index=residue_index,
entity_id=entity_id,
token_index=token_index,
sym_id=sym_id,
token_attention_mask=tok_mask,
num_diffusion_samples=n_samples,
num_sampling_steps=num_sampling_steps,
return_atom_repr=False,
denoising_early_exit_rmsd=(0.10 if early_exit else None),
)
sample_coords = structure_output["sample_atom_coords"]
assert sample_coords is not None
if sample_coords.ndim == 4:
batch, sample_count, atom_count, coord_dim = sample_coords.shape
sample_coords_for_gather = sample_coords.reshape(
batch * sample_count,
atom_count,
coord_dim,
)
rep_idx = distogram_atom_idx.repeat_interleave(sample_count, 0).long()
else:
sample_coords_for_gather = sample_coords
rep_idx = distogram_atom_idx.long()
representative_atom_coords = gather_rep_atom_coords(
sample_coords_for_gather,
rep_idx,
)
output: dict[str, Tensor] = {
"distogram_logits": distogram_logits,
"sample_atom_coords": sample_coords,
"representative_atom_coords": representative_atom_coords,
}
if calculate_confidence and self.confidence_head is not None:
confidence_output = self.confidence_head(
s_inputs=x_inputs.detach(),
z=z.detach().float(),
x_pred=sample_coords.detach(),
distogram_atom_idx=distogram_atom_idx,
token_attention_mask=tok_mask,
atom_to_token=atom_to_token,
atom_attention_mask=atm_mask,
asym_id=asym_id,
mol_type=mol_type,
num_diffusion_samples=n_samples,
relative_position_encoding=relative_position_encoding.detach(),
token_bonds_encoding=token_bonds_encoding.detach(),
)
output.update(confidence_output)
output["atom_pad_mask"] = (
atm_mask.unsqueeze(0) if atm_mask.dim() == 1 else atm_mask
)
output["residue_index"] = residue_index
output["entity_id"] = entity_id
return output
@property
def input_builder(self):
if self._esmfold2_input_builder is None:
from .esmfold2_processor import ESMFold2InputBuilder
self._esmfold2_input_builder = ESMFold2InputBuilder()
return self._esmfold2_input_builder
@property
def input_types(self):
from . import esmfold2_types
return esmfold2_types
def prepare_structure_input(self, input, seed: int | None = None):
return self.input_builder.prepare_input(input, seed=seed, device=self.device)
@torch.no_grad()
def infer_protein(self, seq: str, **forward_kwargs) -> dict[str, Tensor]:
from .protein_utils import prepare_protein_features
features = prepare_protein_features(seq)
features = {name: tensor.to(self.device) for name, tensor in features.items()}
output = self(**features, **forward_kwargs)
for name in (
"res_type",
"atom_to_token",
"ref_atom_name_chars",
"atom_attention_mask",
"token_attention_mask",
"residue_index",
):
output[name] = features[name]
return output
def fold(
self,
input,
*,
num_loops: int = 3,
num_sampling_steps: int = 50,
num_diffusion_samples: int = 1,
seed: int | None = None,
noise_scale: float | None = None,
step_scale: float | None = None,
max_inference_sigma: int | None = None,
early_exit: bool = False,
complex_id: str = "pred",
):
return self.input_builder.fold(
self,
input,
num_loops=num_loops,
num_sampling_steps=num_sampling_steps,
num_diffusion_samples=num_diffusion_samples,
seed=seed,
noise_scale=noise_scale,
step_scale=step_scale,
max_inference_sigma=max_inference_sigma,
early_exit=early_exit,
complex_id=complex_id,
)
def fold_protein(
self,
sequence: str,
*,
chain_id: str = "A",
num_loops: int = 3,
num_sampling_steps: int = 50,
num_diffusion_samples: int = 1,
seed: int | None = None,
complex_id: str = "pred",
):
from .esmfold2_types import ProteinInput, StructurePredictionInput
input = StructurePredictionInput(
sequences=[ProteinInput(id=chain_id, sequence=sequence)]
)
return self.fold(
input,
num_loops=num_loops,
num_sampling_steps=num_sampling_steps,
num_diffusion_samples=num_diffusion_samples,
seed=seed,
complex_id=complex_id,
)
@staticmethod
def result_to_cif(result) -> str:
assert not isinstance(result, list), "Pass one MolecularComplexResult at a time."
return result.complex.to_mmcif()
@staticmethod
def result_to_pdb(result) -> str:
assert not isinstance(result, list), "Pass one MolecularComplexResult at a time."
return result.complex.to_protein_complex().to_pdb_string()
def save_as_cif(self, result, output_path: str | Path) -> None:
Path(output_path).write_text(self.result_to_cif(result))
def save_as_pdb(self, result, output_path: str | Path) -> None:
Path(output_path).write_text(self.result_to_pdb(result))
def infer_protein_as_cif(self, seq: str, **forward_kwargs) -> str:
return self.result_to_cif(self.fold_protein(seq, **forward_kwargs))
def infer_protein_as_pdb(self, seq: str, **forward_kwargs) -> str:
return self.result_to_pdb(self.fold_protein(seq, **forward_kwargs))
__all__ = [
"ConfidenceHead",
"MSAEncoder",
"MSAEncoderBlock",
"ESMFold2ExperimentalModel",
]