# app.py # --- IMPORTS --- import os import re import esm import gradio as gr import torch from torch.utils.data import DataLoader import lightning.pytorch as pl from protobind_diff.model import ModelGenerator from protobind_diff.data_loader import InferenceDataset from huggingface_hub import hf_hub_download import spaces from pathlib import Path import uuid, json, hashlib from huggingface_hub import CommitScheduler from datetime import datetime # Hugging Face Hub details REPO_ID = "ai-gero/ProtoBind-Diff" MODEL_FILENAME = "model.ckpt" TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json" Request = gr.Request @spaces.GPU(duration=120) def generate_smiles_for_sequence(protein_sequence: str, num_samples: int, request: Request): """ The main prediction function that runs the full pipeline. """ log_run(request.client.host or "unknown", protein_sequence) if not protein_sequence: raise gr.Error("Protein sequence cannot be empty.") protein_sequence = re.sub(r"[^A-Z]", "", protein_sequence.upper()) if len(protein_sequence) < 10: raise gr.Error("Protein sequence is too short.") device = "cuda" esm_model.to(device) protobind_model.to(device) with torch.no_grad(): batch_converter = alphabet.get_batch_converter() _, _, tokens = batch_converter([("protein", protein_sequence)]) tokens = tokens.to(device) embedding = esm_model(tokens, repr_layers=[33])["representations"][33][:, 1:-1, :] embedding = embedding.float() if device == "cpu" else embedding.bfloat16() n_batches = num_samples // 10 dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches) loader = DataLoader(dataset, batch_size=None) trainer = pl.Trainer( accelerator="auto", devices=1, logger=False, precision="16-mixed" if device == "cuda" else "32-true" ) predictions_batches = trainer.predict(model=protobind_model, dataloaders=loader) all_smiles = [smi for batch in predictions_batches for smi in batch[0]] unique_smiles = list(set(all_smiles)) return ",\n".join(unique_smiles) def enable_btn(seq: str): return gr.update(interactive=len(seq) >= 10) def log_run(client_ip: str, seq: str): rec = { "ts": datetime.utcnow().isoformat(timespec="seconds"), "client": client_ip, "seq": seq, "seq_len": len(seq) } (LOG_FOLDER / f"{uuid.uuid4()}.json").write_text(json.dumps(rec)) LOG_FOLDER = Path("usage_logs"); LOG_FOLDER.mkdir(exist_ok=True) scheduler = CommitScheduler( repo_id="ai-gero/protobind_usage", repo_type="dataset", folder_path=str(LOG_FOLDER), every=1, token=os.getenv("HF_TOKEN") ) # --- GRADIO APP DEFINITION --- # Load models on app startup device = "cpu" esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D') esm_model.eval() esm_model = esm_model.to(device) tokenizer_path = hf_hub_download( repo_id=REPO_ID, filename=TOKENIZER_FILENAME, ) ckpt_path = hf_hub_download( repo_id=REPO_ID, filename=MODEL_FILENAME, ) protobind_model = ModelGenerator.load_from_checkpoint( ckpt_path, map_location=device, tokenizer_path=tokenizer_path, seq_embedding_dim=1280, load=True, ) protobind_model.eval() protobind_model.to(device) # Define the UI with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # ProtoBind-Diff: Protein-Conditioned Ligand Generation This Space demonstrates **ProtoBind-Diff**, a diffusion model for generating novel drug-like molecules (ligands) conditioned on a target protein sequence. Provide a protein's amino acid sequence to generate potential binding molecules in SMILES format. """ ) with gr.Row(): with gr.Column(scale=2): protein_sequence = gr.Textbox( lines=10, label="Protein Amino Acid Sequence", placeholder="Enter your protein sequence here (e.g., MGY...)" ) num_samples = gr.Slider( minimum=10, maximum=200, value=50, step=10, label="Generation Attempts", info=( "Upper limit on generation attempts. Duplicates and invalid molecules " "are discarded, so the final count of unique molecules may be lower. " "More attempts increase runtime but can improve diversity." ) ) submit_btn = gr.Button("Generate Molecules", variant="primary", interactive=False) with gr.Column(scale=3): output_smiles = gr.Textbox( lines=15, label="Generated SMILES", info="A list of unique, valid SMILES strings generated for the target protein.", interactive=True ) gr.Markdown("### Examples") gr.Examples( examples=[ ["MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS", 50], ["MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWTVDSENRTNLSCEGCLSPSCLSLLHLQEKNWSALLTAVVIILTIAGNILVIMAVSLEKKLQNATNYFLMSLAIADMLLGFLVMPVSMLTILYGYRWPLPSKLCAVWIYLDVLFSTASIMHLCAISLDRYVAIQNPIHHSRFNSRTKAFLKIIAVWTISVGISMPIPVFGLQDDSKVFKEGSCLLADDNFVLIGSFVSFFIPLTIMVITYFLTIKSLQKEATLCVSDLGTRAKLASFSFLPQSSLSSEKLFQRSIHREPGSYTGRRTMQSISNEQKACKVLGIVFFLFVVMWCPFFITNIMAVICKESCNEDVIGALLNVFVWIGYLSSAVNPLVYTLFNKTYRSAFSRYIQCQYKENKKPLQLILVNTIPALAYKSSQLQMGQKKNSKQDAKTTDNDCSMVALGKQHSEEASKDNSDGVNEKVSCV", 100] ], inputs=[protein_sequence, num_samples], cache_examples=False, ) gr.Markdown( """ --- *Model developed by Gero AI. For more details, check out the [original repository](https://github.com/gero-science/ProtoBind-Diff).* """ ) protein_sequence.change(enable_btn, inputs=protein_sequence, outputs=submit_btn) submit_btn.click( fn=generate_smiles_for_sequence, inputs=[protein_sequence, num_samples], outputs=output_smiles ) # Launch the app if __name__ == "__main__": demo.queue(max_size=10) demo.launch(share=True)