ProtoBind-Diff / app.py
vladimir.manuylov
fixed the bug
e7732b8
# app.py
# --- IMPORTS ---
import os
import re
import esm
import gradio as gr
import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from protobind_diff.model import ModelGenerator
from protobind_diff.data_loader import InferenceDataset
from huggingface_hub import hf_hub_download
import spaces
from pathlib import Path
import uuid, json, hashlib
from huggingface_hub import CommitScheduler
from datetime import datetime
# Hugging Face Hub details
REPO_ID = "ai-gero/ProtoBind-Diff"
MODEL_FILENAME = "model.ckpt"
TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
Request = gr.Request
@spaces.GPU(duration=120)
def generate_smiles_for_sequence(protein_sequence: str, num_samples: int, request: Request):
"""
The main prediction function that runs the full pipeline.
"""
log_run(request.client.host or "unknown", protein_sequence)
if not protein_sequence:
raise gr.Error("Protein sequence cannot be empty.")
protein_sequence = re.sub(r"[^A-Z]", "", protein_sequence.upper())
if len(protein_sequence) < 10:
raise gr.Error("Protein sequence is too short.")
device = "cuda"
esm_model.to(device)
protobind_model.to(device)
with torch.no_grad():
batch_converter = alphabet.get_batch_converter()
_, _, tokens = batch_converter([("protein", protein_sequence)])
tokens = tokens.to(device)
embedding = esm_model(tokens, repr_layers=[33])["representations"][33][:, 1:-1, :]
embedding = embedding.float() if device == "cpu" else embedding.bfloat16()
n_batches = num_samples // 10
dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches)
loader = DataLoader(dataset, batch_size=None)
trainer = pl.Trainer(
accelerator="auto",
devices=1,
logger=False,
precision="16-mixed" if device == "cuda" else "32-true"
)
predictions_batches = trainer.predict(model=protobind_model, dataloaders=loader)
all_smiles = [smi for batch in predictions_batches for smi in batch[0]]
unique_smiles = list(set(all_smiles))
return ",\n".join(unique_smiles)
def enable_btn(seq: str):
return gr.update(interactive=len(seq) >= 10)
def log_run(client_ip: str, seq: str):
rec = {
"ts": datetime.utcnow().isoformat(timespec="seconds"),
"client": client_ip,
"seq": seq,
"seq_len": len(seq)
}
(LOG_FOLDER / f"{uuid.uuid4()}.json").write_text(json.dumps(rec))
LOG_FOLDER = Path("usage_logs"); LOG_FOLDER.mkdir(exist_ok=True)
scheduler = CommitScheduler(
repo_id="ai-gero/protobind_usage",
repo_type="dataset",
folder_path=str(LOG_FOLDER),
every=1,
token=os.getenv("HF_TOKEN")
)
# --- GRADIO APP DEFINITION ---
# Load models on app startup
device = "cpu"
esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D')
esm_model.eval()
esm_model = esm_model.to(device)
tokenizer_path = hf_hub_download(
repo_id=REPO_ID,
filename=TOKENIZER_FILENAME,
)
ckpt_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILENAME,
)
protobind_model = ModelGenerator.load_from_checkpoint(
ckpt_path,
map_location=device,
tokenizer_path=tokenizer_path,
seq_embedding_dim=1280,
load=True,
)
protobind_model.eval()
protobind_model.to(device)
# Define the UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ProtoBind-Diff: Protein-Conditioned Ligand Generation
This Space demonstrates **ProtoBind-Diff**, a diffusion model for generating novel drug-like molecules (ligands)
conditioned on a target protein sequence. Provide a protein's amino acid sequence to generate potential binding molecules in SMILES format.
"""
)
with gr.Row():
with gr.Column(scale=2):
protein_sequence = gr.Textbox(
lines=10,
label="Protein Amino Acid Sequence",
placeholder="Enter your protein sequence here (e.g., MGY...)"
)
num_samples = gr.Slider(
minimum=10,
maximum=200,
value=50,
step=10,
label="Generation Attempts",
info=(
"Upper limit on generation attempts. Duplicates and invalid molecules "
"are discarded, so the final count of unique molecules may be lower. "
"More attempts increase runtime but can improve diversity."
)
)
submit_btn = gr.Button("Generate Molecules", variant="primary", interactive=False)
with gr.Column(scale=3):
output_smiles = gr.Textbox(
lines=15,
label="Generated SMILES",
info="A list of unique, valid SMILES strings generated for the target protein.",
interactive=True
)
gr.Markdown("### Examples")
gr.Examples(
examples=[
["MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS",
50],
["MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWTVDSENRTNLSCEGCLSPSCLSLLHLQEKNWSALLTAVVIILTIAGNILVIMAVSLEKKLQNATNYFLMSLAIADMLLGFLVMPVSMLTILYGYRWPLPSKLCAVWIYLDVLFSTASIMHLCAISLDRYVAIQNPIHHSRFNSRTKAFLKIIAVWTISVGISMPIPVFGLQDDSKVFKEGSCLLADDNFVLIGSFVSFFIPLTIMVITYFLTIKSLQKEATLCVSDLGTRAKLASFSFLPQSSLSSEKLFQRSIHREPGSYTGRRTMQSISNEQKACKVLGIVFFLFVVMWCPFFITNIMAVICKESCNEDVIGALLNVFVWIGYLSSAVNPLVYTLFNKTYRSAFSRYIQCQYKENKKPLQLILVNTIPALAYKSSQLQMGQKKNSKQDAKTTDNDCSMVALGKQHSEEASKDNSDGVNEKVSCV",
100]
],
inputs=[protein_sequence, num_samples],
cache_examples=False,
)
gr.Markdown(
"""
---
*Model developed by Gero AI. For more details, check out the [original repository](https://github.com/gero-science/ProtoBind-Diff).*
"""
)
protein_sequence.change(enable_btn, inputs=protein_sequence, outputs=submit_btn)
submit_btn.click(
fn=generate_smiles_for_sequence,
inputs=[protein_sequence, num_samples],
outputs=output_smiles
)
# Launch the app
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(share=True)