File size: 6,548 Bytes
bd082dc
 
5d09dff
bd082dc
a26c5b0
bd082dc
 
 
 
 
 
 
f42fb15
58ac1b4
 
 
5d09dff
 
bd082dc
 
 
 
 
6d736db
bd082dc
f42fb15
58ac1b4
bd082dc
 
 
58ac1b4
bd082dc
 
 
 
 
 
f42fb15
 
 
a26c5b0
 
 
 
 
 
 
 
bd082dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bea5b09
296764f
bea5b09
bd082dc
58ac1b4
 
 
 
047aecb
58ac1b4
 
 
 
 
 
 
 
 
 
 
047aecb
6d736db
58ac1b4
bd082dc
 
 
f42fb15
a26c5b0
 
 
 
bd082dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bea5b09
bd082dc
 
 
 
 
 
 
 
 
bea5b09
bd082dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7732b8
bd082dc
 
 
 
 
 
 
 
 
a26c5b0
bd082dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# app.py
# --- IMPORTS ---
import os
import re
import esm
import gradio as gr
import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from protobind_diff.model import ModelGenerator
from protobind_diff.data_loader import InferenceDataset
from huggingface_hub import hf_hub_download
import spaces
from pathlib import Path
import uuid, json, hashlib
from huggingface_hub import CommitScheduler
from datetime import datetime


# Hugging Face Hub details
REPO_ID = "ai-gero/ProtoBind-Diff"
MODEL_FILENAME = "model.ckpt"
TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
Request = gr.Request

@spaces.GPU(duration=120) 
def generate_smiles_for_sequence(protein_sequence: str, num_samples: int, request: Request):
    """
    The main prediction function that runs the full pipeline.
    """
    log_run(request.client.host or "unknown", protein_sequence)
    if not protein_sequence:
        raise gr.Error("Protein sequence cannot be empty.")
    protein_sequence = re.sub(r"[^A-Z]", "", protein_sequence.upper())
    if len(protein_sequence) < 10:
        raise gr.Error("Protein sequence is too short.")

    device = "cuda" 
    esm_model.to(device)
    protobind_model.to(device) 

    with torch.no_grad():
        batch_converter = alphabet.get_batch_converter()
        _, _, tokens = batch_converter([("protein", protein_sequence)])
        tokens = tokens.to(device)
        embedding = esm_model(tokens, repr_layers=[33])["representations"][33][:, 1:-1, :]
        embedding = embedding.float() if device == "cpu" else embedding.bfloat16()

    n_batches = num_samples // 10
    dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches)
    loader = DataLoader(dataset, batch_size=None)

    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        logger=False,
        precision="16-mixed" if device == "cuda" else "32-true"
    )

    predictions_batches = trainer.predict(model=protobind_model, dataloaders=loader)

    all_smiles = [smi for batch in predictions_batches for smi in batch[0]]
    unique_smiles = list(set(all_smiles))

    return ",\n".join(unique_smiles)

def enable_btn(seq: str):
    return gr.update(interactive=len(seq) >= 10)


def log_run(client_ip: str, seq: str):
    rec = {
        "ts": datetime.utcnow().isoformat(timespec="seconds"),
        "client": client_ip,
        "seq": seq,
        "seq_len": len(seq)
    }
    (LOG_FOLDER / f"{uuid.uuid4()}.json").write_text(json.dumps(rec))


LOG_FOLDER = Path("usage_logs"); LOG_FOLDER.mkdir(exist_ok=True)

scheduler = CommitScheduler(
    repo_id="ai-gero/protobind_usage",
    repo_type="dataset",
    folder_path=str(LOG_FOLDER),
    every=1,
    token=os.getenv("HF_TOKEN")
)
# --- GRADIO APP DEFINITION ---

# Load models on app startup
device = "cpu"
esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D')
esm_model.eval()
esm_model = esm_model.to(device)

tokenizer_path = hf_hub_download(
    repo_id=REPO_ID,
    filename=TOKENIZER_FILENAME,
)
ckpt_path = hf_hub_download(
    repo_id=REPO_ID,
    filename=MODEL_FILENAME,
)
protobind_model = ModelGenerator.load_from_checkpoint(
    ckpt_path,
    map_location=device,
    tokenizer_path=tokenizer_path,
    seq_embedding_dim=1280,
    load=True,
)
protobind_model.eval()
protobind_model.to(device)
# Define the UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ProtoBind-Diff: Protein-Conditioned Ligand Generation
        This Space demonstrates **ProtoBind-Diff**, a diffusion model for generating novel drug-like molecules (ligands)
        conditioned on a target protein sequence. Provide a protein's amino acid sequence to generate potential binding molecules in SMILES format.
        """
    )

    with gr.Row():
        with gr.Column(scale=2):
            protein_sequence = gr.Textbox(
                lines=10,
                label="Protein Amino Acid Sequence",
                placeholder="Enter your protein sequence here (e.g., MGY...)"
            )
            num_samples = gr.Slider(
                minimum=10,
                maximum=200,
                value=50,
                step=10,
                label="Generation Attempts",
                info=(
                    "Upper limit on generation attempts. Duplicates and invalid molecules "
                    "are discarded, so the final count of unique molecules may be lower. "
                    "More attempts increase runtime but can improve diversity."
                )
            )
            submit_btn = gr.Button("Generate Molecules", variant="primary", interactive=False)

        with gr.Column(scale=3):
            output_smiles = gr.Textbox(
                lines=15,
                label="Generated SMILES",
                info="A list of unique, valid SMILES strings generated for the target protein.",
                interactive=True
            )


    gr.Markdown("### Examples")
    gr.Examples(
        examples=[
            ["MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS",
             50],
            ["MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWTVDSENRTNLSCEGCLSPSCLSLLHLQEKNWSALLTAVVIILTIAGNILVIMAVSLEKKLQNATNYFLMSLAIADMLLGFLVMPVSMLTILYGYRWPLPSKLCAVWIYLDVLFSTASIMHLCAISLDRYVAIQNPIHHSRFNSRTKAFLKIIAVWTISVGISMPIPVFGLQDDSKVFKEGSCLLADDNFVLIGSFVSFFIPLTIMVITYFLTIKSLQKEATLCVSDLGTRAKLASFSFLPQSSLSSEKLFQRSIHREPGSYTGRRTMQSISNEQKACKVLGIVFFLFVVMWCPFFITNIMAVICKESCNEDVIGALLNVFVWIGYLSSAVNPLVYTLFNKTYRSAFSRYIQCQYKENKKPLQLILVNTIPALAYKSSQLQMGQKKNSKQDAKTTDNDCSMVALGKQHSEEASKDNSDGVNEKVSCV",
             100]
        ],
        inputs=[protein_sequence, num_samples],
        cache_examples=False,
    )

    gr.Markdown(
        """
        ---
        *Model developed by Gero AI. For more details, check out the [original repository](https://github.com/gero-science/ProtoBind-Diff).*
        """
    )
    protein_sequence.change(enable_btn, inputs=protein_sequence, outputs=submit_btn)

    submit_btn.click(
        fn=generate_smiles_for_sequence,
        inputs=[protein_sequence, num_samples],
        outputs=output_smiles
    )

# Launch the app
if __name__ == "__main__":
    demo.queue(max_size=10)
    demo.launch(share=True)