Spaces:
Sleeping
Sleeping
File size: 6,548 Bytes
bd082dc 5d09dff bd082dc a26c5b0 bd082dc f42fb15 58ac1b4 5d09dff bd082dc 6d736db bd082dc f42fb15 58ac1b4 bd082dc 58ac1b4 bd082dc f42fb15 a26c5b0 bd082dc bea5b09 296764f bea5b09 bd082dc 58ac1b4 047aecb 58ac1b4 047aecb 6d736db 58ac1b4 bd082dc f42fb15 a26c5b0 bd082dc bea5b09 bd082dc bea5b09 bd082dc e7732b8 bd082dc a26c5b0 bd082dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# app.py
# --- IMPORTS ---
import os
import re
import esm
import gradio as gr
import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from protobind_diff.model import ModelGenerator
from protobind_diff.data_loader import InferenceDataset
from huggingface_hub import hf_hub_download
import spaces
from pathlib import Path
import uuid, json, hashlib
from huggingface_hub import CommitScheduler
from datetime import datetime
# Hugging Face Hub details
REPO_ID = "ai-gero/ProtoBind-Diff"
MODEL_FILENAME = "model.ckpt"
TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
Request = gr.Request
@spaces.GPU(duration=120)
def generate_smiles_for_sequence(protein_sequence: str, num_samples: int, request: Request):
"""
The main prediction function that runs the full pipeline.
"""
log_run(request.client.host or "unknown", protein_sequence)
if not protein_sequence:
raise gr.Error("Protein sequence cannot be empty.")
protein_sequence = re.sub(r"[^A-Z]", "", protein_sequence.upper())
if len(protein_sequence) < 10:
raise gr.Error("Protein sequence is too short.")
device = "cuda"
esm_model.to(device)
protobind_model.to(device)
with torch.no_grad():
batch_converter = alphabet.get_batch_converter()
_, _, tokens = batch_converter([("protein", protein_sequence)])
tokens = tokens.to(device)
embedding = esm_model(tokens, repr_layers=[33])["representations"][33][:, 1:-1, :]
embedding = embedding.float() if device == "cpu" else embedding.bfloat16()
n_batches = num_samples // 10
dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches)
loader = DataLoader(dataset, batch_size=None)
trainer = pl.Trainer(
accelerator="auto",
devices=1,
logger=False,
precision="16-mixed" if device == "cuda" else "32-true"
)
predictions_batches = trainer.predict(model=protobind_model, dataloaders=loader)
all_smiles = [smi for batch in predictions_batches for smi in batch[0]]
unique_smiles = list(set(all_smiles))
return ",\n".join(unique_smiles)
def enable_btn(seq: str):
return gr.update(interactive=len(seq) >= 10)
def log_run(client_ip: str, seq: str):
rec = {
"ts": datetime.utcnow().isoformat(timespec="seconds"),
"client": client_ip,
"seq": seq,
"seq_len": len(seq)
}
(LOG_FOLDER / f"{uuid.uuid4()}.json").write_text(json.dumps(rec))
LOG_FOLDER = Path("usage_logs"); LOG_FOLDER.mkdir(exist_ok=True)
scheduler = CommitScheduler(
repo_id="ai-gero/protobind_usage",
repo_type="dataset",
folder_path=str(LOG_FOLDER),
every=1,
token=os.getenv("HF_TOKEN")
)
# --- GRADIO APP DEFINITION ---
# Load models on app startup
device = "cpu"
esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D')
esm_model.eval()
esm_model = esm_model.to(device)
tokenizer_path = hf_hub_download(
repo_id=REPO_ID,
filename=TOKENIZER_FILENAME,
)
ckpt_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILENAME,
)
protobind_model = ModelGenerator.load_from_checkpoint(
ckpt_path,
map_location=device,
tokenizer_path=tokenizer_path,
seq_embedding_dim=1280,
load=True,
)
protobind_model.eval()
protobind_model.to(device)
# Define the UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ProtoBind-Diff: Protein-Conditioned Ligand Generation
This Space demonstrates **ProtoBind-Diff**, a diffusion model for generating novel drug-like molecules (ligands)
conditioned on a target protein sequence. Provide a protein's amino acid sequence to generate potential binding molecules in SMILES format.
"""
)
with gr.Row():
with gr.Column(scale=2):
protein_sequence = gr.Textbox(
lines=10,
label="Protein Amino Acid Sequence",
placeholder="Enter your protein sequence here (e.g., MGY...)"
)
num_samples = gr.Slider(
minimum=10,
maximum=200,
value=50,
step=10,
label="Generation Attempts",
info=(
"Upper limit on generation attempts. Duplicates and invalid molecules "
"are discarded, so the final count of unique molecules may be lower. "
"More attempts increase runtime but can improve diversity."
)
)
submit_btn = gr.Button("Generate Molecules", variant="primary", interactive=False)
with gr.Column(scale=3):
output_smiles = gr.Textbox(
lines=15,
label="Generated SMILES",
info="A list of unique, valid SMILES strings generated for the target protein.",
interactive=True
)
gr.Markdown("### Examples")
gr.Examples(
examples=[
["MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS",
50],
["MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWTVDSENRTNLSCEGCLSPSCLSLLHLQEKNWSALLTAVVIILTIAGNILVIMAVSLEKKLQNATNYFLMSLAIADMLLGFLVMPVSMLTILYGYRWPLPSKLCAVWIYLDVLFSTASIMHLCAISLDRYVAIQNPIHHSRFNSRTKAFLKIIAVWTISVGISMPIPVFGLQDDSKVFKEGSCLLADDNFVLIGSFVSFFIPLTIMVITYFLTIKSLQKEATLCVSDLGTRAKLASFSFLPQSSLSSEKLFQRSIHREPGSYTGRRTMQSISNEQKACKVLGIVFFLFVVMWCPFFITNIMAVICKESCNEDVIGALLNVFVWIGYLSSAVNPLVYTLFNKTYRSAFSRYIQCQYKENKKPLQLILVNTIPALAYKSSQLQMGQKKNSKQDAKTTDNDCSMVALGKQHSEEASKDNSDGVNEKVSCV",
100]
],
inputs=[protein_sequence, num_samples],
cache_examples=False,
)
gr.Markdown(
"""
---
*Model developed by Gero AI. For more details, check out the [original repository](https://github.com/gero-science/ProtoBind-Diff).*
"""
)
protein_sequence.change(enable_btn, inputs=protein_sequence, outputs=submit_btn)
submit_btn.click(
fn=generate_smiles_for_sequence,
inputs=[protein_sequence, num_samples],
outputs=output_smiles
)
# Launch the app
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(share=True)
|