ESM2 / app.py
gabboud's picture
change default max runtime to 300 and maximum to 1h
bd04feb
import gradio as gr
import torch
import numpy as np
from Bio import SeqIO
import tempfile
import os
import json
from pathlib import Path
import zipfile
import spaces
from utils.download_models import *
from utils.handle_files import parse_fasta_files
from utils.pipelines import generate_embeddings, full_embedding_pipeline, full_ppl_pipeline
print("Downloading ESM2 models...")
MODELS = {
"facebook/esm2_t6_8M_UR50D": "ESM2-8M",
"facebook/esm2_t12_35M_UR50D": "ESM2-35M",
"facebook/esm2_t33_650M_UR50D": "ESM2-650M"
}
cache_dirs = cache_all_models(MODELS)
models_and_tokenizers = load_all_models(MODELS)
# Create Gradio interface
with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
gr.Markdown("""
# ESM2 for candidate sequence filtering 🤖
Once one has generated de novo protein sequences using a tool like LigandMPNN, one must rank them to select promising candidates for experimental validation. One powerful approach is to use <a href="https://www.science.org/doi/10.1126/science.ade2574" target="_blank">protein language models like Meta's ESM2.</a>
These language models rely on a BERT-like architecture and a Masked Language Modeling (MLM) objective to learn rich representations of protein sequences. Note that this Space pairs well with the companion <a href="https://huggingface.co/spaces/hugging-science/RFdiffusion3" target="_blank">RFdiffusion3</a>, <a href="https://huggingface.co/spaces/hugging-science/LigandMPNN" target="_blank">LigandMPNN</a> and RosettaFold3 Spaces for a full de novo design pipeline!
ESM is used for two main purposes:
1. **Generating embeddings**: ESM's hidden layers creates high-dimensional representations of protein sequences that capture structural and functional information.
These embeddings can be used as input features for downstream machine learning models to predict function, properties or even for folding.
Embeddings can also be used with dimensionality reduction techniques like t-SNE to visualize to identify clusters or compare against known proteins.
2. **Calculating pseudo-perplexity scores (PPL)**: The lower this score is for a given input sequence, the more "natural" or "plausible" it is according to the model's learned distribution.
Such scores are often used as a filtering criterion in de novo design, as sequences with lower PPL are more likely to express properly in the lab and fold into stable structures.
PPL scores provide an orthogonal evaluation metric to structure-based methods like RosettaFold.
## How to use this Space:
- **Choose the ESM2 model:** models mainly differ by the number of parameters (8M, 35M, 650M). Larger models produce better PPL scores and richer embeddings but have longer runtimes.
- **Upload one or more FASTA files** containing your candidate sequences.
- **Choose the batch size:** it controls how many sequences are processed together. Larger batch sizes can speed up processing but require more GPU memory.
- **Choose between generating embeddings or calculating pseudo-perplexity scores.**
Note that calculating PPL scores is much more computationally intensive than generating embeddings, it scales cubically with sequence length $L$. This is because calculating PPL requires $L$ forward passes through the model, each with a different token masked out.
For long sequences or large numbers of sequences, we recommend using the approximate PPL calculation, which masks 10% of tokens at a time and thus only scales quadratically with sequence length. This provides a good tradeoff between accuracy and runtime.
""")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Select Model"
)
batch_size = gr.Slider(
minimum=1,
maximum=128,
step=1,
value=32,
label="Batch Size"
)
max_duration = gr.Number(
value=300,
label="Max Duration (seconds)",
precision=0,
minimum=1,
maximum=3600
)
with gr.Row():
with gr.Column():
input_files = gr.File(
label="Upload FASTA files",
file_count="multiple",
file_types=[".fasta", ".fa", ".faa"]
)
with gr.Column():
with gr.Tabs():
with gr.TabItem("Generate Embeddings"):
submit_btn = gr.Button("Generate Embeddings", variant="primary", size="lg")
status_output = gr.Textbox(
label="Waiting for embeddings generation...",
interactive=False,
lines=6
)
download_output = gr.File(
label="Download Output Files",
file_count="multiple"
)
with gr.TabItem("Calculate Pseudo-Perplexity scores"):
with gr.Row():
ppl_button = gr.Button("Calculate Exact PPL", variant="primary", size="lg")
ppl_approx_button = gr.Button("Calculate Approximate PPL", variant="primary", size="lg")
ppl_status = gr.Textbox(
label="Waiting for pseudo-perplexity calculation...",
interactive=False,
lines=6
)
ppl_download = gr.File(
label="Download Pseudo-Perplexity Output",
file_count="multiple"
)
def run_pipeline_with_selected_model(fasta_files, model_key, batch_size_value, max_duration, task="embedding"):
"""Wrapper to run pipeline with selected model from dropdown."""
if not fasta_files:
return gr.update(), "No FASTA files uploaded. Please upload at least one FASTA file for inference."
model, tokenizer = models_and_tokenizers[model_key]
if task == "embedding":
return full_embedding_pipeline(fasta_files, model, tokenizer, batch_size_value, max_duration)
elif task == "ppl":
return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, None, max_duration)
elif task == "ppl-approx":
return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, 0.1, max_duration)
submit_btn.click(
fn=run_pipeline_with_selected_model,
inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("embedding")],
outputs=[download_output, status_output]
)
ppl_button.click(
fn=run_pipeline_with_selected_model,
inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("ppl")],
outputs=[ppl_download, ppl_status]
)
ppl_approx_button.click(
fn=run_pipeline_with_selected_model,
inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("ppl-approx")],
outputs=[ppl_download, ppl_status]
)
gr.Markdown("""
<u>Citation:</u> Zeming Lin et al. ,Evolutionary-scale prediction of atomic-level protein structure with a language model.Science379,1123-1130(2023).DOI:10.1126/science.ade2574
""")
if __name__ == "__main__":
demo.launch()