FastESM2_650 / README.md
lhallee's picture
Update README.md
b5d932c verified
|
raw
history blame
2.84 kB
metadata
library_name: transformers
tags: []

FastESM

A faster half-precision version of ESM2-650 that leverages FlashAttention2

FastESM is a fully Huggingface compatible version rewritten with a newer PyTorch Attention implementation which will run FlashAttention2 when possible.

To produce the FastESM weights, we trained ESM2-650 50000 additional steps in fp16 mixed precision on OMG50 up to sequence length of 2048.

Outputting attentions and predicting contacts are not possible from SDPA. Various other optimizations also make the base implementation slightly different than the HF one.

Use with 🤗 transformers

import torch
from transformers import AutoModel, AutoTokenizer

model_path = 'Synthyra/FastESM2_650'
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

sequence = 'MSEQWENCE'
tokenized = tokenizer(sequence, return_tensors='pt')
with torch.no_grad():
    embeddings = model(**tokenized).last_hidden_state

print(embeddings.shape) # (1, 11, 1280)

Embed entire datasets with no new code

To embed a list of protein sequences fast, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the progress bar is usually much longer than the actual time.

embeddings = model.embed_dataset(
    sequences=sequences, # list of protein strings
    batch_size=16, # embedding batch size
    max_len=2048, # truncate to max_len
    full_embeddings=True, # return residue-wise embeddings
    full_precision=False, # store as float32
    pooling_type='mean', # use mean pooling if protein-wise embeddings
    num_workers=0, # data loading num workers
    sql=False, # return dictionary of sequences and embeddings
)

_ = model.embed_dataset(
    sequences=sequences, # list of protein strings
    batch_size=16, # embedding batch size
    max_len=2048, # truncate to max_len
    full_embeddings=True, # return residue-wise embeddings
    full_precision=False, # store as float32
    pooling_type='mean', # use mean pooling if protein-wise embeddings
    num_workers=0, # data loading num workers
    sql=True, # store sequences in local SQL database
    sql_db_path='embeddings.db', # path to .db file of choice
)

Comparison of half precisions

Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.

When summing the MSE of 1000 sequences vs. the fp32 weights:

Average MSE for FP16: 0.00000140

Average MSE for BF16: 0.00004125

FlashAttention2

Requires PyTorch 2.5+ for the most savings, see SDPA.

Citation