ESM3_small / modeling_esm3.py
lhallee's picture
Upload modeling_esm3.py with huggingface_hub
3797259 verified
Raw
History Blame Contribute Delete
62.8 kB
from __future__ import annotations
"""
Hugging Face compatible ESM3 wrapper.
This module keeps Biohub's ESM3 implementation as the execution core and adds
the FastPLMs conventions around it: AutoModel loading, sequence-only
`input_ids` forwarding, and direct multimodal track arguments.
"""
import sys
import math
import functools
import importlib
import os
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Union
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
from transformers.modeling_outputs import ModelOutput
try:
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
)
except ImportError:
BlockMask = None
create_block_mask = None
flex_attention = None
try:
from fastplms.test_time_training import FastPLMTestTimeTrainingMixin
except ImportError:
pass # Running as HF Hub composite; shared definitions are above
ESM3_OPEN_SMALL = "esm3_sm_open_v1"
ESM3_OPEN_SMALL_ALIASES = {
"ESM3_small",
"esm3_small",
"esm3_sm_open_v1",
"esm3-open-2024-03",
"esm3-sm-open-v1",
"esm3-open",
}
SEQUENCE_BOS_TOKEN = 0
SEQUENCE_PAD_TOKEN = 1
SEQUENCE_EOS_TOKEN = 2
SEQUENCE_CHAINBREAK_TOKEN = 31
SEQUENCE_MASK_TOKEN = 32
VQVAE_CODEBOOK_SIZE = 4096
STRUCTURE_MASK_TOKEN = VQVAE_CODEBOOK_SIZE
STRUCTURE_EOS_TOKEN = VQVAE_CODEBOOK_SIZE + 1
STRUCTURE_BOS_TOKEN = VQVAE_CODEBOOK_SIZE + 2
STRUCTURE_PAD_TOKEN = VQVAE_CODEBOOK_SIZE + 3
STRUCTURE_CHAINBREAK_TOKEN = VQVAE_CODEBOOK_SIZE + 4
SASA_PAD_TOKEN = 0
SS8_PAD_TOKEN = 0
INTERPRO_PAD_TOKEN = 0
RESIDUE_PAD_TOKEN = 0
MAX_RESIDUE_ANNOTATIONS = 16
FUNCTION_TOKENS_DEPTH = 8
SEQUENCE_VOCAB = [
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"L",
"A",
"G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
"|",
"<mask>",
]
_SUPPORTED_ATTENTION_BACKENDS = ("auto", "flex", "sdpa")
_compiled_flex_attention = None
class AttentionBackend(Enum):
AUTO = "auto"
FLEX = "flex"
SDPA = "sdpa"
def _get_flex_attention_fn():
global _compiled_flex_attention
if flex_attention is None:
return None
flex_mod = torch.nn.attention.flex_attention
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
return flex_attention
if _compiled_flex_attention is None:
_compiled_flex_attention = torch.compile(
flex_attention,
dynamic=False,
)
return _compiled_flex_attention
def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
assert requested_backend in _SUPPORTED_ATTENTION_BACKENDS, (
f"Unsupported ESM3 attention backend: {requested_backend}. "
f"Expected one of {_SUPPORTED_ATTENTION_BACKENDS}."
)
if requested_backend == AttentionBackend.AUTO.value:
if flex_attention is not None:
return AttentionBackend.FLEX
return AttentionBackend.SDPA
if requested_backend == AttentionBackend.FLEX.value:
assert flex_attention is not None, "Flex Attention is not available in this environment."
return AttentionBackend.FLEX
if requested_backend == AttentionBackend.SDPA.value:
return AttentionBackend.SDPA
raise AssertionError(f"Unsupported ESM3 attention backend: {requested_backend}")
_ESM3_CHECKPOINT_SPECS = {
ESM3_OPEN_SMALL: {
"repo_id": "biohub/esm3-sm-open-v1",
"hidden_size": 1536,
"num_attention_heads": 24,
"num_vector_heads": 256,
"num_hidden_layers": 48,
},
}
class FastESM3Config(PretrainedConfig):
model_type = "fast_esm3"
def __init__(
self,
vocab_size: int = 64,
hidden_size: int = 1536,
num_attention_heads: int = 24,
num_vector_heads: int = 256,
num_hidden_layers: int = 48,
initializer_range: float = 0.02,
attn_backend: str = "sdpa",
model_name: str = ESM3_OPEN_SMALL,
**kwargs,
):
super().__init__(**kwargs)
assert hidden_size % FUNCTION_TOKENS_DEPTH == 0
assert hidden_size % num_attention_heads == 0
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_vector_heads = num_vector_heads
self.num_hidden_layers = num_hidden_layers
self.initializer_range = initializer_range
self.attn_backend = attn_backend
self.model_name = _resolve_esm3_checkpoint_key(model_name)
self.tie_word_embeddings = False
@dataclass
class FastESM3Output(ModelOutput):
loss: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
last_hidden_state: Optional[torch.Tensor] = None
sequence_logits: Optional[torch.Tensor] = None
structure_logits: Optional[torch.Tensor] = None
secondary_structure_logits: Optional[torch.Tensor] = None
sasa_logits: Optional[torch.Tensor] = None
function_logits: Optional[torch.Tensor] = None
residue_logits: Optional[torch.Tensor] = None
embeddings: Optional[torch.Tensor] = None
hidden_states: Optional[tuple[torch.Tensor, ...]] = None
attentions: Optional[tuple[torch.Tensor, ...]] = None
class EsmSequenceTokenizer(PreTrainedTokenizerFast):
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
unk_token: str = "<unk>",
cls_token: str = "<cls>",
pad_token: str = "<pad>",
mask_token: str = "<mask>",
eos_token: str = "<eos>",
chain_break_token: str = "|",
**kwargs,
):
token_to_id = {token: index for index, token in enumerate(SEQUENCE_VOCAB)}
bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
tokenizer = Tokenizer(bpe)
special_tokens = [
cls_token,
pad_token,
mask_token,
eos_token,
chain_break_token,
]
self.cb_token = chain_break_token
tokenizer.add_special_tokens(special_tokens)
tokenizer.post_processor = TemplateProcessing(
single="<cls> $A <eos>",
pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1",
special_tokens=[
("<cls>", tokenizer.token_to_id("<cls>")),
("<eos>", tokenizer.token_to_id("<eos>")),
],
)
super().__init__(
tokenizer_object=tokenizer,
unk_token=unk_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
eos_token=eos_token,
additional_special_tokens=[chain_break_token],
**kwargs,
)
@property
def bos_token(self) -> str:
return self.cls_token
@property
def bos_token_id(self) -> int:
return self.cls_token_id
@property
def chain_break_token(self) -> str:
return self.cb_token
@property
def chain_break_token_id(self) -> int:
token_id = self.convert_tokens_to_ids(self.chain_break_token)
assert isinstance(token_id, int)
return token_id
@property
def all_token_ids(self) -> list[int]:
return list(range(self.vocab_size))
@property
def special_token_ids(self) -> list[int]:
return self.all_special_ids
@dataclass
class FastESM3TokenizerCollection:
sequence: EsmSequenceTokenizer
structure: Optional[object] = None
secondary_structure: Optional[object] = None
sasa: Optional[object] = None
function: Optional[object] = None
residue_annotations: Optional[object] = None
def rbf(values: torch.Tensor, v_min: float, v_max: float, n_bins: int = 16) -> torch.Tensor:
centers = torch.linspace(
v_min,
v_max,
n_bins,
device=values.device,
dtype=values.dtype,
)
centers = centers.view([1] * len(values.shape) + [-1])
std = (v_max - v_min) / n_bins
z = (values.unsqueeze(-1) - centers) / std
return torch.exp(-(z**2))
def RegressionHead(
d_model: int,
output_dim: int,
hidden_dim: Optional[int] = None,
) -> nn.Module:
hidden_dim = hidden_dim if hidden_dim is not None else d_model
return nn.Sequential(
nn.Linear(d_model, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, output_dim),
)
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1),
"... d two -> ... (d two)",
two=2,
)
def apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved: bool = False,
) -> torch.Tensor:
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
seqlen = x.size(1)
cos = cos[:seqlen]
sin = sin[:seqlen]
cos = einops.repeat(cos, "s d -> s 1 (2 d)")
sin = einops.repeat(sin, "s d -> s 1 (2 d)")
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim: int,
base: float = 10000.0,
interleaved: bool = False,
scale_base: Optional[float] = None,
scaling_factor: float = 1.0,
pos_idx_in_fp32: bool = True,
device: Optional[torch.device] = None,
):
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
self.interleaved = interleaved
self.scale_base = scale_base
self.scaling_factor = scaling_factor
self.device = device
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self.reset_parameters()
def reset_parameters(self) -> None:
inv_freq = self._compute_inv_freq(self.device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
scale = (
(arange + 0.4 * self.dim) / (1.4 * self.dim)
if self.scale_base is not None
else None
)
self.register_buffer("scale", scale)
def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
return 1 / (
self.base
** (
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
/ self.dim
)
)
def _update_cos_sin_cache(
self,
seqlen: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
t /= self.scaling_factor
inv_freq = self.inv_freq.to(torch.float32)
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
t /= self.scaling_factor
inv_freq = self.inv_freq
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
raise NotImplementedError("Scaled rotary embeddings are not used by ESM3.")
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
seqlen_offset: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(
q.shape[1] + seqlen_offset,
device=q.device,
dtype=q.dtype,
)
assert self._cos_cached is not None
assert self._sin_cached is not None
return (
apply_rotary_emb_torch(
q,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
),
apply_rotary_emb_torch(
k,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
),
)
def fp32_autocast_context(device_type: str):
if device_type == "cuda":
return torch.autocast(device_type="cuda", enabled=False)
return torch.autocast(device_type=device_type, enabled=False)
class RotationMatrix:
def __init__(self, rots: torch.Tensor):
if rots.shape[-1] == 9:
rots = rots.unflatten(-1, (3, 3))
assert rots.shape[-1] == 3
assert rots.shape[-2] == 3
self._rots = rots.to(torch.float32)
@classmethod
def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> "RotationMatrix":
rots = torch.eye(3, **tensor_kwargs)
rots = rots.view(*[1 for _ in range(len(shape))], 3, 3)
rots = rots.expand(*shape, -1, -1)
return cls(rots)
def __getitem__(self, idx) -> "RotationMatrix":
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
return RotationMatrix(self._rots[indices + (slice(None), slice(None))])
@property
def shape(self) -> torch.Size:
return self._rots.shape[:-2]
@property
def tensor(self) -> torch.Tensor:
return self._rots.flatten(-2)
@property
def device(self) -> torch.device:
return self._rots.device
def as_matrix(self) -> "RotationMatrix":
return self
def apply(self, p: torch.Tensor) -> torch.Tensor:
with fp32_autocast_context(self.device.type):
p = p.to(self._rots.dtype)
if self._rots.shape[-3] == 1:
return p @ self._rots.transpose(-1, -2).squeeze(-3)
return torch.einsum("...ij,...j", self._rots, p)
def invert(self) -> "RotationMatrix":
return RotationMatrix(self._rots.transpose(-1, -2))
@staticmethod
def from_graham_schmidt(
x_axis: torch.Tensor,
xy_plane: torch.Tensor,
eps: float = 1e-12,
) -> "RotationMatrix":
with fp32_autocast_context(x_axis.device.type):
e1 = xy_plane
denom = torch.sqrt((x_axis**2).sum(dim=-1, keepdim=True) + eps)
x_axis = x_axis / denom
dot = (x_axis * e1).sum(dim=-1, keepdim=True)
e1 = e1 - x_axis * dot
denom = torch.sqrt((e1**2).sum(dim=-1, keepdim=True) + eps)
e1 = e1 / denom
e2 = torch.cross(x_axis, e1, dim=-1)
return RotationMatrix(torch.stack([x_axis, e1, e2], dim=-1))
@dataclass(frozen=True)
class Affine3D:
trans: torch.Tensor
rot: RotationMatrix
def __post_init__(self) -> None:
assert self.trans.shape[:-1] == self.rot.shape
def __getitem__(self, idx) -> "Affine3D":
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
return Affine3D(
trans=self.trans[indices + (slice(None),)],
rot=self.rot[idx],
)
@property
def shape(self) -> torch.Size:
return self.trans.shape[:-1]
@property
def dtype(self) -> torch.dtype:
return self.trans.dtype
@property
def device(self) -> torch.device:
return self.trans.device
@property
def tensor(self) -> torch.Tensor:
return torch.cat([self.rot.tensor, self.trans], dim=-1)
def as_matrix(self) -> "Affine3D":
return Affine3D(trans=self.trans, rot=self.rot.as_matrix())
def apply(self, p: torch.Tensor) -> torch.Tensor:
return self.rot.apply(p) + self.trans
@staticmethod
def from_tensor(t: torch.Tensor) -> "Affine3D":
match t.shape[-1]:
case 12:
trans = t[..., -3:]
rot = RotationMatrix(t[..., :-3].unflatten(-1, (3, 3)))
case _:
raise RuntimeError(
f"Cannot detect rotation format from {t.shape[-1] - 3}-d flat vector"
)
return Affine3D(trans, rot)
@staticmethod
def from_graham_schmidt(
neg_x_axis: torch.Tensor,
origin: torch.Tensor,
xy_plane: torch.Tensor,
eps: float = 1e-10,
) -> "Affine3D":
x_axis = origin - neg_x_axis
xy_plane = xy_plane - origin
return Affine3D(
trans=origin,
rot=RotationMatrix.from_graham_schmidt(x_axis, xy_plane, eps),
)
def build_affine3d_from_coordinates(coords: torch.Tensor) -> tuple[Affine3D, torch.Tensor]:
max_supported_distance = 1e6
coord_mask = torch.all(
torch.all(torch.isfinite(coords) & (coords < max_supported_distance), dim=-1),
dim=-1,
)
def atom3_to_backbone_affine(bb_positions: torch.Tensor) -> Affine3D:
n_atom, ca_atom, c_atom = bb_positions.unbind(dim=-2)
return Affine3D.from_graham_schmidt(c_atom, ca_atom, n_atom)
coords = coords.clone().float()
coords[~coord_mask] = 0
average_per_n_ca_c = coords.masked_fill(~coord_mask[..., None, None], 0).sum(1) / (
coord_mask.sum(-1)[..., None, None] + 1e-8
)
affine_from_average = atom3_to_backbone_affine(
average_per_n_ca_c.float()
).as_matrix()
batch_size, seq_len, _, _ = coords.shape
affine_rot_mats = affine_from_average.rot.tensor[..., None, :].expand(
batch_size,
seq_len,
9,
)
affine_trans = affine_from_average.trans[..., None, :].expand(batch_size, seq_len, 3)
identity_rot = RotationMatrix.identity(
(batch_size, seq_len),
dtype=torch.float32,
device=coords.device,
requires_grad=False,
)
affine_rot_mats = affine_rot_mats.where(
coord_mask.any(-1)[..., None, None],
identity_rot.tensor,
)
black_hole_affine = Affine3D(affine_trans, RotationMatrix(affine_rot_mats))
affine = atom3_to_backbone_affine(coords.float())
affine = Affine3D.from_tensor(
affine.tensor.where(coord_mask[..., None], black_hole_affine.tensor)
)
return affine, coord_mask
class MultiHeadAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
bias: bool = False,
qk_layernorm: bool = True,
attn_backend: str = "sdpa",
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = self.d_model // self.n_heads
self.scale = self.d_head**-0.5
self.attn_backend = resolve_attention_backend(attn_backend)
self.layernorm_qkv = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model * 3, bias=bias),
)
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
if qk_layernorm:
self.q_ln = nn.LayerNorm(d_model, bias=bias)
self.k_ln = nn.LayerNorm(d_model, bias=bias)
else:
self.q_ln = nn.Identity()
self.k_ln = nn.Identity()
self.rotary = RotaryEmbedding(d_model // n_heads)
def _apply_rotary(
self,
q: torch.Tensor,
k: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
q = q.unflatten(-1, (self.n_heads, self.d_head))
k = k.unflatten(-1, (self.n_heads, self.d_head))
q, k = self.rotary(q, k)
q = q.flatten(-2, -1)
k = k.flatten(-2, -1)
return q, k
def forward(
self,
x: torch.Tensor,
seq_id: Optional[torch.Tensor],
output_attentions: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
qkv = self.layernorm_qkv(x)
query, key, value = torch.chunk(qkv, 3, dim=-1)
query = self.q_ln(query).to(query.dtype)
key = self.k_ln(key).to(query.dtype)
query, key = self._apply_rotary(query, key)
reshaper = functools.partial(
einops.rearrange,
pattern="b s (h d) -> b h s d",
h=self.n_heads,
)
query, key, value = map(reshaper, (query, key, value))
if seq_id is not None:
mask = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask = mask.unsqueeze(1)
else:
mask = None
if output_attentions:
attn_scores = torch.einsum("bhld,bhsd->bhls", query, key) * self.scale
if mask is not None:
attn_scores = attn_scores.masked_fill(~mask, float("-inf"))
attn_weights = torch.softmax(attn_scores, dim=-1)
context = torch.einsum("bhls,bhsd->bhld", attn_weights, value)
else:
attn_weights = None
if self.attn_backend == AttentionBackend.FLEX:
block_mask = self._create_flex_block_mask(seq_id, query)
fn = _get_flex_attention_fn()
assert fn is not None, "Flex Attention is not available in this environment."
context = fn(
query,
key,
value,
block_mask=block_mask,
scale=self.scale,
)
elif self.attn_backend == AttentionBackend.SDPA:
context = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=mask,
scale=self.scale,
)
else:
raise AssertionError(f"Unsupported resolved ESM3 backend: {self.attn_backend}")
context = einops.rearrange(context, "b h s d -> b s (h d)")
return self.out_proj(context), attn_weights
@staticmethod
def _create_flex_block_mask(
seq_id: Optional[torch.Tensor],
query: torch.Tensor,
) -> Optional["BlockMask"]:
if seq_id is None:
return None
assert create_block_mask is not None, (
"Flex Attention requested but torch.create_block_mask is unavailable."
)
batch_size, _, seq_len, _ = query.shape
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return seq_id[batch_idx, q_idx] == seq_id[batch_idx, kv_idx]
return create_block_mask(
mask_mod,
batch_size,
1,
seq_len,
seq_len,
device=query.device,
)
class GeometricReasoningOriginalImpl(nn.Module):
def __init__(
self,
c_s: int,
v_heads: int,
num_vector_messages: int = 1,
mask_and_zero_frameless: bool = True,
bias: bool = False,
):
super().__init__()
self.c_s = c_s
self.v_heads = v_heads
self.num_vector_messages = num_vector_messages
self.mask_and_zero_frameless = mask_and_zero_frameless
self.s_norm = nn.LayerNorm(c_s, bias=bias)
dim_proj = 4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages
self.proj = nn.Linear(c_s, dim_proj, bias=bias)
channels_out = self.v_heads * 3 * self.num_vector_messages
self.out_proj = nn.Linear(channels_out, c_s, bias=bias)
self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
def forward(
self,
s: torch.Tensor,
affine: Affine3D,
affine_mask: torch.Tensor,
sequence_id: Optional[torch.Tensor],
chain_id: torch.Tensor,
) -> torch.Tensor:
if sequence_id is None:
sequence_id = torch.zeros_like(s[..., 0], dtype=torch.int64)
attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
attn_bias = attn_bias.unsqueeze(1).float()
attn_bias = attn_bias.masked_fill(
~affine_mask[:, None, None, :],
torch.finfo(attn_bias.dtype).min,
)
chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2)
attn_bias = attn_bias.masked_fill(
chain_id_mask.unsqueeze(1),
torch.finfo(s.dtype).min,
)
ns = self.s_norm(s)
vec_rot, vec_dist = self.proj(ns).split(
[
self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
self.v_heads * 2 * 3,
],
dim=-1,
)
query_rot, key_rot, value = (
affine.rot[..., None]
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
.split(
[self.v_heads, self.v_heads, self.v_heads * self.num_vector_messages],
dim=-2,
)
)
query_dist, key_dist = (
affine[..., None]
.apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3))
.chunk(2, dim=-2)
)
query_dist = rearrange(query_dist, "b s h d -> b h s 1 d")
key_dist = rearrange(key_dist, "b s h d -> b h 1 s d")
query_rot = rearrange(query_rot, "b s h d -> b h s d")
key_rot = rearrange(key_rot, "b s h d -> b h d s")
value = rearrange(
value,
"b s (h m) d -> b h s (m d)",
m=self.num_vector_messages,
)
distance_term = (query_dist - key_dist).norm(dim=-1) / math.sqrt(3)
rotation_term = query_rot.matmul(key_rot) / math.sqrt(3)
distance_term_weight = rearrange(
F.softplus(self.distance_scale_per_head),
"h -> h 1 1",
)
rotation_term_weight = rearrange(
F.softplus(self.rotation_scale_per_head),
"h -> h 1 1",
)
attn_weight = (
rotation_term * rotation_term_weight - distance_term * distance_term_weight
)
s_q = attn_weight.size(2)
s_k = attn_weight.size(3)
offset_q = max(0, attn_bias.size(2) - s_q)
offset_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, offset_q:, offset_k:]
attn_weight = torch.softmax(attn_weight + attn_bias, dim=-1)
attn_out = attn_weight.matmul(value)
attn_out = (
affine.rot[..., None]
.invert()
.apply(
rearrange(
attn_out,
"b h s (m d) -> b s (h m) d",
m=self.num_vector_messages,
)
)
)
attn_out = rearrange(
attn_out,
"b s (h m) d -> b s (h m d)",
m=self.num_vector_messages,
)
if self.mask_and_zero_frameless:
attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0)
attn_out = attn_out.to(self.out_proj.weight.dtype)
return self.out_proj(attn_out)
def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
return int(((expansion_ratio * d_model) + 255) // 256 * 256)
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return F.silu(x1) * x2
def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool) -> nn.Module:
return nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(
d_model,
swiglu_correction_fn(expansion_ratio, d_model) * 2,
bias=bias,
),
SwiGLU(),
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias),
)
def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool) -> nn.Module:
hidden_dim = int(expansion_ratio * d_model)
return nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, hidden_dim, bias=bias),
nn.GELU(),
nn.Linear(hidden_dim, d_model, bias=bias),
)
class UnifiedTransformerBlock(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
use_geom_attn: bool = False,
use_plain_attn: bool = True,
v_heads: Optional[int] = None,
bias: bool = False,
expansion_ratio: float = 4.0,
residue_scaling_factor: float = 1.0,
mask_and_zero_frameless: bool = False,
qk_layernorm: bool = True,
ffn_type: str = "swiglu",
attn_backend: str = "sdpa",
):
super().__init__()
self.use_plain_attn = use_plain_attn
if self.use_plain_attn:
self.attn = MultiHeadAttention(
d_model,
n_heads,
bias,
qk_layernorm=qk_layernorm,
attn_backend=attn_backend,
)
self.use_geom_attn = use_geom_attn
if self.use_geom_attn:
assert v_heads is not None
self.geom_attn = GeometricReasoningOriginalImpl(
c_s=d_model,
v_heads=v_heads,
bias=bias,
mask_and_zero_frameless=mask_and_zero_frameless,
)
if ffn_type == "swiglu":
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias)
elif ffn_type == "gelu":
self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias)
else:
raise ValueError(f"Unknown ffn_type: {ffn_type}")
self.scaling_factor = residue_scaling_factor
def forward(
self,
x: torch.Tensor,
sequence_id: Optional[torch.Tensor],
frames: Affine3D,
frames_mask: torch.Tensor,
chain_id: torch.Tensor,
output_attentions: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
attn_weights = None
if self.use_plain_attn:
r1, attn_weights = self.attn(
x,
sequence_id,
output_attentions=output_attentions,
)
x = x + r1 / self.scaling_factor
if self.use_geom_attn:
r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id)
x = x + r2 / self.scaling_factor
r3 = self.ffn(x) / self.scaling_factor
x = x + r3
return x, attn_weights
class TransformerStack(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
v_heads: Optional[int],
n_layers: int,
n_layers_geom: int = 1,
scale_residue: bool = True,
mask_and_zero_frameless: bool = False,
bias: bool = False,
qk_layernorm: bool = True,
ffn_type: str = "swiglu",
expansion_ratio: float = 8 / 3,
attn_backend: str = "sdpa",
):
super().__init__()
self.blocks = nn.ModuleList(
[
UnifiedTransformerBlock(
d_model,
n_heads,
v_heads=v_heads,
use_geom_attn=index < n_layers_geom,
residue_scaling_factor=(
math.sqrt(n_layers / 36) if scale_residue else 1.0
),
expansion_ratio=expansion_ratio,
mask_and_zero_frameless=mask_and_zero_frameless,
bias=bias,
qk_layernorm=qk_layernorm,
ffn_type=ffn_type,
attn_backend=attn_backend,
)
for index in range(n_layers)
]
)
self.norm = nn.LayerNorm(d_model, bias=False)
def forward(
self,
x: torch.Tensor,
sequence_id: Optional[torch.Tensor] = None,
affine: Optional[Affine3D] = None,
affine_mask: Optional[torch.Tensor] = None,
chain_id: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> tuple[
torch.Tensor,
torch.Tensor,
tuple[torch.Tensor, ...],
Optional[tuple[torch.Tensor, ...]],
]:
*batch_dims, _ = x.shape
if chain_id is None:
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
assert affine is not None
assert affine_mask is not None
all_hidden_states = []
all_attentions = []
for block in self.blocks:
x, attn_weights = block(
x,
sequence_id,
affine,
affine_mask,
chain_id,
output_attentions=output_attentions,
)
all_hidden_states.append(x)
if output_attentions and attn_weights is not None:
all_attentions.append(attn_weights)
hidden_states = tuple(all_hidden_states)
attentions = tuple(all_attentions) if output_attentions else None
return self.norm(x), x, hidden_states, attentions
class EncodeInputs(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.sequence_embed = nn.Embedding(64, d_model)
self.plddt_projection = nn.Linear(16, d_model)
self.structure_per_res_plddt_projection = nn.Linear(16, d_model)
self.structure_tokens_embed = nn.Embedding(4096 + 5, d_model)
self.ss8_embed = nn.Embedding(8 + 3, d_model)
self.sasa_embed = nn.Embedding(16 + 3, d_model)
self.function_embed = nn.ModuleList(
[nn.Embedding(260, d_model // 8, padding_idx=0) for _ in range(8)]
)
self.residue_embed = nn.EmbeddingBag(1478, d_model, mode="sum", padding_idx=0)
def forward(
self,
sequence_tokens: torch.Tensor,
structure_tokens: torch.Tensor,
average_plddt: torch.Tensor,
per_res_plddt: torch.Tensor,
ss8_tokens: torch.Tensor,
sasa_tokens: torch.Tensor,
function_tokens: torch.Tensor,
residue_annotation_tokens: torch.Tensor,
) -> torch.Tensor:
sequence_embed = self.sequence_embed(sequence_tokens)
rbf_16_fn = functools.partial(rbf, v_min=0.0, v_max=1.0, n_bins=16)
plddt_embed = self.plddt_projection(
rbf_16_fn(average_plddt).to(self.plddt_projection.weight.dtype)
)
structure_per_res_plddt = self.structure_per_res_plddt_projection(
rbf_16_fn(per_res_plddt).to(
self.structure_per_res_plddt_projection.weight.dtype
)
)
structure_embed = self.structure_tokens_embed(structure_tokens)
ss8_embed = self.ss8_embed(ss8_tokens)
sasa_embed = self.sasa_embed(sasa_tokens)
function_embed = torch.cat(
[
embed_fn(funcs)
for embed_fn, funcs in zip(
self.function_embed,
function_tokens.unbind(-1),
)
],
-1,
)
batch_size, seq_len, num_annotations = residue_annotation_tokens.shape
residue_embed = self.residue_embed(
rearrange(
residue_annotation_tokens,
"b l n -> (b l) n",
b=batch_size,
l=seq_len,
n=num_annotations,
)
)
residue_embed = rearrange(
residue_embed,
"(b l) d -> b l d",
b=batch_size,
l=seq_len,
)
return (
sequence_embed
+ plddt_embed
+ structure_per_res_plddt
+ structure_embed
+ ss8_embed
+ sasa_embed
+ function_embed
+ residue_embed
)
@dataclass
class ESM3CoreOutput:
sequence_logits: torch.Tensor
structure_logits: torch.Tensor
secondary_structure_logits: torch.Tensor
sasa_logits: torch.Tensor
function_logits: torch.Tensor
residue_logits: torch.Tensor
embeddings: torch.Tensor
hidden_states: tuple[torch.Tensor, ...]
attentions: Optional[tuple[torch.Tensor, ...]] = None
class OutputHeads(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.sequence_head = RegressionHead(d_model, 64)
self.structure_head = RegressionHead(d_model, 4096)
self.ss8_head = RegressionHead(d_model, 8 + 3)
self.sasa_head = RegressionHead(d_model, 16 + 3)
self.function_head = RegressionHead(d_model, 260 * 8)
self.residue_head = RegressionHead(d_model, 1478)
def forward(
self,
x: torch.Tensor,
embed: torch.Tensor,
hidden_states: tuple[torch.Tensor, ...],
attentions: Optional[tuple[torch.Tensor, ...]] = None,
) -> ESM3CoreOutput:
function_logits = self.function_head(x)
function_logits = rearrange(function_logits, "... (k v) -> ... k v", k=8)
return ESM3CoreOutput(
sequence_logits=self.sequence_head(x),
structure_logits=self.structure_head(x),
secondary_structure_logits=self.ss8_head(x),
sasa_logits=self.sasa_head(x),
function_logits=function_logits,
residue_logits=self.residue_head(x),
embeddings=embed,
hidden_states=hidden_states,
attentions=attentions,
)
class ESM3Core(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
v_heads: int,
n_layers: int,
tokenizers: FastESM3TokenizerCollection,
attn_backend: str = "sdpa",
):
super().__init__()
self.encoder = EncodeInputs(d_model)
self.transformer = TransformerStack(
d_model,
n_heads,
v_heads,
n_layers,
mask_and_zero_frameless=True,
attn_backend=attn_backend,
)
self.output_heads = OutputHeads(d_model)
self.tokenizers = tokenizers
def forward(
self,
*,
sequence_tokens: Optional[torch.Tensor] = None,
structure_tokens: Optional[torch.Tensor] = None,
ss8_tokens: Optional[torch.Tensor] = None,
sasa_tokens: Optional[torch.Tensor] = None,
function_tokens: Optional[torch.Tensor] = None,
residue_annotation_tokens: Optional[torch.Tensor] = None,
average_plddt: Optional[torch.Tensor] = None,
per_res_plddt: Optional[torch.Tensor] = None,
structure_coords: Optional[torch.Tensor] = None,
chain_id: Optional[torch.Tensor] = None,
sequence_id: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
) -> ESM3CoreOutput:
output_attentions = bool(output_attentions)
present_inputs = [
sequence_tokens,
structure_tokens,
ss8_tokens,
sasa_tokens,
structure_coords,
function_tokens,
residue_annotation_tokens,
]
try:
seq_len, device = next(
(x.shape[1], x.device) for x in present_inputs if x is not None
)
except StopIteration:
raise ValueError("At least one of the inputs must be non-None")
def defaults(x: Optional[torch.Tensor], token: int) -> torch.Tensor:
if x is None:
return torch.full(
(1, seq_len),
token,
dtype=torch.long,
device=device,
)
return x
sequence_tokens = defaults(sequence_tokens, self.tokenizers.sequence.mask_token_id)
ss8_tokens = defaults(ss8_tokens, SS8_PAD_TOKEN)
sasa_tokens = defaults(sasa_tokens, SASA_PAD_TOKEN)
average_plddt = defaults(average_plddt, 1).float()
per_res_plddt = defaults(per_res_plddt, 0).float()
chain_id = defaults(chain_id, 0)
if residue_annotation_tokens is None:
residue_annotation_tokens = torch.full(
(1, seq_len, MAX_RESIDUE_ANNOTATIONS),
RESIDUE_PAD_TOKEN,
dtype=torch.long,
device=device,
)
if function_tokens is None:
function_tokens = torch.full(
(1, seq_len, FUNCTION_TOKENS_DEPTH),
INTERPRO_PAD_TOKEN,
dtype=torch.long,
device=device,
)
if structure_coords is None:
structure_coords = torch.full(
(1, seq_len, 3, 3),
float("nan"),
dtype=torch.float,
device=device,
)
structure_coords = structure_coords[..., :3, :]
affine, affine_mask = build_affine3d_from_coordinates(structure_coords)
structure_tokens = defaults(structure_tokens, STRUCTURE_MASK_TOKEN)
structure_tokens = (
structure_tokens.masked_fill(structure_tokens == -1, STRUCTURE_MASK_TOKEN)
.masked_fill(sequence_tokens == SEQUENCE_BOS_TOKEN, STRUCTURE_BOS_TOKEN)
.masked_fill(sequence_tokens == SEQUENCE_PAD_TOKEN, STRUCTURE_PAD_TOKEN)
.masked_fill(sequence_tokens == SEQUENCE_EOS_TOKEN, STRUCTURE_EOS_TOKEN)
.masked_fill(
sequence_tokens == SEQUENCE_CHAINBREAK_TOKEN,
STRUCTURE_CHAINBREAK_TOKEN,
)
)
x = self.encoder(
sequence_tokens,
structure_tokens,
average_plddt,
per_res_plddt,
ss8_tokens,
sasa_tokens,
function_tokens,
residue_annotation_tokens,
)
x, embedding, hidden_states, attentions = self.transformer(
x,
sequence_id,
affine,
affine_mask,
chain_id,
output_attentions=output_attentions,
)
return self.output_heads(
x,
embedding,
hidden_states=hidden_states,
attentions=attentions,
)
def _resolve_esm3_checkpoint_key(model_name: str) -> str:
if model_name in ESM3_OPEN_SMALL_ALIASES:
return ESM3_OPEN_SMALL
raise ValueError(
f"Unsupported ESM3 checkpoint {model_name}. "
f"Supported names: {sorted(ESM3_OPEN_SMALL_ALIASES)}"
)
def parse_fasta(fasta_path: str) -> list[str]:
assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
sequences = []
current_seq = []
with open(fasta_path, "r", encoding="utf-8") as handle:
for line in handle:
stripped = line.strip()
if len(stripped) == 0:
continue
if stripped.startswith(">"):
if len(current_seq) > 0:
sequences.append("".join(current_seq))
current_seq = []
else:
current_seq.append(stripped)
if len(current_seq) > 0:
sequences.append("".join(current_seq))
return sequences
def _ensure_official_esm_on_path() -> None:
for parent in Path(__file__).resolve().parents:
candidate = parent / "official" / "esm"
if (candidate / "esm" / "models" / "esm3.py").exists():
candidate_str = str(candidate)
if candidate_str not in sys.path:
sys.path.insert(0, candidate_str)
return
def _make_structure_encoder(device: Union[torch.device, str]) -> nn.Module:
_ensure_official_esm_on_path()
pretrained = importlib.import_module("esm.pretrained")
return pretrained.ESM3_structure_encoder_v0(device)
def _make_structure_decoder(device: Union[torch.device, str]) -> nn.Module:
_ensure_official_esm_on_path()
pretrained = importlib.import_module("esm.pretrained")
return pretrained.ESM3_structure_decoder_v0(device)
def _make_function_decoder(device: Union[torch.device, str]) -> nn.Module:
_ensure_official_esm_on_path()
pretrained = importlib.import_module("esm.pretrained")
return pretrained.ESM3_function_decoder_v0(device)
def _build_official_esm3(config: FastESM3Config) -> nn.Module:
return ESM3Core(
d_model=config.hidden_size,
n_heads=config.num_attention_heads,
v_heads=config.num_vector_heads,
n_layers=config.num_hidden_layers,
tokenizers=FastESM3TokenizerCollection(sequence=EsmSequenceTokenizer()),
attn_backend=config.attn_backend,
)
class FastESM3PreTrainedModel(PreTrainedModel):
config_class = FastESM3Config
base_model_prefix = "esm3"
main_input_name = "input_ids"
supports_gradient_checkpointing = False
all_tied_weights_keys = {}
@classmethod
def is_remote_code(cls) -> bool:
return True
def _init_weights(self, module: nn.Module) -> None:
for parameter in module.parameters(recurse=False):
if "_is_hf_initialized" in parameter.__dict__ and parameter.__dict__["_is_hf_initialized"]:
return
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
with torch.no_grad():
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
if module.bias is not None:
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
@property
def attn_backend(self) -> str:
return self.config.attn_backend
@attn_backend.setter
def attn_backend(self, backend: str) -> None:
assert backend in _SUPPORTED_ATTENTION_BACKENDS, (
f"ESM3 currently supports only {_SUPPORTED_ATTENTION_BACKENDS}; got {backend}."
)
self.config.attn_backend = backend
resolved = resolve_attention_backend(backend)
for module in self.modules():
if isinstance(module, MultiHeadAttention):
module.attn_backend = resolved
@classmethod
def from_pretrained_esm(
cls,
model_name: str = ESM3_OPEN_SMALL,
device: Union[torch.device, str] = "cpu",
dtype: Optional[torch.dtype] = None,
) -> "FastESM3Model":
key = _resolve_esm3_checkpoint_key(model_name)
spec = _ESM3_CHECKPOINT_SPECS[key]
config = FastESM3Config(
hidden_size=spec["hidden_size"],
num_attention_heads=spec["num_attention_heads"],
num_vector_heads=spec["num_vector_heads"],
num_hidden_layers=spec["num_hidden_layers"],
model_name=key,
)
model = FastESM3Model(config)
checkpoint_root = Path(
snapshot_download(
repo_id=spec["repo_id"],
allow_patterns=["data/weights/esm3_sm_open_v1.pth"],
)
)
state_dict = torch.load(
checkpoint_root / "data" / "weights" / "esm3_sm_open_v1.pth",
map_location=torch.device(device),
)
load_result = model.esm3.load_state_dict(state_dict, strict=True)
assert len(load_result.missing_keys) == 0, load_result.missing_keys
assert len(load_result.unexpected_keys) == 0, load_result.unexpected_keys
model = model.to(device)
if dtype is not None:
model = model.to(dtype=dtype)
model.eval()
return model
class FastESM3Model(FastPLMTestTimeTrainingMixin, FastESM3PreTrainedModel):
config_class = FastESM3Config
def __init__(self, config: FastESM3Config, **kwargs):
super().__init__(config, **kwargs)
self.tokenizer = EsmSequenceTokenizer()
self.esm3 = _build_official_esm3(config)
self.__dict__["_official_sdk_model"] = None
self.init_ttt({"lora_target_replace_module": "MultiHeadAttention"})
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def raw_model(self) -> nn.Module:
return self.esm3
def _get_official_sdk_model(self) -> nn.Module:
cached_model = self.__dict__["_official_sdk_model"]
if cached_model is not None:
return cached_model
_ensure_official_esm_on_path()
esm3_module = importlib.import_module("esm.models.esm3")
tokenization = importlib.import_module("esm.tokenization")
sdk_model = esm3_module.ESM3(
d_model=self.config.hidden_size,
n_heads=self.config.num_attention_heads,
v_heads=self.config.num_vector_heads,
n_layers=self.config.num_hidden_layers,
structure_encoder_fn=_make_structure_encoder,
structure_decoder_fn=_make_structure_decoder,
function_decoder_fn=_make_function_decoder,
tokenizers=tokenization.get_esm3_model_tokenizers(self.config.model_name),
)
load_result = sdk_model.load_state_dict(self.esm3.state_dict(), strict=True)
assert len(load_result.missing_keys) == 0, load_result.missing_keys
assert len(load_result.unexpected_keys) == 0, load_result.unexpected_keys
dtype = next(self.esm3.parameters()).dtype
sdk_model = sdk_model.to(self.device).to(dtype=dtype).eval()
self.__dict__["_official_sdk_model"] = sdk_model
return sdk_model
def get_input_embeddings(self) -> nn.Module:
return self.esm3.encoder.sequence_embed
def set_input_embeddings(self, value: nn.Module) -> None:
self.esm3.encoder.sequence_embed = value
def tokenize_sequences(
self,
sequences: Union[str, list[str]],
padding: bool = True,
return_tensors: str = "pt",
device: Optional[Union[torch.device, str]] = None,
add_special_tokens: bool = True,
) -> dict[str, torch.Tensor]:
tokenized = self.tokenizer(
sequences,
padding=padding,
return_tensors=return_tensors,
add_special_tokens=add_special_tokens,
)
if device is None:
return tokenized
return {name: tensor.to(device) for name, tensor in tokenized.items()}
def forward_sequence(
self,
sequences: Union[str, list[str]],
device: Optional[Union[torch.device, str]] = None,
**kwargs,
) -> FastESM3Output:
if device is None:
device = self.device
tokenized = self.tokenize_sequences(sequences, device=device)
return self(**tokenized, **kwargs)
def _embed(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
**kwargs,
) -> torch.Tensor:
output = self(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
if store_all_hidden_states:
assert output.hidden_states is not None, "store_all_hidden_states requires hidden states."
return torch.stack(tuple(output.hidden_states), dim=1)
if hidden_state_index == -1:
return output.last_hidden_state
assert output.hidden_states is not None, "hidden_state_index selection requires hidden states."
return output.hidden_states[hidden_state_index]
def _pool_embeddings(
self,
embeddings: torch.Tensor,
attention_mask: torch.Tensor,
pooling_types: list[str],
) -> torch.Tensor:
pooled = []
mask = attention_mask.to(dtype=embeddings.dtype).unsqueeze(-1)
for pooling_type in pooling_types:
if pooling_type == "mean":
pooled.append((embeddings * mask).sum(dim=1) / mask.sum(dim=1))
elif pooling_type == "cls":
pooled.append(embeddings[:, 0, :])
elif pooling_type == "max":
bool_mask = attention_mask.unsqueeze(-1).bool()
pooled.append(
embeddings.masked_fill(~bool_mask, float("-inf")).max(dim=1).values
)
else:
raise ValueError(
f"Unsupported ESM3 pooling type {pooling_type}. "
"Supported values are 'mean', 'cls', and 'max'."
)
return torch.cat(pooled, dim=-1)
def embed_dataset(
self,
sequences: Optional[List[str]] = None,
tokenizer: Optional[PreTrainedTokenizerFast] = None,
batch_size: int = 2,
max_len: int = 512,
truncate: bool = True,
full_embeddings: bool = False,
embed_dtype: torch.dtype = torch.float32,
pooling_types: List[str] = ["mean"],
num_workers: int = 0,
sql: bool = False,
save: bool = True,
sql_db_path: str = "embeddings.db",
save_path: str = "embeddings.pth",
fasta_path: Optional[str] = None,
padding: str = "longest",
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
**kwargs,
) -> Dict[str, torch.Tensor]:
del num_workers, sql_db_path
assert not sql, "ESM3 embed_dataset currently supports .pth saves, not SQLite."
assert isinstance(hidden_state_index, int), "hidden_state_index must be an integer."
assert full_embeddings or not store_all_hidden_states, (
"store_all_hidden_states=True requires full_embeddings=True."
)
if tokenizer is None:
tokenizer = self.tokenizer
if fasta_path is not None:
fasta_sequences = parse_fasta(fasta_path)
sequences = list(sequences or []) + fasta_sequences
assert sequences is not None and len(sequences) > 0, (
"Must provide at least one sequence via `sequences` or `fasta_path`."
)
unique_sequences = []
seen_sequences = set()
for sequence in sequences:
prepared_sequence = sequence[:max_len] if truncate else sequence
if prepared_sequence not in seen_sequences:
unique_sequences.append(prepared_sequence)
seen_sequences.add(prepared_sequence)
unique_sequences = sorted(unique_sequences, key=len, reverse=True)
embeddings_by_sequence: Dict[str, torch.Tensor] = {}
was_training = self.training
self.eval()
for batch_start in range(0, len(unique_sequences), batch_size):
batch_sequences = unique_sequences[batch_start : batch_start + batch_size]
tokenized = tokenizer(
batch_sequences,
padding=padding,
truncation=truncate,
max_length=max_len + 2,
return_tensors="pt",
)
tokenized = {
name: tensor.to(self.device) for name, tensor in tokenized.items()
}
with torch.inference_mode():
residue_embeddings = self._embed(
**tokenized,
hidden_state_index=hidden_state_index,
store_all_hidden_states=store_all_hidden_states,
**kwargs,
)
attention_mask = tokenized["attention_mask"]
if full_embeddings:
batch_embeddings = residue_embeddings.to(embed_dtype).cpu()
for sequence, embedding, mask in zip(
batch_sequences,
batch_embeddings,
attention_mask.cpu(),
):
if embedding.ndim == 3:
embeddings_by_sequence[sequence] = embedding[:, mask.bool(), :]
else:
embeddings_by_sequence[sequence] = embedding[mask.bool()]
else:
pooled_embeddings = self._pool_embeddings(
residue_embeddings,
attention_mask,
pooling_types,
)
pooled_embeddings = pooled_embeddings.to(embed_dtype).cpu()
for sequence, embedding in zip(batch_sequences, pooled_embeddings):
embeddings_by_sequence[sequence] = embedding
if was_training:
self.train()
if save:
torch.save(embeddings_by_sequence, save_path)
return embeddings_by_sequence
def encode(self, input):
return self._get_official_sdk_model().encode(input)
def decode(self, input):
return self._get_official_sdk_model().decode(input)
def generate(self, input, config):
return self._get_official_sdk_model().generate(input, config)
def batch_generate(self, inputs, configs):
return self._get_official_sdk_model().batch_generate(inputs, configs)
def _ttt_get_trainable_modules(self) -> list[nn.Module]:
return [self.esm3]
def forward_and_sample(self, input, sampling_configuration):
return self._get_official_sdk_model().forward_and_sample(
input,
sampling_configuration,
)
def logits(self, input=None, config=None, **kwargs):
if input is None:
return self.forward(**kwargs)
if isinstance(input, torch.Tensor):
return self.forward(sequence_tokens=input, **kwargs)
if config is None:
return self._get_official_sdk_model().logits(input)
return self._get_official_sdk_model().logits(input, config)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sequence_tokens: Optional[torch.Tensor] = None,
structure_tokens: Optional[torch.Tensor] = None,
ss8_tokens: Optional[torch.Tensor] = None,
sasa_tokens: Optional[torch.Tensor] = None,
function_tokens: Optional[torch.Tensor] = None,
residue_annotation_tokens: Optional[torch.Tensor] = None,
average_plddt: Optional[torch.Tensor] = None,
per_res_plddt: Optional[torch.Tensor] = None,
structure_coords: Optional[torch.Tensor] = None,
chain_id: Optional[torch.Tensor] = None,
sequence_id: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> FastESM3Output:
del output_hidden_states, return_dict, kwargs
if sequence_tokens is None:
sequence_tokens = input_ids
if sequence_id is None and attention_mask is not None:
sequence_id = attention_mask.to(dtype=torch.bool)
output = self.esm3(
sequence_tokens=sequence_tokens,
structure_tokens=structure_tokens,
ss8_tokens=ss8_tokens,
sasa_tokens=sasa_tokens,
function_tokens=function_tokens,
residue_annotation_tokens=residue_annotation_tokens,
average_plddt=average_plddt,
per_res_plddt=per_res_plddt,
structure_coords=structure_coords,
chain_id=chain_id,
sequence_id=sequence_id,
output_attentions=output_attentions,
)
loss = None
if labels is not None:
loss = F.cross_entropy(
output.sequence_logits.view(-1, output.sequence_logits.shape[-1]),
labels.view(-1),
ignore_index=-100,
)
return FastESM3Output(
loss=loss,
logits=output.sequence_logits,
last_hidden_state=output.embeddings,
sequence_logits=output.sequence_logits,
structure_logits=output.structure_logits,
secondary_structure_logits=output.secondary_structure_logits,
sasa_logits=output.sasa_logits,
function_logits=output.function_logits,
residue_logits=output.residue_logits,
embeddings=output.embeddings,
hidden_states=output.hidden_states,
attentions=output.attentions,
)