File size: 7,507 Bytes
b3297e4
2898034
 
8a3cce7
2898034
 
 
 
 
c3ad370
4dcb469
 
ae38197
b3297e4
4dcb469
2898034
4dcb469
 
c767ebc
 
4dcb469
2898034
4dcb469
 
2898034
 
 
 
 
fdc5e1b
 
5f4bd43
c9a01db
 
 
fdc5e1b
 
 
 
 
 
2898034
fdc5e1b
 
 
 
 
2898034
fdc5e1b
 
2898034
 
4dcb469
 
ae38197
 
4dcb469
 
a749e8f
 
 
 
 
 
 
f0a8bfb
bd04feb
f0a8bfb
 
 
bd04feb
f0a8bfb
ae38197
 
 
 
 
 
 
 
4dcb469
ae38197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecb4e6c
0aa10ea
fdc5e1b
ae38197
 
 
 
 
 
 
 
 
 
 
f0a8bfb
ae38197
 
 
 
 
 
f0a8bfb
ae38197
f0a8bfb
ecb4e6c
f0a8bfb
4dcb469
2898034
ae38197
f0a8bfb
2898034
 
4dcb469
0aa10ea
 
 
 
 
ae38197
ecb4e6c
 
f0a8bfb
ecb4e6c
 
 
ad7ce30
 
 
 
4dcb469
2898034
 
 
 
b3297e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
import torch
import numpy as np
from Bio import SeqIO
import tempfile
import os
import json
from pathlib import Path
import zipfile
import spaces
from utils.download_models import *
from utils.handle_files import parse_fasta_files
from utils.pipelines import generate_embeddings, full_embedding_pipeline, full_ppl_pipeline

print("Downloading ESM2 models...")

MODELS = {
    "facebook/esm2_t6_8M_UR50D": "ESM2-8M",
    "facebook/esm2_t12_35M_UR50D": "ESM2-35M",
    "facebook/esm2_t33_650M_UR50D": "ESM2-650M"
}

cache_dirs = cache_all_models(MODELS)
models_and_tokenizers = load_all_models(MODELS)


# Create Gradio interface
with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
    gr.Markdown("""
    # ESM2 for candidate sequence filtering 🤖
                
    Once one has generated de novo protein sequences using a tool like LigandMPNN, one must rank them to select promising candidates for experimental validation. One powerful approach is to use <a href="https://www.science.org/doi/10.1126/science.ade2574" target="_blank">protein language models like Meta's ESM2.</a>
    These language models rely on a BERT-like architecture and a Masked Language Modeling (MLM) objective to learn rich representations of protein sequences. Note that this Space pairs well with the companion <a href="https://huggingface.co/spaces/hugging-science/RFdiffusion3" target="_blank">RFdiffusion3</a>, <a href="https://huggingface.co/spaces/hugging-science/LigandMPNN" target="_blank">LigandMPNN</a> and RosettaFold3 Spaces for a full de novo design pipeline!
                
    ESM is used for two main purposes:
    1. **Generating embeddings**: ESM's hidden layers creates high-dimensional representations of protein sequences that capture structural and functional information. 
        These embeddings can be used as input features for downstream machine learning models to predict function, properties or even for folding.
        Embeddings can also be used with dimensionality reduction techniques like t-SNE to visualize to identify clusters or compare against known proteins.
    2. **Calculating pseudo-perplexity scores (PPL)**: The lower this score is for a given input sequence, the more "natural" or "plausible" it is according to the model's learned distribution.
        Such scores are often used as a filtering criterion in de novo design, as sequences with lower PPL are more likely to express properly in the lab and fold into stable structures.
        PPL scores provide an orthogonal evaluation metric to structure-based methods like RosettaFold.
    
    ## How to use this Space:
    - **Choose the ESM2 model:** models mainly differ by the number of parameters (8M, 35M, 650M). Larger models produce better PPL scores and richer embeddings but have longer runtimes.
    - **Upload one or more FASTA files** containing your candidate sequences.
    - **Choose the batch size:** it controls how many sequences are processed together. Larger batch sizes can speed up processing but require more GPU memory.
    - **Choose between generating embeddings or calculating pseudo-perplexity scores.**
    
    Note that calculating PPL scores is much more computationally intensive than generating embeddings, it scales cubically with sequence length $L$. This is because calculating PPL requires $L$ forward passes through the model, each with a different token masked out.
    For long sequences or large numbers of sequences, we recommend using the approximate PPL calculation, which masks 10% of tokens at a time and thus only scales quadratically with sequence length. This provides a good tradeoff between accuracy and runtime.
    """)
    
    with gr.Row():
        model_dropdown = gr.Dropdown(
            choices=list(MODELS.keys()),
            value=list(MODELS.keys())[0],
            label="Select Model"
        )
        batch_size = gr.Slider(
            minimum=1,
            maximum=128,
            step=1,
            value=32,
            label="Batch Size"
        )
        max_duration = gr.Number(
            value=300,
            label="Max Duration (seconds)",
            precision=0,
            minimum=1,
            maximum=3600
        )
    
    with gr.Row():
        with gr.Column():
            input_files = gr.File(
                label="Upload FASTA files",
                file_count="multiple",
                file_types=[".fasta", ".fa", ".faa"]
            )

        
        with gr.Column():
            with gr.Tabs():
                with gr.TabItem("Generate Embeddings"):
                    submit_btn = gr.Button("Generate Embeddings", variant="primary", size="lg")
                    status_output = gr.Textbox(
                        label="Waiting for embeddings generation...",
                        interactive=False,
                        lines=6
                    )

                    download_output = gr.File(
                        label="Download Output Files",
                        file_count="multiple"
                    )
                with gr.TabItem("Calculate Pseudo-Perplexity scores"):
                    with gr.Row():
                        ppl_button = gr.Button("Calculate Exact PPL", variant="primary", size="lg")
                        ppl_approx_button = gr.Button("Calculate Approximate PPL", variant="primary", size="lg")
                    ppl_status = gr.Textbox(
                        label="Waiting for pseudo-perplexity calculation...",
                        interactive=False,
                        lines=6
                    )
                    ppl_download = gr.File(
                        label="Download Pseudo-Perplexity Output",
                        file_count="multiple"
                    )


    def run_pipeline_with_selected_model(fasta_files, model_key, batch_size_value, max_duration, task="embedding"):
        """Wrapper to run pipeline with selected model from dropdown."""

        if not fasta_files:
            return gr.update(), "No FASTA files uploaded. Please upload at least one FASTA file for inference."
        model, tokenizer = models_and_tokenizers[model_key]
        if task == "embedding":
            return full_embedding_pipeline(fasta_files, model, tokenizer, batch_size_value, max_duration)
        elif task == "ppl":
            return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, None, max_duration)
        elif task == "ppl-approx":
            return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, 0.1, max_duration)

    submit_btn.click(
        fn=run_pipeline_with_selected_model,
        inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("embedding")],
        outputs=[download_output, status_output]
    )

    ppl_button.click(
        fn=run_pipeline_with_selected_model,
        inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("ppl")],
        outputs=[ppl_download, ppl_status]
    )

    ppl_approx_button.click(
        fn=run_pipeline_with_selected_model,
        inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("ppl-approx")],
        outputs=[ppl_download, ppl_status]
    )

    gr.Markdown(""" 
                <u>Citation:</u> Zeming Lin et al. ,Evolutionary-scale prediction of atomic-level protein structure with a language model.Science379,1123-1130(2023).DOI:10.1126/science.ade2574
                """)




if __name__ == "__main__":
    demo.launch()