Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2-Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2-Fast with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/ESMFold2-Fast", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2-Fast", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import typing as T | |
| from abc import ABC | |
| from dataclasses import dataclass | |
| import torch | |
| from torch.nn import functional as F | |
| from typing_extensions import Self | |
| from .esmfold2_misc import fp32_autocast_context | |
| class Rotation(ABC): | |
| def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ... | |
| def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ... | |
| def __getitem__(self, idx: T.Any) -> Self: ... | |
| def tensor(self) -> torch.Tensor: | |
| # We claim that this should be zero-cost abstraction that returns the raw tensor backing this | |
| # object. The raw tensor should always have exactly 1 more dim than self.shape, which should be | |
| # implemented using reshaping | |
| ... | |
| def shape(self) -> torch.Size: | |
| # The "shape" of the rotation, as if it was a torch.tensor object | |
| # This means that 1x4 quaternions are treated as size (1,) for example | |
| ... | |
| def as_matrix(self) -> RotationMatrix: ... | |
| def as_quat(self, normalize: bool = False) -> RotationQuat: ... | |
| def compose(self, other: Self) -> Self: | |
| # To be safe, we force users to explicitly convert between rotation types. | |
| ... | |
| def convert_compose(self, other: Self) -> Self: | |
| # This function will automatically convert between types of rotations | |
| ... | |
| def apply(self, p: torch.Tensor) -> torch.Tensor: | |
| # rotates points by this rotation object | |
| ... | |
| def invert(self) -> Self: ... | |
| def dtype(self) -> torch.dtype: | |
| return self.tensor.dtype | |
| def device(self) -> torch.device: | |
| return self.tensor.device | |
| def requires_grad(self) -> bool: | |
| return self.tensor.requires_grad | |
| def _from_tensor(cls, t: torch.Tensor) -> Self: | |
| # This function exists to simplify the below functions, esp type signatures | |
| # Its implementation is different from Affine3D.from_tensor and does not | |
| # autodetect rotation types. | |
| return cls(t) # type: ignore | |
| def to(self, **kwargs) -> Self: | |
| return self._from_tensor(self.tensor.to(**kwargs)) | |
| def detach(self, *args, **kwargs) -> Self: | |
| return self._from_tensor(self.tensor.detach(**kwargs)) | |
| def tensor_apply(self, func) -> Self: | |
| # Applys a function to the underlying tensor | |
| return self._from_tensor( | |
| torch.stack([func(x) for x in self.tensor.unbind(dim=-1)], dim=-1) | |
| ) | |
| class RotationMatrix(Rotation): | |
| def __init__(self, rots: torch.Tensor): | |
| if rots.shape[-1] == 9: | |
| rots = rots.unflatten(-1, (3, 3)) | |
| assert rots.shape[-1] == 3 | |
| assert rots.shape[-2] == 3 | |
| # Force full precision | |
| rots = rots.to(torch.float32) | |
| self._rots = rots | |
| def identity(cls, shape, **tensor_kwargs): | |
| rots = torch.eye(3, **tensor_kwargs) | |
| rots = rots.view(*[1 for _ in range(len(shape))], 3, 3) | |
| rots = rots.expand(*shape, -1, -1) | |
| return cls(rots) | |
| def random(cls, shape, **tensor_kwargs): | |
| return RotationQuat.random(shape, **tensor_kwargs).as_matrix() | |
| def __getitem__(self, idx: T.Any) -> RotationMatrix: | |
| indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx) | |
| return RotationMatrix(self._rots[indices + (slice(None), slice(None))]) | |
| def shape(self) -> torch.Size: | |
| return self._rots.shape[:-2] | |
| def as_matrix(self) -> RotationMatrix: | |
| return self | |
| def as_quat(self, normalize: bool = False) -> RotationQuat: | |
| m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( | |
| self._rots.flatten(-2), dim=-1 | |
| ) | |
| q_abs = _sqrt_subgradient( | |
| torch.stack( | |
| [ | |
| 1.0 + m00 + m11 + m22, | |
| 1.0 + m00 - m11 - m22, | |
| 1.0 - m00 + m11 - m22, | |
| 1.0 - m00 - m11 + m22, | |
| ], | |
| dim=-1, | |
| ) | |
| ) | |
| # we produce the desired quaternion multiplied by each of r, i, j, k | |
| quat_by_rijk = torch.stack( | |
| [ | |
| x | |
| for lst in [ | |
| [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], | |
| [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], | |
| [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], | |
| [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], | |
| ] | |
| for x in lst | |
| ], | |
| dim=-1, | |
| ).unflatten(-1, (4, 4)) | |
| # We floor here at 0.1 but the exact level is not important; if q_abs is small, | |
| # the candidate won't be picked. | |
| flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) | |
| quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) | |
| # if not for numerical problems, quat_candidates[i] should be same (up to a sign), | |
| # forall i; we pick the best-conditioned one (with the largest denominator) | |
| # We manually implement one_hot so torch.compile works | |
| one_hot = torch.zeros_like(q_abs, dtype=torch.bool) | |
| one_hot.scatter_(-1, q_abs.argmax(dim=-1, keepdim=True), True) | |
| quat = quat_candidates[one_hot, :].reshape(q_abs.shape) | |
| return RotationQuat(quat) | |
| def compose(self, other: RotationMatrix) -> RotationMatrix: | |
| with fp32_autocast_context(self._rots.device.type): | |
| return RotationMatrix(self._rots @ other._rots) | |
| def convert_compose(self, other: Rotation): | |
| return self.compose(other.as_matrix()) | |
| def apply(self, p: torch.Tensor) -> torch.Tensor: | |
| with fp32_autocast_context(self.device.type): | |
| if self._rots.shape[-3] == 1: | |
| # This is a slight speedup over einsum for batched rotations | |
| return p @ self._rots.transpose(-1, -2).squeeze(-3) | |
| else: | |
| # einsum way faster than bmm! | |
| return torch.einsum("...ij,...j", self._rots, p) | |
| def invert(self) -> RotationMatrix: | |
| return RotationMatrix(self._rots.transpose(-1, -2)) | |
| def tensor(self) -> torch.Tensor: | |
| return self._rots.flatten(-2) | |
| def to_3x3(self) -> torch.Tensor: | |
| return self._rots | |
| def from_graham_schmidt( | |
| x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12 | |
| ) -> RotationMatrix: | |
| # A low eps here is necessary for good stability! | |
| return RotationMatrix(_graham_schmidt(x_axis, xy_plane, eps)) | |
| class RotationQuat(Rotation): | |
| def __init__(self, quats: torch.Tensor, normalized=False): | |
| assert quats.shape[-1] == 4 | |
| self._normalized = normalized | |
| # Force float32 as well | |
| if normalized: | |
| self._quats = F.normalize(quats.to(torch.float32), dim=-1) | |
| self._quats = self._quats.where(self._quats[..., :1] >= 0, -self._quats) | |
| else: | |
| self._quats = quats.to(torch.float32) | |
| def identity(cls, shape, **tensor_kwargs): | |
| q = torch.ones((*shape, 4), **tensor_kwargs) | |
| mult = torch.tensor([1, 0, 0, 0], device=q.device) | |
| return RotationQuat(q * mult) | |
| def random(cls, shape, **tensor_kwargs): | |
| quat = torch.randn((*shape, 4), **tensor_kwargs) | |
| return RotationQuat(quat, normalized=True) | |
| def __getitem__(self, idx: T.Any) -> RotationQuat: | |
| indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx) | |
| return RotationQuat(self._quats[indices + (slice(None),)]) | |
| def shape(self) -> torch.Size: | |
| return self._quats.shape[:-1] | |
| def compose(self, other: RotationQuat) -> RotationQuat: | |
| with fp32_autocast_context(self._quats.device.type): | |
| return RotationQuat(_quat_mult(self._quats, other._quats)) | |
| def convert_compose(self, other: Rotation): | |
| return self.compose(other.as_quat()) | |
| def as_matrix(self) -> RotationMatrix: | |
| q = self.normalized().tensor | |
| r, i, j, k = torch.unbind(q, -1) | |
| two_s = 2.0 / torch.linalg.norm(q, dim=-1) | |
| o = torch.stack( | |
| ( | |
| 1 - two_s * (j * j + k * k), | |
| two_s * (i * j - k * r), | |
| two_s * (i * k + j * r), | |
| two_s * (i * j + k * r), | |
| 1 - two_s * (i * i + k * k), | |
| two_s * (j * k - i * r), | |
| two_s * (i * k - j * r), | |
| two_s * (j * k + i * r), | |
| 1 - two_s * (i * i + j * j), | |
| ), | |
| -1, | |
| ) | |
| return RotationMatrix(o.reshape(q.shape[:-1] + (3, 3))) | |
| def as_quat(self, normalize: bool = False) -> RotationQuat: | |
| return self | |
| def apply(self, p: torch.Tensor) -> torch.Tensor: | |
| return _quat_rotation(self.normalized()._quats, p) | |
| def invert(self) -> RotationQuat: | |
| return RotationQuat(_quat_invert(self._quats)) | |
| def tensor(self) -> torch.Tensor: | |
| return self._quats | |
| def normalized(self) -> RotationQuat: | |
| return self if self._normalized else RotationQuat(self._quats, normalized=True) | |
| class Affine3D: | |
| trans: torch.Tensor | |
| rot: Rotation | |
| def __post_init__(self): | |
| assert self.trans.shape[:-1] == self.rot.shape | |
| def identity( | |
| shape_or_affine: T.Union[tuple[int, ...], "Affine3D"], | |
| rotation_type: T.Type[Rotation] = RotationMatrix, | |
| **tensor_kwargs, | |
| ): | |
| # Creates a new identity Affine3D object with a specified shape | |
| # or the same shape as another Affine3D object. | |
| if isinstance(shape_or_affine, Affine3D): | |
| kwargs = {"dtype": shape_or_affine.dtype, "device": shape_or_affine.device} | |
| kwargs.update(tensor_kwargs) | |
| shape = shape_or_affine.shape | |
| rotation_type = type(shape_or_affine.rot) | |
| else: | |
| kwargs = tensor_kwargs | |
| shape = shape_or_affine | |
| return Affine3D( | |
| torch.zeros((*shape, 3), **kwargs), rotation_type.identity(shape, **kwargs) | |
| ) | |
| def random( | |
| shape: tuple[int, ...], | |
| std: float = 1, | |
| rotation_type: T.Type[Rotation] = RotationMatrix, | |
| **tensor_kwargs, | |
| ) -> "Affine3D": | |
| return Affine3D( | |
| trans=torch.randn((*shape, 3), **tensor_kwargs).mul(std), | |
| rot=rotation_type.random(shape, **tensor_kwargs), | |
| ) | |
| def __getitem__(self, idx: T.Any) -> "Affine3D": | |
| indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx) | |
| return Affine3D(trans=self.trans[indices + (slice(None),)], rot=self.rot[idx]) | |
| def shape(self) -> torch.Size: | |
| return self.trans.shape[:-1] | |
| def dtype(self) -> torch.dtype: | |
| return self.trans.dtype | |
| def device(self) -> torch.device: | |
| return self.trans.device | |
| def requires_grad(self) -> bool: | |
| return self.trans.requires_grad | |
| def to(self, **kwargs) -> "Affine3D": | |
| return Affine3D(self.trans.to(**kwargs), self.rot.to(**kwargs)) | |
| def detach(self, *args, **kwargs) -> "Affine3D": | |
| return Affine3D(self.trans.detach(**kwargs), self.rot.detach(**kwargs)) | |
| def tensor_apply(self, func) -> "Affine3D": | |
| # Applys a function to the underlying tensor | |
| return self.from_tensor( | |
| torch.stack([func(x) for x in self.tensor.unbind(dim=-1)], dim=-1) | |
| ) | |
| def as_matrix(self): | |
| return Affine3D(trans=self.trans, rot=self.rot.as_matrix()) | |
| def as_quat(self, normalize: bool = False): | |
| return Affine3D(trans=self.trans, rot=self.rot.as_quat(normalize)) | |
| def compose(self, other: "Affine3D", autoconvert: bool = False): | |
| rot = self.rot | |
| new_rot = (rot.convert_compose if autoconvert else rot.compose)(other.rot) | |
| new_trans = rot.apply(other.trans) + self.trans | |
| return Affine3D(trans=new_trans, rot=new_rot) | |
| def compose_rotation(self, other: Rotation, autoconvert: bool = False): | |
| return Affine3D( | |
| trans=self.trans, | |
| rot=(self.rot.convert_compose if autoconvert else self.rot.compose)(other), | |
| ) | |
| def scale(self, v: torch.Tensor | float): | |
| return Affine3D(self.trans * v, self.rot) | |
| def mask(self, mask: torch.Tensor, with_zero=False): | |
| # Returns a transform where True positions in mask is identity | |
| if with_zero: | |
| tensor = self.tensor | |
| return Affine3D.from_tensor( | |
| torch.zeros_like(tensor).where(mask[..., None], tensor) | |
| ) | |
| else: | |
| identity = self.identity( | |
| self.shape, | |
| rotation_type=type(self.rot), | |
| device=self.device, | |
| dtype=self.dtype, | |
| ).tensor | |
| return Affine3D.from_tensor(identity.where(mask[..., None], self.tensor)) | |
| def apply(self, p: torch.Tensor) -> torch.Tensor: | |
| return self.rot.apply(p) + self.trans | |
| def invert(self): | |
| inv_rot = self.rot.invert() | |
| return Affine3D(trans=-inv_rot.apply(self.trans), rot=inv_rot) | |
| def tensor(self) -> torch.Tensor: | |
| return torch.cat([self.rot.tensor, self.trans], dim=-1) | |
| def from_tensor(t: torch.Tensor) -> "Affine3D": | |
| match t.shape[-1]: | |
| case 4: | |
| # Assume tensor 4x4 for backward compat with alphafold | |
| trans = t[..., :3, 3] | |
| rot = RotationMatrix(t[..., :3, :3]) | |
| case 6: | |
| # Assume quaternion representation with real part = 1 | |
| trans = t[..., -3:] | |
| rot = RotationQuat(F.pad(t[..., :3], (1, 0), value=1)) | |
| case 7: | |
| trans = t[..., -3:] | |
| rot = RotationQuat(t[..., :4]) | |
| case 12: | |
| trans = t[..., -3:] | |
| rot = RotationMatrix(t[..., :-3].unflatten(-1, (3, 3))) | |
| case _: | |
| raise RuntimeError( | |
| f"Cannot detect rotation fromat from {t.shape[-1] -3}-d flat vector" | |
| ) | |
| return Affine3D(trans, rot) | |
| def from_tensor_pair(t: torch.Tensor, r: torch.Tensor) -> "Affine3D": | |
| return Affine3D(t, RotationMatrix(r)) | |
| def from_graham_schmidt( | |
| neg_x_axis: torch.Tensor, | |
| origin: torch.Tensor, | |
| xy_plane: torch.Tensor, | |
| eps: float = 1e-10, | |
| ): | |
| # The arguments of this function is for parity with AlphaFold | |
| x_axis = origin - neg_x_axis | |
| xy_plane = xy_plane - origin | |
| return Affine3D( | |
| trans=origin, rot=RotationMatrix.from_graham_schmidt(x_axis, xy_plane, eps) | |
| ) | |
| def cat(affines: list["Affine3D"], dim: int = 0): | |
| if dim < 0: | |
| dim = len(affines[0].shape) + dim | |
| return Affine3D.from_tensor(torch.cat([x.tensor for x in affines], dim=dim)) | |
| def _quat_mult(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Multiply two quaternions. | |
| Usual torch rules for broadcasting apply. | |
| Args: | |
| a: Quaternions as tensor of shape (..., 4), real part first. | |
| b: Quaternions as tensor of shape (..., 4), real part first. | |
| Returns: | |
| The product of a and b, a tensor of quaternions shape (..., 4). | |
| """ | |
| aw, ax, ay, az = torch.unbind(a, -1) | |
| bw, bx, by, bz = torch.unbind(b, -1) | |
| ow = aw * bw - ax * bx - ay * by - az * bz | |
| ox = aw * bx + ax * bw + ay * bz - az * by | |
| oy = aw * by - ax * bz + ay * bw + az * bx | |
| oz = aw * bz + ax * by - ay * bx + az * bw | |
| return torch.stack((ow, ox, oy, oz), -1) | |
| def _quat_rotation(q: torch.Tensor, p: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Rotates p by quaternion q. Usual torch rules for broadcasting apply. | |
| Args: | |
| q: Quaternions as tensor of shape (..., 4), real part first. | |
| p: Points as tensor of shape (..., 3) | |
| Returns: | |
| The rotated version of p, of shape (..., 3) | |
| """ | |
| aw, ax, ay, az = torch.unbind(q, -1) | |
| bx, by, bz = torch.unbind(p, -1) | |
| # fmt: off | |
| ow = - ax * bx - ay * by - az * bz | |
| ox = aw * bx + ay * bz - az * by | |
| oy = aw * by - ax * bz + az * bx | |
| oz = aw * bz + ax * by - ay * bx | |
| # fmt: on | |
| q_mul_pts = torch.stack((ow, ox, oy, oz), -1) | |
| return _quat_mult(q_mul_pts, _quat_invert(q))[..., 1:] | |
| def _quat_invert(q: torch.Tensor): | |
| return q * torch.tensor([1, -1, -1, -1], device=q.device) | |
| def _sqrt_subgradient(x: torch.Tensor) -> torch.Tensor: | |
| # Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. | |
| ret = torch.zeros_like(x) | |
| positive_mask = x > 0 | |
| ret[positive_mask] = torch.sqrt(x[positive_mask]) | |
| return ret | |
| def _graham_schmidt(x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12): | |
| # A low eps here is necessary for good stability! | |
| with fp32_autocast_context(x_axis.device.type): | |
| e1 = xy_plane | |
| denom = torch.sqrt((x_axis**2).sum(dim=-1, keepdim=True) + eps) | |
| x_axis = x_axis / denom | |
| dot = (x_axis * e1).sum(dim=-1, keepdim=True) | |
| e1 = e1 - x_axis * dot | |
| denom = torch.sqrt((e1**2).sum(dim=-1, keepdim=True) + eps) | |
| e1 = e1 / denom | |
| e2 = torch.cross(x_axis, e1, dim=-1) | |
| rots = torch.stack([x_axis, e1, e2], dim=-1) | |
| return rots | |
| def build_affine3d_from_coordinates( | |
| coords: torch.Tensor, # (N, CA, C). | |
| ) -> tuple[Affine3D, torch.Tensor]: | |
| _MAX_SUPPORTED_DISTANCE = 1e6 | |
| coord_mask = torch.all( | |
| torch.all(torch.isfinite(coords) & (coords < _MAX_SUPPORTED_DISTANCE), dim=-1), | |
| dim=-1, | |
| ) | |
| def atom3_to_backbone_affine(bb_positions: torch.Tensor) -> Affine3D: | |
| N, CA, C = bb_positions.unbind(dim=-2) | |
| return Affine3D.from_graham_schmidt(C, CA, N) | |
| coords = coords.clone().float() | |
| coords[~coord_mask] = 0 | |
| # NOTE(thayes): If you have already normalized the coordinates, then | |
| # the black hole affine translations will be zeros and the rotations will be | |
| # the identity. | |
| average_per_n_ca_c = coords.masked_fill(~coord_mask[..., None, None], 0).sum(1) / ( | |
| coord_mask.sum(-1)[..., None, None] + 1e-8 | |
| ) | |
| affine_from_average = atom3_to_backbone_affine( | |
| average_per_n_ca_c.float() | |
| ).as_matrix() | |
| B, S, _, _ = coords.shape | |
| assert isinstance(B, int) | |
| assert isinstance(S, int) | |
| affine_rot_mats = affine_from_average.rot.tensor[..., None, :].expand(B, S, 9) | |
| affine_trans = affine_from_average.trans[..., None, :].expand(B, S, 3) | |
| # We use the identity rotation whereever we have no coordinates. This is | |
| # important because otherwise the rotation matrices will be all zeros, which | |
| # will cause collapse in the distance/direction attention mechanism. | |
| identity_rot = RotationMatrix.identity( | |
| (B, S), dtype=torch.float32, device=coords.device, requires_grad=False | |
| ) | |
| affine_rot_mats = affine_rot_mats.where( | |
| coord_mask.any(-1)[..., None, None], identity_rot.tensor | |
| ) | |
| black_hole_affine = Affine3D(affine_trans, RotationMatrix(affine_rot_mats)) | |
| affine = atom3_to_backbone_affine(coords.float()) | |
| affine = Affine3D.from_tensor( | |
| affine.tensor.where(coord_mask[..., None], black_hole_affine.tensor) | |
| ) | |
| return affine, coord_mask | |