Spaces:
Sleeping
Sleeping
| # app.py | |
| # --- IMPORTS --- | |
| import os | |
| import re | |
| import esm | |
| import gradio as gr | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import lightning.pytorch as pl | |
| from protobind_diff.model import ModelGenerator | |
| from protobind_diff.data_loader import InferenceDataset | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| from pathlib import Path | |
| import uuid, json, hashlib | |
| from huggingface_hub import CommitScheduler | |
| from datetime import datetime | |
| # Hugging Face Hub details | |
| REPO_ID = "ai-gero/ProtoBind-Diff" | |
| MODEL_FILENAME = "model.ckpt" | |
| TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json" | |
| Request = gr.Request | |
| def generate_smiles_for_sequence(protein_sequence: str, num_samples: int, request: Request): | |
| """ | |
| The main prediction function that runs the full pipeline. | |
| """ | |
| log_run(request.client.host or "unknown", protein_sequence) | |
| if not protein_sequence: | |
| raise gr.Error("Protein sequence cannot be empty.") | |
| protein_sequence = re.sub(r"[^A-Z]", "", protein_sequence.upper()) | |
| if len(protein_sequence) < 10: | |
| raise gr.Error("Protein sequence is too short.") | |
| device = "cuda" | |
| esm_model.to(device) | |
| protobind_model.to(device) | |
| with torch.no_grad(): | |
| batch_converter = alphabet.get_batch_converter() | |
| _, _, tokens = batch_converter([("protein", protein_sequence)]) | |
| tokens = tokens.to(device) | |
| embedding = esm_model(tokens, repr_layers=[33])["representations"][33][:, 1:-1, :] | |
| embedding = embedding.float() if device == "cpu" else embedding.bfloat16() | |
| n_batches = num_samples // 10 | |
| dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches) | |
| loader = DataLoader(dataset, batch_size=None) | |
| trainer = pl.Trainer( | |
| accelerator="auto", | |
| devices=1, | |
| logger=False, | |
| precision="16-mixed" if device == "cuda" else "32-true" | |
| ) | |
| predictions_batches = trainer.predict(model=protobind_model, dataloaders=loader) | |
| all_smiles = [smi for batch in predictions_batches for smi in batch[0]] | |
| unique_smiles = list(set(all_smiles)) | |
| return ",\n".join(unique_smiles) | |
| def enable_btn(seq: str): | |
| return gr.update(interactive=len(seq) >= 10) | |
| def log_run(client_ip: str, seq: str): | |
| rec = { | |
| "ts": datetime.utcnow().isoformat(timespec="seconds"), | |
| "client": client_ip, | |
| "seq": seq, | |
| "seq_len": len(seq) | |
| } | |
| (LOG_FOLDER / f"{uuid.uuid4()}.json").write_text(json.dumps(rec)) | |
| LOG_FOLDER = Path("usage_logs"); LOG_FOLDER.mkdir(exist_ok=True) | |
| scheduler = CommitScheduler( | |
| repo_id="ai-gero/protobind_usage", | |
| repo_type="dataset", | |
| folder_path=str(LOG_FOLDER), | |
| every=1, | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| # --- GRADIO APP DEFINITION --- | |
| # Load models on app startup | |
| device = "cpu" | |
| esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D') | |
| esm_model.eval() | |
| esm_model = esm_model.to(device) | |
| tokenizer_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=TOKENIZER_FILENAME, | |
| ) | |
| ckpt_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=MODEL_FILENAME, | |
| ) | |
| protobind_model = ModelGenerator.load_from_checkpoint( | |
| ckpt_path, | |
| map_location=device, | |
| tokenizer_path=tokenizer_path, | |
| seq_embedding_dim=1280, | |
| load=True, | |
| ) | |
| protobind_model.eval() | |
| protobind_model.to(device) | |
| # Define the UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # ProtoBind-Diff: Protein-Conditioned Ligand Generation | |
| This Space demonstrates **ProtoBind-Diff**, a diffusion model for generating novel drug-like molecules (ligands) | |
| conditioned on a target protein sequence. Provide a protein's amino acid sequence to generate potential binding molecules in SMILES format. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| protein_sequence = gr.Textbox( | |
| lines=10, | |
| label="Protein Amino Acid Sequence", | |
| placeholder="Enter your protein sequence here (e.g., MGY...)" | |
| ) | |
| num_samples = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=50, | |
| step=10, | |
| label="Generation Attempts", | |
| info=( | |
| "Upper limit on generation attempts. Duplicates and invalid molecules " | |
| "are discarded, so the final count of unique molecules may be lower. " | |
| "More attempts increase runtime but can improve diversity." | |
| ) | |
| ) | |
| submit_btn = gr.Button("Generate Molecules", variant="primary", interactive=False) | |
| with gr.Column(scale=3): | |
| output_smiles = gr.Textbox( | |
| lines=15, | |
| label="Generated SMILES", | |
| info="A list of unique, valid SMILES strings generated for the target protein.", | |
| interactive=True | |
| ) | |
| gr.Markdown("### Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS", | |
| 50], | |
| ["MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWTVDSENRTNLSCEGCLSPSCLSLLHLQEKNWSALLTAVVIILTIAGNILVIMAVSLEKKLQNATNYFLMSLAIADMLLGFLVMPVSMLTILYGYRWPLPSKLCAVWIYLDVLFSTASIMHLCAISLDRYVAIQNPIHHSRFNSRTKAFLKIIAVWTISVGISMPIPVFGLQDDSKVFKEGSCLLADDNFVLIGSFVSFFIPLTIMVITYFLTIKSLQKEATLCVSDLGTRAKLASFSFLPQSSLSSEKLFQRSIHREPGSYTGRRTMQSISNEQKACKVLGIVFFLFVVMWCPFFITNIMAVICKESCNEDVIGALLNVFVWIGYLSSAVNPLVYTLFNKTYRSAFSRYIQCQYKENKKPLQLILVNTIPALAYKSSQLQMGQKKNSKQDAKTTDNDCSMVALGKQHSEEASKDNSDGVNEKVSCV", | |
| 100] | |
| ], | |
| inputs=[protein_sequence, num_samples], | |
| cache_examples=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| *Model developed by Gero AI. For more details, check out the [original repository](https://github.com/gero-science/ProtoBind-Diff).* | |
| """ | |
| ) | |
| protein_sequence.change(enable_btn, inputs=protein_sequence, outputs=submit_btn) | |
| submit_btn.click( | |
| fn=generate_smiles_for_sequence, | |
| inputs=[protein_sequence, num_samples], | |
| outputs=output_smiles | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10) | |
| demo.launch(share=True) | |