Spaces:
Sleeping
Sleeping
vladimir.manuylov
commited on
Commit
·
bd082dc
1
Parent(s):
8e21d42
initial commit
Browse files- LICENSE +21 -0
- README.md +29 -2
- app.py +146 -0
- protobind_diff/__init__.py +0 -0
- protobind_diff/data_loader.py +761 -0
- protobind_diff/decoder_rope.py +769 -0
- protobind_diff/esm_inference.py +175 -0
- protobind_diff/ligands/__init__.py +0 -0
- protobind_diff/ligands/rdkit_utils.py +209 -0
- protobind_diff/ligands/smiles_tokenizer.py +135 -0
- protobind_diff/model.py +411 -0
- protobind_diff/noise_schedule.py +185 -0
- pyproject.toml +44 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: ProtoBind Diff
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
|
@@ -11,4 +11,31 @@ license: mit
|
|
| 11 |
short_description: Structure-free target-specific molecule generation
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: ProtoBind Diff
|
| 3 |
+
emoji: 💊
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
|
|
|
| 11 |
short_description: Structure-free target-specific molecule generation
|
| 12 |
---
|
| 13 |
|
| 14 |
+
## A Structure-Free Diffusion Language Model for Protein Sequence-Conditioned Ligand Design
|
| 15 |
+
|
| 16 |
+
<a href="https://www.biorxiv.org/content/10.1101/2025.06.16.659955v1">
|
| 17 |
+
<img
|
| 18 |
+
src="https://img.shields.io/badge/bioRxiv-paper-blue?logo=biorxiv&logoColor=white"
|
| 19 |
+
alt="Paper on bioRxiv"
|
| 20 |
+
/>
|
| 21 |
+
</a>
|
| 22 |
+
<a href="https://github.com/gero-science/ProtoBind-Diff">
|
| 23 |
+
<img
|
| 24 |
+
src="https://img.shields.io/badge/GitHub-code-black?logo=github&logoColor=white"
|
| 25 |
+
alt="View on GitHub"
|
| 26 |
+
/>
|
| 27 |
+
</a>
|
| 28 |
+
|
| 29 |
+
## Citation
|
| 30 |
+
|
| 31 |
+
```bibtex
|
| 32 |
+
@article {Mistryukova2025.06.16.659955,
|
| 33 |
+
author = {Mistryukova, Lukia and Manuilov, Vladimir and Avchaciov, Konstantin and Fedichev, Peter O.},
|
| 34 |
+
title = {ProtoBind-Diff: A Structure-Free Diffusion Language Model for Protein Sequence-Conditioned Ligand Design},
|
| 35 |
+
year = {2025},
|
| 36 |
+
journal = {bioRxiv}
|
| 37 |
+
}
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## License
|
| 41 |
+
The code and model weights are released under MIT license. See the [LICENSE](LICENSE) file for details.
|
app.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
# --- IMPORTS ---
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
import lightning.pytorch as pl
|
| 10 |
+
from protobind_diff.esm_inference import get_esm_embedding
|
| 11 |
+
from protobind_diff.model import ModelGenerator
|
| 12 |
+
from protobind_diff.data_loader import InferenceDataset
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
|
| 15 |
+
# Hugging Face Hub details
|
| 16 |
+
REPO_ID = "ai-gero/ProtoBind-Diff"
|
| 17 |
+
MODEL_FILENAME = "model.ckpt"
|
| 18 |
+
TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
|
| 22 |
+
"""
|
| 23 |
+
The main prediction function that runs the full pipeline.
|
| 24 |
+
"""
|
| 25 |
+
if not protein_sequence:
|
| 26 |
+
raise gr.Error("Protein sequence cannot be empty.")
|
| 27 |
+
protein_sequence = re.sub(r"[^A-Z]", "", protein_sequence.upper())
|
| 28 |
+
if len(protein_sequence) < 10:
|
| 29 |
+
raise gr.Error("Protein sequence is too short.")
|
| 30 |
+
|
| 31 |
+
embedding = get_esm_embedding(
|
| 32 |
+
protein_sequence,
|
| 33 |
+
'esm2_t33_650M_UR50D',
|
| 34 |
+
device
|
| 35 |
+
).to(dtype=torch.bfloat16)
|
| 36 |
+
n_batches = num_samples // 10
|
| 37 |
+
dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches)
|
| 38 |
+
loader = DataLoader(dataset, batch_size=None)
|
| 39 |
+
|
| 40 |
+
trainer = pl.Trainer(
|
| 41 |
+
accelerator="auto",
|
| 42 |
+
devices=1,
|
| 43 |
+
logger=False,
|
| 44 |
+
precision="16-mixed" if device == "cuda" else "32-true"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
predictions_batches = trainer.predict(model=protobind_model, dataloaders=loader)
|
| 48 |
+
|
| 49 |
+
all_smiles = [smi for batch in predictions_batches for smi in batch[0]]
|
| 50 |
+
unique_smiles = list(set(all_smiles))
|
| 51 |
+
|
| 52 |
+
return ",\n".join(unique_smiles)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# --- GRADIO APP DEFINITION ---
|
| 56 |
+
|
| 57 |
+
# Load models on app startup
|
| 58 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
+
tokenizer_path = hf_hub_download(
|
| 60 |
+
repo_id=REPO_ID,
|
| 61 |
+
filename=TOKENIZER_FILENAME,
|
| 62 |
+
)
|
| 63 |
+
ckpt_path = hf_hub_download(
|
| 64 |
+
repo_id=REPO_ID,
|
| 65 |
+
filename=MODEL_FILENAME,
|
| 66 |
+
)
|
| 67 |
+
protobind_model = ModelGenerator.load_from_checkpoint(
|
| 68 |
+
ckpt_path,
|
| 69 |
+
map_location=device,
|
| 70 |
+
tokenizer_path=tokenizer_path,
|
| 71 |
+
seq_embedding_dim=1280,
|
| 72 |
+
load=True,
|
| 73 |
+
)
|
| 74 |
+
protobind_model.eval()
|
| 75 |
+
protobind_model.to(device)
|
| 76 |
+
# Define the UI
|
| 77 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 78 |
+
gr.Markdown(
|
| 79 |
+
"""
|
| 80 |
+
# ProtoBind-Diff: Protein-Conditioned Ligand Generation
|
| 81 |
+
This Space demonstrates **ProtoBind-Diff**, a diffusion model for generating novel drug-like molecules (ligands)
|
| 82 |
+
conditioned on a target protein sequence. Provide a protein's amino acid sequence to generate potential binding molecules in SMILES format.
|
| 83 |
+
"""
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
with gr.Row():
|
| 87 |
+
with gr.Column(scale=2):
|
| 88 |
+
protein_sequence = gr.Textbox(
|
| 89 |
+
lines=10,
|
| 90 |
+
label="Protein Amino Acid Sequence",
|
| 91 |
+
placeholder="Enter your protein sequence here (e.g., MGY...)"
|
| 92 |
+
)
|
| 93 |
+
num_samples = gr.Slider(
|
| 94 |
+
minimum=10,
|
| 95 |
+
maximum=200,
|
| 96 |
+
value=50,
|
| 97 |
+
step=10,
|
| 98 |
+
label="Generation Attempts",
|
| 99 |
+
info=(
|
| 100 |
+
"Upper limit on generation attempts. Duplicates and invalid molecules "
|
| 101 |
+
"are discarded, so the final count of unique molecules may be lower. "
|
| 102 |
+
"More attempts increase runtime but can improve diversity."
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
submit_btn = gr.Button("Generate Molecules", variant="primary")
|
| 106 |
+
|
| 107 |
+
with gr.Column(scale=3):
|
| 108 |
+
output_smiles = gr.Textbox(
|
| 109 |
+
lines=15,
|
| 110 |
+
label="Generated SMILES",
|
| 111 |
+
info="A list of unique, valid SMILES strings generated for the target protein.",
|
| 112 |
+
interactive=True
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
gr.Markdown("### Examples")
|
| 116 |
+
gr.Examples(
|
| 117 |
+
examples=[
|
| 118 |
+
["MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS",
|
| 119 |
+
50],
|
| 120 |
+
["MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWTVDSENRTNLSCEGCLSPSCLSLLHLQEKNWSALLTAVVIILTIAGNILVIMAVSLEKKLQNATNYFLMSLAIADMLLGFLVMPVSMLTILYGYRWPLPSKLCAVWIYLDVLFSTASIMHLCAISLDRYVAIQNPIHHSRFNSRTKAFLKIIAVWTISVGISMPIPVFGLQDDSKVFKEGSCLLADDNFVLIGSFVSFFIPLTIMVITYFLTIKSLQKEATLCVSDLGTRAKLASFSFLPQSSLSSEKLFQRSIHREPGSYTGRRTMQSISNEQKACKVLGIVFFLFVVMWCPFFITNIMAVICKESCNEDVIGALLNVFVWIGYLSSAVNPLVYTLFNKTYRSAFSRYIQCQYKENKKPLQLILVNTIPALAYKSSQLQMGQKKNSKQDAKTTDNDCSMVALGKQHSEEASKDNSDGVNEKVSCV",
|
| 121 |
+
100]
|
| 122 |
+
],
|
| 123 |
+
inputs=[protein_sequence, num_samples],
|
| 124 |
+
outputs=output_smiles,
|
| 125 |
+
fn=generate_smiles_for_sequence,
|
| 126 |
+
cache_examples=False,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
gr.Markdown(
|
| 130 |
+
"""
|
| 131 |
+
---
|
| 132 |
+
*Model developed by Gero AI. For more details, check out the [original repository](https://github.com/gero-science/ProtoBind-Diff).*
|
| 133 |
+
"""
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
submit_btn.click(
|
| 137 |
+
fn=generate_smiles_for_sequence,
|
| 138 |
+
inputs=[protein_sequence, num_samples],
|
| 139 |
+
outputs=output_smiles
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Launch the app
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
demo.launch(share=True)
|
| 145 |
+
|
| 146 |
+
|
protobind_diff/__init__.py
ADDED
|
File without changes
|
protobind_diff/data_loader.py
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data loader for the protobind-diff.
|
| 2 |
+
# This version only supports ProtobindMaskedDiffusion with SMILES and ESM-2 protein encodings.
|
| 3 |
+
import os.path
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Dict, List, Tuple, Optional, Union
|
| 9 |
+
from zipfile import ZipFile
|
| 10 |
+
|
| 11 |
+
import lightning.pytorch as pl
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 17 |
+
from tqdm.auto import tqdm
|
| 18 |
+
|
| 19 |
+
from .ligands.smiles_tokenizer import ChemformerTokenizer
|
| 20 |
+
from .ligands.rdkit_utils import randomize_smiles_rotated, cluster_fpsim2
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger("lightning")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SplittingMethod(Enum):
|
| 26 |
+
# enum that describes various train/val/test splitting methods.
|
| 27 |
+
RANDOM = 1
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def split_at_random(df: pd.DataFrame, valid_fraction=0.1, test_fraction=0.1, seed=777):
|
| 31 |
+
"""Randomly splits a DataFrame into training, validation, and test sets.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
df (pd.DataFrame): The DataFrame to split.
|
| 35 |
+
valid_fraction (float): The fraction of the data to allocate to the validation set.
|
| 36 |
+
test_fraction (float): The fraction of the data to allocate to the test set.
|
| 37 |
+
seed (int): The random seed for shuffling to ensure reproducibility.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing the
|
| 41 |
+
training, validation, and test DataFrames.
|
| 42 |
+
"""
|
| 43 |
+
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
|
| 44 |
+
valid_size = int(len(df) * valid_fraction)
|
| 45 |
+
test_size = int(len(df) * test_fraction)
|
| 46 |
+
train_size = len(df) - valid_size - test_size
|
| 47 |
+
train_df = df[:train_size]
|
| 48 |
+
valid_df = df[train_size:train_size + valid_size]
|
| 49 |
+
test_df = df[train_size + valid_size:]
|
| 50 |
+
return train_df, valid_df, test_df
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class RandomizedSmilesDataset(object):
|
| 54 |
+
"""Creates a dataset of tokenized SMILES strings, with an option for on-the-fly randomization.
|
| 55 |
+
|
| 56 |
+
This dataset maps integer indices to SMILES strings and provides tokenized
|
| 57 |
+
representations. It can randomize SMILES strings during data retrieval to
|
| 58 |
+
augment the training data.
|
| 59 |
+
|
| 60 |
+
Attributes:
|
| 61 |
+
smiles (pd.Series): A series of SMILES strings indexed by integers.
|
| 62 |
+
tokenizer (ChemformerTokenizer): The tokenizer for converting SMILES to tokens.
|
| 63 |
+
randomize (bool): If True, applies SMILES randomization at retrieval time.
|
| 64 |
+
"""
|
| 65 |
+
def __init__(self, smiles: dict, tokenizer: ChemformerTokenizer,
|
| 66 |
+
randomize: bool = True):
|
| 67 |
+
self.smiles = pd.Series(data=smiles.keys(), index=smiles.values()).sort_index()
|
| 68 |
+
assert len(self.smiles) == self.smiles.index[-1] + 1, (f"{len(self.smiles)}"
|
| 69 |
+
f" {self.smiles.index[:5]} {self.smiles.index[-5:]}")
|
| 70 |
+
self.tokenizer = tokenizer
|
| 71 |
+
self.randomize = randomize
|
| 72 |
+
logger.info(f"Molecular dataset initialized: RandomizedSmilesDataset {type(self.tokenizer)}"
|
| 73 |
+
f" random: {self.randomize}")
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self.smiles)
|
| 77 |
+
|
| 78 |
+
def __getitem__(self, item):
|
| 79 |
+
smi = self.smiles[item]
|
| 80 |
+
if self.randomize:
|
| 81 |
+
smi = randomize_smiles_rotated(smi)
|
| 82 |
+
mol = self.tokenizer.encode(smi)[0]
|
| 83 |
+
return mol
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def from_json(cls, path, **kwargs):
|
| 87 |
+
with open(path) as f:
|
| 88 |
+
categorical_mappings = json.load(f)
|
| 89 |
+
smiles = categorical_mappings['smiles']
|
| 90 |
+
loaded = cls(smiles, **kwargs)
|
| 91 |
+
return loaded
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class RandomizedBatchSampler(torch.utils.data.Sampler):
|
| 95 |
+
"""A batch sampler that minimizes padding while maximizing batch randomness.
|
| 96 |
+
|
| 97 |
+
To achieve this, the sampler employs a two-level shuffling strategy:
|
| 98 |
+
1. The data is first sorted by sequence length and grouped into buckets.
|
| 99 |
+
2. Within each bucket, the sample indices are shuffled.
|
| 100 |
+
3. Batches are created by slicing across the globally sorted list of indices,
|
| 101 |
+
which keeps sequence lengths within a batch similar.
|
| 102 |
+
4. The order of these batches is then shuffled to ensure randomness across epochs.
|
| 103 |
+
|
| 104 |
+
This approach balances the trade-off between minimizing padding (by batching
|
| 105 |
+
similar-length sequences) and maintaining randomness required for effective training.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, sequence_length: np.ndarray, shuffle: bool, batch_volume: int,
|
| 109 |
+
generator: torch.Generator = None, num_ranges: int = 150, batch_size: int = 128):
|
| 110 |
+
"""Initializes the RandomizedBatchSampler.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
sequence_length (np.ndarray): An array of sequence lengths for each item in the dataset.
|
| 114 |
+
shuffle (bool): If True, shuffle batches and indices within length buckets.
|
| 115 |
+
batch_volume (int): The maximum total number of elements (seq_len^2) per batch.
|
| 116 |
+
generator (torch.Generator, optional): PyTorch random number generator. Defaults to None.
|
| 117 |
+
num_ranges (int): The number of buckets to partition the sequence lengths into.
|
| 118 |
+
batch_size (int): The maximum number of samples per batch.
|
| 119 |
+
"""
|
| 120 |
+
self.shuffle = shuffle
|
| 121 |
+
# For val/test (i.e. when we don't shuffle) we can fit more batches in memory as we don't need grads.
|
| 122 |
+
batch_volume_factor = 1 if shuffle else 2
|
| 123 |
+
self.batch_volume = batch_volume * batch_volume_factor
|
| 124 |
+
assert max(sequence_length) ** 2 < self.batch_volume, \
|
| 125 |
+
f"Cannot fit sequence {max(sequence_length)=} to {batch_volume=}"
|
| 126 |
+
|
| 127 |
+
if generator is None:
|
| 128 |
+
self.generator = self._init_generator()
|
| 129 |
+
else:
|
| 130 |
+
self.generator = generator
|
| 131 |
+
self.num_ranges = num_ranges
|
| 132 |
+
self.sequence_length = sequence_length
|
| 133 |
+
self.sequence_length_2 = self.sequence_length ** 2
|
| 134 |
+
self.batch_size = batch_size
|
| 135 |
+
|
| 136 |
+
bins = np.linspace(np.min(sequence_length), np.max(sequence_length) + 1, num_ranges)
|
| 137 |
+
digit_bins = np.digitize(sequence_length, bins=bins, right=True)
|
| 138 |
+
self.sequence_length_buckets = [torch.tensor(np.where(digit_bins == i)[0],
|
| 139 |
+
dtype=torch.int32) for i in range(num_ranges)]
|
| 140 |
+
self._prepared_batches = None
|
| 141 |
+
|
| 142 |
+
def _get_sliced_batches(self):
|
| 143 |
+
if self.shuffle:
|
| 144 |
+
# reshuffle the sequence length buckets.
|
| 145 |
+
for i in range(len(self.sequence_length_buckets)):
|
| 146 |
+
self.sequence_length_buckets[i] = self.sequence_length_buckets[i][torch.randperm(
|
| 147 |
+
len(self.sequence_length_buckets[i]), generator=self.generator)]
|
| 148 |
+
|
| 149 |
+
current_batch = []
|
| 150 |
+
current_batch_volume = 0
|
| 151 |
+
current_batch_size = 0
|
| 152 |
+
for i in range(self.num_ranges):
|
| 153 |
+
for idx in self.sequence_length_buckets[i]:
|
| 154 |
+
if (current_batch_volume + self.sequence_length_2[idx] >= self.batch_volume
|
| 155 |
+
or current_batch_size >= self.batch_size):
|
| 156 |
+
yield current_batch
|
| 157 |
+
current_batch = []
|
| 158 |
+
current_batch_volume = 0
|
| 159 |
+
current_batch_size = 0
|
| 160 |
+
current_batch.append(idx.item())
|
| 161 |
+
current_batch_volume += self.sequence_length_2[idx]
|
| 162 |
+
current_batch_size += 1
|
| 163 |
+
if len(current_batch) > 0:
|
| 164 |
+
yield current_batch
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def _init_generator():
|
| 168 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 169 |
+
generator = torch.Generator()
|
| 170 |
+
generator.manual_seed(seed)
|
| 171 |
+
return generator
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def _length(self):
|
| 175 |
+
if self._prepared_batches is None:
|
| 176 |
+
self._prepared_batches = list(self._get_sliced_batches())
|
| 177 |
+
return len(self._prepared_batches)
|
| 178 |
+
|
| 179 |
+
def __len__(self):
|
| 180 |
+
return self._length
|
| 181 |
+
|
| 182 |
+
def __iter__(self):
|
| 183 |
+
if self.shuffle:
|
| 184 |
+
# Then get the batches and serve them in random order
|
| 185 |
+
if self._prepared_batches is None:
|
| 186 |
+
self._prepared_batches = list(self._get_sliced_batches())
|
| 187 |
+
for batch_idx in torch.randperm(self._length, generator=self.generator):
|
| 188 |
+
yield self._prepared_batches[batch_idx]
|
| 189 |
+
self._prepared_batches = None # Destroy _prepared_batches to recreate it again in __len__
|
| 190 |
+
else:
|
| 191 |
+
for batch in self._get_sliced_batches():
|
| 192 |
+
yield batch
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ProtobindDataModule(pl.LightningDataModule):
|
| 196 |
+
"""PyTorch Lightning DataModule for Protobind-diffusion datasets.
|
| 197 |
+
|
| 198 |
+
This module handles the loading, processing, and batching of protein-ligand
|
| 199 |
+
data. It is designed to work with ESM-2 protein embeddings and tokenized
|
| 200 |
+
SMILES representations for ligands. The module manages data splitting,
|
| 201 |
+
feature loading, and provides DataLoaders with an efficient batching
|
| 202 |
+
strategy to minimize padding.
|
| 203 |
+
|
| 204 |
+
Key Features:
|
| 205 |
+
- Loads pre-computed ESM-2 protein embeddings.
|
| 206 |
+
- Utilizes tokenized SMILES for ligands via `ChemformerTokenizer`.
|
| 207 |
+
- Implements a `RandomizedBatchSampler` to create efficient, low-padding batches.
|
| 208 |
+
- Handles dataset splitting into training, validation, and test sets.
|
| 209 |
+
"""
|
| 210 |
+
MASK_VALUE = 0
|
| 211 |
+
|
| 212 |
+
def __init__(self, *,
|
| 213 |
+
data_dir: Path,
|
| 214 |
+
exp_dir: Path,
|
| 215 |
+
splitting_method: SplittingMethod,
|
| 216 |
+
batch_volume: int,
|
| 217 |
+
num_workers: int,
|
| 218 |
+
sequence_type: str = 'esm_zip',
|
| 219 |
+
esm_model_name: str = "esm2_t33_650M_UR50D",
|
| 220 |
+
max_size_batch: int = 16,
|
| 221 |
+
dataset_params: Optional[dict] = None,
|
| 222 |
+
float_type: str = 'float32'):
|
| 223 |
+
super().__init__()
|
| 224 |
+
"""Initializes the ProtobindDataModule.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
data_dir (Path): The directory containing the raw dataset files (e.g., data.csv, embeddings).
|
| 228 |
+
exp_dir (Path): The directory to save experiment artifacts, including data splits.
|
| 229 |
+
splitting_method (SplittingMethod): The method for splitting data (e.g., RANDOM).
|
| 230 |
+
batch_volume (int): The target batch volume for the RandomizedBatchSampler.
|
| 231 |
+
num_workers (int): The number of workers for the DataLoader.
|
| 232 |
+
sequence_type (str): The type of protein sequence data. Must be 'esm_zip'.
|
| 233 |
+
esm_model_name (str): The specific ESM model name for embeddings.
|
| 234 |
+
max_size_batch (int): The maximum number of samples in a batch.
|
| 235 |
+
dataset_params (Optional[dict]): Parameters for the underlying molecular dataset.
|
| 236 |
+
float_type (str): The floating-point precision to use.
|
| 237 |
+
"""
|
| 238 |
+
self.csv_path = data_dir / "data.csv"
|
| 239 |
+
self.categorical_mappings_path = data_dir / "categorical_mappings.json"
|
| 240 |
+
|
| 241 |
+
# Validate sequence type - only allow ESM variants
|
| 242 |
+
if sequence_type not in ['esm_zip']:
|
| 243 |
+
raise ValueError(f"DataModule only supports only 'esm_zip' sequence type, got: {sequence_type}")
|
| 244 |
+
|
| 245 |
+
# directory structure:
|
| 246 |
+
# output_dir / split / exp_dir_prefix
|
| 247 |
+
self.exp_dir: Path = Path(exp_dir)
|
| 248 |
+
self.split_dir: Path = self.exp_dir.parent
|
| 249 |
+
self.exp_data_dir: Path = self.split_dir.parent
|
| 250 |
+
self.data_dir = data_dir
|
| 251 |
+
|
| 252 |
+
if dataset_params is None:
|
| 253 |
+
dataset_params = {}
|
| 254 |
+
|
| 255 |
+
# Create simplified SMILES dataloader
|
| 256 |
+
self.molecular_dataloader = MolecularDataloaderSMILES(
|
| 257 |
+
data_dir=data_dir,
|
| 258 |
+
dataset_options=dataset_params,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
self.float_type = float_type
|
| 262 |
+
self.batch_volume = batch_volume
|
| 263 |
+
self.max_size_batch = max_size_batch
|
| 264 |
+
self.num_workers = num_workers
|
| 265 |
+
self.splitting_method = splitting_method
|
| 266 |
+
self.esm_model_name = esm_model_name
|
| 267 |
+
|
| 268 |
+
# Only support ESM embeddings (float type data)
|
| 269 |
+
self.sequence_dtype = getattr(torch, self.float_type)
|
| 270 |
+
|
| 271 |
+
# Will be initialized in setup()
|
| 272 |
+
self.train_dataset: Optional[torch.utils.data.Dataset] = None
|
| 273 |
+
self.val_dataset: Optional[torch.utils.data.Dataset] = None
|
| 274 |
+
self.test_dataset: Optional[torch.utils.data.Dataset] = None
|
| 275 |
+
|
| 276 |
+
self.datasets: Dict[str, pd.DataFrame] = {}
|
| 277 |
+
self.torch_datasets: Dict[str, torch.utils.data.Dataset] = {}
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def _read_df(csv_path: Path) -> pd.DataFrame:
|
| 281 |
+
_use_columns = ['smiles', 'sequence', 'log_IC50', 'log_Ki', 'log_Kd', 'log_EC50', 'label', 'split',
|
| 282 |
+
'cluster_smi']
|
| 283 |
+
df = pd.read_csv(csv_path, nrows=1)
|
| 284 |
+
_use_columns = df.columns.intersection(_use_columns)
|
| 285 |
+
|
| 286 |
+
dtypes = {"smiles": int, "sequence": int, "log_IC50": float,
|
| 287 |
+
"log_Ki": float, "log_Kd": float, "log_EC50": float,
|
| 288 |
+
"label": float, "split": str, "cluster_smi": str}
|
| 289 |
+
|
| 290 |
+
df = pd.read_csv(csv_path, dtype=dtypes, usecols=_use_columns)
|
| 291 |
+
return df
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def _read_df_and_compute_sequence_lengths(csv_path: Path, length_dict: dict) -> pd.DataFrame:
|
| 295 |
+
# to reduce RAM load only necessary columns
|
| 296 |
+
df = ProtobindDataModule._read_df(csv_path)
|
| 297 |
+
df['sequence_length'] = df["sequence"].map(length_dict)
|
| 298 |
+
|
| 299 |
+
# sort by sequence length to increase the batching efficiency.
|
| 300 |
+
df.sort_values(by="sequence_length", inplace=True)
|
| 301 |
+
return df
|
| 302 |
+
|
| 303 |
+
def check_splits_exist(self):
|
| 304 |
+
""" Tries to find that train-test split exist """
|
| 305 |
+
if (self.split_dir / "train.csv").exists():
|
| 306 |
+
assert (self.split_dir / "valid.csv").exists()
|
| 307 |
+
assert (self.split_dir / "test.csv").exists()
|
| 308 |
+
logger.info(f"train.csv/valid.csv/test.csv exist, "
|
| 309 |
+
f"no new splits will be created for {self.splitting_method}")
|
| 310 |
+
return True
|
| 311 |
+
|
| 312 |
+
return False
|
| 313 |
+
|
| 314 |
+
def prepare_data_split(self, seed=777, valid_fraction=0.1, test_fraction=0.1):
|
| 315 |
+
""" Create train.csv, val.csv and test.csv in the experiment dir """
|
| 316 |
+
|
| 317 |
+
if self.check_splits_exist():
|
| 318 |
+
return
|
| 319 |
+
|
| 320 |
+
# Check that data exists
|
| 321 |
+
for path in [self.csv_path, self.categorical_mappings_path]:
|
| 322 |
+
if not path.exists():
|
| 323 |
+
raise FileNotFoundError(
|
| 324 |
+
f"Could not find {path}. Please download the data.")
|
| 325 |
+
|
| 326 |
+
# load label data
|
| 327 |
+
data_df = pd.read_csv(self.csv_path)
|
| 328 |
+
|
| 329 |
+
# add clusters
|
| 330 |
+
distance_data = list(self.csv_path.parent.glob('all_smiles_sparse_*.npz'))
|
| 331 |
+
if len(distance_data) > 0:
|
| 332 |
+
logger.info(f"Calculating clusters for SMILES and distance data {distance_data[0]}")
|
| 333 |
+
clusters_smi = cluster_fpsim2(distance_data[0])
|
| 334 |
+
len_ = len(data_df)
|
| 335 |
+
data_df = data_df.merge(pd.Series(clusters_smi, name='cluster_smi'), left_on='smiles', right_index=True)
|
| 336 |
+
assert data_df.shape[0] == len_, (f"Failed to merge clusters, {len_=} {data_df.shape=}"
|
| 337 |
+
f" {clusters_smi.min()} {clusters_smi.max()}")
|
| 338 |
+
else:
|
| 339 |
+
raise FileNotFoundError(f'Could not find any all_smiles_sparse_*.npz in {str(self.csv_path.parent)}')
|
| 340 |
+
|
| 341 |
+
# Create splits
|
| 342 |
+
if self.splitting_method == SplittingMethod.RANDOM:
|
| 343 |
+
train, valid, test = split_at_random(data_df, valid_fraction=valid_fraction,
|
| 344 |
+
test_fraction=test_fraction, seed=seed)
|
| 345 |
+
else:
|
| 346 |
+
raise NotImplementedError(
|
| 347 |
+
f"Splitting method {self.splitting_method} is not implemented in simplified version.")
|
| 348 |
+
|
| 349 |
+
train.to_csv(self.split_dir / "train.csv", index=False)
|
| 350 |
+
valid.to_csv(self.split_dir / "valid.csv", index=False)
|
| 351 |
+
test.to_csv(self.split_dir / "test.csv", index=False)
|
| 352 |
+
|
| 353 |
+
def prepare_data(self, **kwargs):
|
| 354 |
+
|
| 355 |
+
if kwargs.get('load', False):
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
if self.exp_dir.exists():
|
| 359 |
+
logger.info(f"Experiment directory {self.exp_dir} already exists. All existing files "
|
| 360 |
+
f" will be kept. To create new data/split remove {self.exp_data_dir} or {self.split_dir}")
|
| 361 |
+
self.exp_dir.mkdir(parents=True, exist_ok=True)
|
| 362 |
+
|
| 363 |
+
# Make train-test split
|
| 364 |
+
default_split_kwargs = {'seed': 777,
|
| 365 |
+
'valid_fraction': 0.1,
|
| 366 |
+
'test_fraction': 0.1,
|
| 367 |
+
}
|
| 368 |
+
# update from kwargs
|
| 369 |
+
for key in default_split_kwargs.keys():
|
| 370 |
+
if key in kwargs:
|
| 371 |
+
default_split_kwargs[key] = kwargs[key]
|
| 372 |
+
# Create new split or skip if exist
|
| 373 |
+
self.prepare_data_split(**default_split_kwargs)
|
| 374 |
+
|
| 375 |
+
# Prepare smiles (simplified - only tokenized smiles)
|
| 376 |
+
self.molecular_dataloader.prepare_molecular_features()
|
| 377 |
+
|
| 378 |
+
def setup(self, stage=None):
|
| 379 |
+
"""Loads and prepares the datasets for a given stage.
|
| 380 |
+
|
| 381 |
+
This method is called by PyTorch Lightning. It performs the following steps:
|
| 382 |
+
1. Loads molecular features (tokenized SMILES).
|
| 383 |
+
2. Loads protein features (pre-computed ESM embeddings).
|
| 384 |
+
3. Loads data splits (train/val/test) from CSV files.
|
| 385 |
+
4. Initializes the PyTorch Datasets for each split.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
stage (str, optional): The stage to setup ('fit', 'validate', 'test', 'predict').
|
| 389 |
+
"""
|
| 390 |
+
logger.info("Loading molecular features")
|
| 391 |
+
|
| 392 |
+
# Load molecular features (simplified - only SMILES)
|
| 393 |
+
self.molecular_dataloader.load_molecular_features()
|
| 394 |
+
|
| 395 |
+
# Load protein features (only ESM embeddings)
|
| 396 |
+
logger.info(f"Loading protein features {self.esm_model_name}")
|
| 397 |
+
prot_embbeding_pt = self.data_dir / f'all_prots_{self.esm_model_name}.pt'
|
| 398 |
+
|
| 399 |
+
if prot_embbeding_pt.exists():
|
| 400 |
+
self.idx_to_sequence_data = torch.load(prot_embbeding_pt, map_location='cpu', weights_only=False)
|
| 401 |
+
length_dict = {idx: emb.shape[0] for idx, emb in self.idx_to_sequence_data.items()}
|
| 402 |
+
self.sequence_embedding_dim = next(iter(self.idx_to_sequence_data.values())).shape[1]
|
| 403 |
+
else:
|
| 404 |
+
raise FileNotFoundError(
|
| 405 |
+
f"Packed proteins `all_prots_{self.esm_model_name}.pt` is not found in {self.data_dir}")
|
| 406 |
+
|
| 407 |
+
# load data. Use integer dtypes for categorical features and float for labels.
|
| 408 |
+
logger.info("Loading activity table")
|
| 409 |
+
|
| 410 |
+
self.datasets = dict(zip(["train", "val", "test"],
|
| 411 |
+
[self._read_df_and_compute_sequence_lengths(self.split_dir / f"{split}.csv",
|
| 412 |
+
length_dict)
|
| 413 |
+
for split in ["train", "valid", "test"]]))
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# initialise self.train_dataset, self.val_dataset, self.test_dataset
|
| 417 |
+
for ds in ['train', 'val', 'test']:
|
| 418 |
+
df_ds = self.datasets[ds]
|
| 419 |
+
assert len(ds) > 0, f"{ds=} is empty"
|
| 420 |
+
ds_proto = self.create_dataset(df_ds)
|
| 421 |
+
ds_proto._is_train = (ds == 'train')
|
| 422 |
+
self.torch_datasets[ds] = ds_proto
|
| 423 |
+
|
| 424 |
+
def create_dataset(self, df, **kwargs):
|
| 425 |
+
dataset_kwargs = self.molecular_dataloader.dataset_kwargs
|
| 426 |
+
dataset_class = DatasetMolecularEmbeddings
|
| 427 |
+
|
| 428 |
+
cluster_smi = None
|
| 429 |
+
sample_smiles = dataset_kwargs.get('sample_smiles', False)
|
| 430 |
+
if sample_smiles:
|
| 431 |
+
cluster_smi = df['cluster_smi'].values
|
| 432 |
+
|
| 433 |
+
logger.info(f"Creating dataset: using {dataset_class=}")
|
| 434 |
+
ds_proto = dataset_class(
|
| 435 |
+
sequence_embedding=(self.idx_to_sequence_data),
|
| 436 |
+
smiles_embeddings=self.molecular_dataloader.get_features(),
|
| 437 |
+
sequences=df['sequence'].values,
|
| 438 |
+
sequences_length=df['sequence_length'].values,
|
| 439 |
+
smiles=df['smiles'].values,
|
| 440 |
+
dtype=self.float_type,
|
| 441 |
+
cluster_smi=cluster_smi,
|
| 442 |
+
**dataset_kwargs,
|
| 443 |
+
**kwargs,
|
| 444 |
+
)
|
| 445 |
+
return ds_proto
|
| 446 |
+
|
| 447 |
+
def get_dataloader(self, dataset, shuffle, use_sampler=True, pin_memory=True):
|
| 448 |
+
if use_sampler:
|
| 449 |
+
sampler = RandomizedBatchSampler(sequence_length=dataset.sequences_length,
|
| 450 |
+
shuffle=shuffle,
|
| 451 |
+
batch_volume=self.batch_volume,
|
| 452 |
+
batch_size=self.max_size_batch)
|
| 453 |
+
return DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.collate_fn,
|
| 454 |
+
num_workers=self.num_workers, pin_memory=pin_memory)
|
| 455 |
+
else:
|
| 456 |
+
return DataLoader(dataset=dataset, collate_fn=dataset.collate_fn, batch_size=self.max_size_batch,
|
| 457 |
+
num_workers=self.num_workers, pin_memory=pin_memory, shuffle=shuffle)
|
| 458 |
+
|
| 459 |
+
def train_dataloader(self, use_sampler=True, shuffle=True):
|
| 460 |
+
return self.get_dataloader(self.torch_datasets['train'], shuffle=shuffle, use_sampler=use_sampler)
|
| 461 |
+
|
| 462 |
+
def val_dataloader(self, use_sampler=True, shuffle=False):
|
| 463 |
+
return self.get_dataloader(self.torch_datasets['val'], shuffle=shuffle, use_sampler=use_sampler)
|
| 464 |
+
|
| 465 |
+
def test_dataloader(self, use_sampler=True, shuffle=False):
|
| 466 |
+
return self.get_dataloader(self.torch_datasets['test'], shuffle=shuffle, use_sampler=use_sampler)
|
| 467 |
+
|
| 468 |
+
def predict_dataloader(self, dataset='test', use_sampler=False, shuffle=False):
|
| 469 |
+
return self.get_dataloader(self.torch_datasets[dataset], shuffle=shuffle, use_sampler=use_sampler)
|
| 470 |
+
|
| 471 |
+
def get_smiles_embedding_dim(self):
|
| 472 |
+
return self.molecular_dataloader.embedding_size
|
| 473 |
+
|
| 474 |
+
def get_sequence_embedding_dim(self):
|
| 475 |
+
return self.sequence_embedding_dim
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class DatasetNumpy(Dataset):
|
| 479 |
+
""" Dataset for feeding model with sequences and ligands embeddings """
|
| 480 |
+
|
| 481 |
+
def __init__(self, *, sequence_embedding: Tuple[np.array, np.array],
|
| 482 |
+
smiles_embeddings: np.ndarray,
|
| 483 |
+
sequences: np.ndarray,
|
| 484 |
+
sequences_length: np.ndarray,
|
| 485 |
+
smiles: np.ndarray,
|
| 486 |
+
dtype='float16',
|
| 487 |
+
**kwargs,
|
| 488 |
+
):
|
| 489 |
+
"""
|
| 490 |
+
Args:
|
| 491 |
+
sequence_embedding: embedding for sequences - 1 per each sequence
|
| 492 |
+
smiles_embeddings: embedding for smiles - 1 per each smile
|
| 493 |
+
sequences: sequence label in the dataset - 1 per sample
|
| 494 |
+
sequences_length: sequence length in the dataset - 1 per sample
|
| 495 |
+
smiles: smile label in the dataset - 1 per sample
|
| 496 |
+
"""
|
| 497 |
+
assert len(sequences) == len(sequences_length), f"{len(sequences)=} {len(sequences_length)=}"
|
| 498 |
+
assert len(sequences) == len(smiles), f"{len(sequences)=} {len(smiles)=}"
|
| 499 |
+
|
| 500 |
+
self.data_sequence = sequence_embedding
|
| 501 |
+
self.smiles_embeddings = self.init_smiles_embeddings(smiles_embeddings)
|
| 502 |
+
self.sequences_length = sequences_length
|
| 503 |
+
self.sequences = sequences
|
| 504 |
+
self.smiles = smiles
|
| 505 |
+
self.float_type = getattr(torch, dtype)
|
| 506 |
+
|
| 507 |
+
# Only support ESM embeddings (float type)
|
| 508 |
+
self.sequence_dtype = self.float_type
|
| 509 |
+
self._is_train = False # this parameter is assigned in during model.setup()
|
| 510 |
+
|
| 511 |
+
# SMILES SAMPLER
|
| 512 |
+
sample_smiles = kwargs.get('sample_smiles', False)
|
| 513 |
+
self.cluster_smiles = kwargs.get('cluster_smi', None)
|
| 514 |
+
self.smiles_to_cluster = None
|
| 515 |
+
if sample_smiles:
|
| 516 |
+
self.group_smiles(self.cluster_smiles)
|
| 517 |
+
self.get_smiles_id = self._smiles_id_sample
|
| 518 |
+
else:
|
| 519 |
+
self.get_smiles_id = self._smiles_id_as_ind
|
| 520 |
+
|
| 521 |
+
def init_smiles_embeddings(self, smiles_embeddings):
|
| 522 |
+
return smiles_embeddings
|
| 523 |
+
|
| 524 |
+
def group_smiles(self, clusters):
|
| 525 |
+
""" for each sequence group similar smiles to list for random sampling during training """
|
| 526 |
+
|
| 527 |
+
len_ = len(self.sequences)
|
| 528 |
+
df = pd.DataFrame(data={'smiles': self.smiles, 'sequence': self.sequences, 'cluster_smi': clusters,
|
| 529 |
+
'sequences_length': self.sequences_length}
|
| 530 |
+
).groupby(['cluster_smi', 'sequence', 'sequences_length'], as_index=False).agg(list)
|
| 531 |
+
self.smiles_to_cluster = df['smiles'].values
|
| 532 |
+
self.sequences = df['sequence'].values
|
| 533 |
+
self.cluster_smiles = df['cluster_smi'].values
|
| 534 |
+
self.sequences_length = df['sequences_length'].values
|
| 535 |
+
logger.info(f"Sampling from similar smiles is ON, dataset size reduced from {len_} to {len(self.sequences)}")
|
| 536 |
+
|
| 537 |
+
def _smiles_id_as_ind(self, idx: int) -> int:
|
| 538 |
+
""" Get smiles is from array self.smiles """
|
| 539 |
+
return self.smiles[idx]
|
| 540 |
+
|
| 541 |
+
def _smiles_id_sample(self, idx) -> int:
|
| 542 |
+
""" Sample smile id from grouped SMILES from same cluster"""
|
| 543 |
+
return np.random.choice(self.smiles_to_cluster[idx])
|
| 544 |
+
|
| 545 |
+
def __len__(self) -> int:
|
| 546 |
+
# the number of entries in the dataset
|
| 547 |
+
return len(self.sequences)
|
| 548 |
+
|
| 549 |
+
def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray, int]:
|
| 550 |
+
|
| 551 |
+
seq_id = self.sequences[idx]
|
| 552 |
+
smi_id = self.get_smiles_id(idx)
|
| 553 |
+
|
| 554 |
+
return (self.parametrize_sequence(seq_id),
|
| 555 |
+
self.parametrize_smiles(smi_id),
|
| 556 |
+
self.sequences_length[idx])
|
| 557 |
+
|
| 558 |
+
def parametrize_smiles(self, smiles_id: int) -> np.array:
|
| 559 |
+
return self.smiles_embeddings[smiles_id]
|
| 560 |
+
|
| 561 |
+
def parametrize_sequence(self, sequence_id: int) -> np.array:
|
| 562 |
+
return self.data_sequence[sequence_id]
|
| 563 |
+
|
| 564 |
+
@staticmethod
|
| 565 |
+
def _collate_fn_pack(batch):
|
| 566 |
+
""" Pack dataset samples to sequences of sequences, smiles, sequence_lengths """
|
| 567 |
+
return zip(*batch)
|
| 568 |
+
|
| 569 |
+
def _pad_sequence(self, sequences: List[np.ndarray]) -> torch.Tensor:
|
| 570 |
+
return pad_sequence([torch.tensor(s, dtype=self.sequence_dtype) for s in sequences], batch_first=True,
|
| 571 |
+
padding_value=ProtobindDataModule.MASK_VALUE)
|
| 572 |
+
|
| 573 |
+
def collate_fn(self, batch: Tuple[np.ndarray, np.ndarray, int ]) -> Tuple[
|
| 574 |
+
Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 575 |
+
"""Collates samples into a single batch, padding sequences to the same length.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
batch : A tuple of samples, where each sample is the output of `__getitem__`.
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
Tuple: A tuple containing batched tensors:
|
| 582 |
+
- ((torch.Tensor, torch.Tensor)): A tuple of padded protein sequences
|
| 583 |
+
and a tensor of their original lengths.
|
| 584 |
+
- (torch.Tensor): A batch of SMILES embeddings.
|
| 585 |
+
"""
|
| 586 |
+
|
| 587 |
+
sequences, smiles, sequence_lengths = self._collate_fn_pack(batch)
|
| 588 |
+
|
| 589 |
+
padded_sequences = self._pad_sequence(sequences)
|
| 590 |
+
|
| 591 |
+
return ((padded_sequences, torch.tensor(sequence_lengths, dtype=torch.int32)),
|
| 592 |
+
torch.tensor(np.array(smiles), dtype=self.float_type))
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
class DatasetMolecularEmbeddings(DatasetNumpy):
|
| 596 |
+
"""A dataset for masked diffusion models using protein embeddings and tokenized SMILES.
|
| 597 |
+
|
| 598 |
+
This class extends `DatasetNumpy` to handle variable-length, tokenized SMILES
|
| 599 |
+
representations from a `RandomizedSmilesDataset`. It overrides methods for
|
| 600 |
+
SMILES parameterization and batch collation to support this token-based approach,
|
| 601 |
+
which is required for diffusion models.
|
| 602 |
+
"""
|
| 603 |
+
|
| 604 |
+
def parametrize_smiles(self, smiles_id: int) -> Tuple[np.array, int]:
|
| 605 |
+
mol = self.smiles_embeddings[smiles_id]
|
| 606 |
+
return mol, len(mol)
|
| 607 |
+
|
| 608 |
+
def __getitem__(self, idx) -> Tuple[np.ndarray, np.array, int, int, int, int]:
|
| 609 |
+
"""Retrieves a single data sample with tokenized SMILES.
|
| 610 |
+
|
| 611 |
+
Unlike the parent class, this method returns the tokenized SMILES
|
| 612 |
+
and its length instead of a fixed-size embedding.
|
| 613 |
+
"""
|
| 614 |
+
seq_id = self.sequences[idx]
|
| 615 |
+
smi_id = self.smiles[idx]
|
| 616 |
+
return (self.parametrize_sequence(seq_id),) + self.parametrize_smiles(smi_id) + (
|
| 617 |
+
self.sequences_length[idx], seq_id, smi_id)
|
| 618 |
+
|
| 619 |
+
def init_smiles_embeddings(self, smiles_embeddings):
|
| 620 |
+
if isinstance(smiles_embeddings, RandomizedSmilesDataset):
|
| 621 |
+
return smiles_embeddings
|
| 622 |
+
else:
|
| 623 |
+
raise ValueError("version only supports RandomizedSmilesDataset")
|
| 624 |
+
|
| 625 |
+
def collate_fn(self, batch: List[Tuple[np.ndarray, np.array, int, int, int, int]]) -> Tuple[
|
| 626 |
+
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
|
| 627 |
+
torch.Tensor, torch.Tensor]:
|
| 628 |
+
|
| 629 |
+
"""Collates samples into a batch, padding both protein and SMILES sequences.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
batch (list): A list of samples, where each sample is the output of __getitem__.
|
| 633 |
+
|
| 634 |
+
Returns:
|
| 635 |
+
Tuple: A tuple containing the final batched tensors for the model:
|
| 636 |
+
- ((torch.Tensor, torch.Tensor)): Padded protein sequences and their lengths.
|
| 637 |
+
- ((torch.Tensor, torch.Tensor)): Padded tokenized SMILES and their lengths.
|
| 638 |
+
- (torch.Tensor): A batch of sequence IDs.
|
| 639 |
+
- (torch.Tensor): A batch of SMILES IDs.
|
| 640 |
+
"""
|
| 641 |
+
|
| 642 |
+
sequences, atom, atom_lengths, sequence_lengths, seq_id, smi_id \
|
| 643 |
+
= self._collate_fn_pack(batch)
|
| 644 |
+
|
| 645 |
+
padded_sequences = self._pad_sequence(sequences) # padding proteins sequences
|
| 646 |
+
padded_atom = pad_sequence([s.to(dtype=self.float_type) for s in atom], batch_first=True)
|
| 647 |
+
atom_lengths = torch.tensor(atom_lengths, dtype=torch.int32)
|
| 648 |
+
|
| 649 |
+
return ((padded_sequences, torch.tensor(sequence_lengths, dtype=torch.int32)),
|
| 650 |
+
(padded_atom, atom_lengths),
|
| 651 |
+
torch.tensor(seq_id, dtype=torch.int32),
|
| 652 |
+
torch.tensor(smi_id, dtype=torch.int32),
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
class MolecularDataloaderSMILES(object):
|
| 657 |
+
"""
|
| 658 |
+
molecular dataloader that only supports tokenized SMILES
|
| 659 |
+
with ChemformerTokenizer for masked diffusion models.
|
| 660 |
+
"""
|
| 661 |
+
|
| 662 |
+
def __init__(self, *,
|
| 663 |
+
data_dir: Path,
|
| 664 |
+
dataset_options: Optional[dict] = None):
|
| 665 |
+
"""
|
| 666 |
+
Args:
|
| 667 |
+
data_dir: path to data folder containing tokenizer files and dict with all smiles and fasta sequences
|
| 668 |
+
dataset_options: dictionary with additional parameters used to create pytorch Dataset
|
| 669 |
+
"""
|
| 670 |
+
self.data_dir = data_dir
|
| 671 |
+
if dataset_options is None:
|
| 672 |
+
logger.info('Setting tokenizer file name to tokenizer_smiles_diffusion.json')
|
| 673 |
+
dataset_options = {'tokenizer_json_name': 'tokenizer_smiles_diffusion'}
|
| 674 |
+
self.dataset_options = dataset_options
|
| 675 |
+
|
| 676 |
+
self.tokenizer_path = self.data_dir / f"{dataset_options['tokenizer_json_name']}.json"
|
| 677 |
+
self.tokenizer = ChemformerTokenizer(filename=str(self.tokenizer_path))
|
| 678 |
+
self.randomize = dataset_options.get('randomize', False)
|
| 679 |
+
self.smiles_embedding_dim = 1 # For tokenized SMILES, embedding dim is 1
|
| 680 |
+
self.baseline_dim = 0 # this version doesn't support baseline features
|
| 681 |
+
|
| 682 |
+
def prepare_molecular_features(self):
|
| 683 |
+
"""Prepare molecular features"""
|
| 684 |
+
if not self.tokenizer_path.exists():
|
| 685 |
+
raise FileNotFoundError(
|
| 686 |
+
f"Could not find tokenizer at {self.tokenizer_path}. Please ensure the tokenizer file exists.")
|
| 687 |
+
logger.info(f"Found ChemformerTokenizer at {self.tokenizer_path}")
|
| 688 |
+
|
| 689 |
+
def load_molecular_features(self):
|
| 690 |
+
"""Load molecular features - loads SMILES mappings"""
|
| 691 |
+
categorical_mappings_path = self.data_dir / 'categorical_mappings.json'
|
| 692 |
+
if not categorical_mappings_path.exists():
|
| 693 |
+
raise FileNotFoundError(f"categorical_mappings.json not found in data_dir: {self.data_dir}")
|
| 694 |
+
|
| 695 |
+
self.smiles_dataset = RandomizedSmilesDataset.from_json(
|
| 696 |
+
categorical_mappings_path,
|
| 697 |
+
tokenizer=self.tokenizer,
|
| 698 |
+
randomize=self.randomize
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
def get_features(self):
|
| 702 |
+
"""Get the SMILES dataset for tokenized molecular features"""
|
| 703 |
+
return self.smiles_dataset
|
| 704 |
+
|
| 705 |
+
@property
|
| 706 |
+
def dataset_kwargs(self):
|
| 707 |
+
"""Return dataset options for creating pytorch datasets"""
|
| 708 |
+
return self.dataset_options
|
| 709 |
+
|
| 710 |
+
@property
|
| 711 |
+
def embedding_size(self):
|
| 712 |
+
"""Get embedding size for tokenized SMILES"""
|
| 713 |
+
return self.smiles_embedding_dim
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
class InferenceDataset(Dataset):
|
| 717 |
+
"""Creates a dataset for running inference on a single protein embedding.
|
| 718 |
+
|
| 719 |
+
This utility dataset repeatedly yields the same batch, created by expanding
|
| 720 |
+
a single input embedding. It's designed for generating a large number of
|
| 721 |
+
ligand samples for one protein target without a traditional dataset structure.
|
| 722 |
+
"""
|
| 723 |
+
def __init__(self, embedding: torch.Tensor, batch_size: int, n_batches: int):
|
| 724 |
+
"""Initializes the inference dataset.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
embedding (torch.Tensor): The single protein embedding tensor to be used.
|
| 728 |
+
batch_size (int): The number of times to repeat the embedding in each batch.
|
| 729 |
+
n_batches (int): The total number of identical batches the dataset should yield.
|
| 730 |
+
"""
|
| 731 |
+
self.embedding_single = embedding
|
| 732 |
+
self.batch_size = batch_size
|
| 733 |
+
self.n_batches = n_batches
|
| 734 |
+
self.seq_len = embedding.shape[1]
|
| 735 |
+
|
| 736 |
+
def __len__(self) -> int:
|
| 737 |
+
return self.n_batches
|
| 738 |
+
|
| 739 |
+
def __getitem__(self, idx: int) -> Tuple:
|
| 740 |
+
"""Generates a full batch ready for model inference.
|
| 741 |
+
|
| 742 |
+
Note: This method ignores the `idx` argument and always returns the same
|
| 743 |
+
batch, which is constructed by expanding the stored protein embedding.
|
| 744 |
+
It includes dummy values to match the data structure expected by the model.
|
| 745 |
+
|
| 746 |
+
Returns:
|
| 747 |
+
Tuple: A tuple containing pre-batched tensors:
|
| 748 |
+
- ((torch.Tensor, torch.Tensor)): Expanded protein embeddings and their lengths.
|
| 749 |
+
- (torch.Tensor): A dummy NaN tensor (placeholder for SMILES).
|
| 750 |
+
- (torch.Tensor): A batch of placeholder sequence IDs (-1).
|
| 751 |
+
- (torch.Tensor): A dummy NaN tensor (placeholder for smiles IDs).
|
| 752 |
+
"""
|
| 753 |
+
embedding = self.embedding_single.expand(self.batch_size, -1, -1).contiguous()
|
| 754 |
+
lengths = torch.full((self.batch_size,), self.seq_len, dtype=torch.int32)
|
| 755 |
+
seq_ids = torch.full((self.batch_size,), -1, dtype=torch.int32) #seq_ids dont exist for new sequences
|
| 756 |
+
return (
|
| 757 |
+
(embedding, lengths),
|
| 758 |
+
torch.tensor(float('nan')),
|
| 759 |
+
seq_ids,
|
| 760 |
+
torch.tensor(float('nan')),
|
| 761 |
+
)
|
protobind_diff/decoder_rope.py
ADDED
|
@@ -0,0 +1,769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from math import pi
|
| 3 |
+
import typing
|
| 4 |
+
from typing import Tuple, Optional, Literal
|
| 5 |
+
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.amp import autocast
|
| 12 |
+
from torch.nn import Module, ModuleList
|
| 13 |
+
from torch import nn, einsum, broadcast_tensors, Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
#################################################################################
|
| 18 |
+
# Rotary Encoding #
|
| 19 |
+
#################################################################################
|
| 20 |
+
|
| 21 |
+
# helper functions
|
| 22 |
+
|
| 23 |
+
def exists(val):
|
| 24 |
+
return val is not None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def default(val, d):
|
| 28 |
+
return val if exists(val) else d
|
| 29 |
+
|
| 30 |
+
def slice_at_dim(t, dim_slice: slice, *, dim):
|
| 31 |
+
dim += (t.ndim if dim < 0 else 0)
|
| 32 |
+
colons = [slice(None)] * t.ndim
|
| 33 |
+
colons[dim] = dim_slice
|
| 34 |
+
return t[tuple(colons)]
|
| 35 |
+
|
| 36 |
+
# rotary embedding helper functions
|
| 37 |
+
|
| 38 |
+
def rotate_half(x):
|
| 39 |
+
"""Splits the last dimension of a tensor, swaps halves, and negates the first half."""
|
| 40 |
+
x = rearrange(x, '... (d r) -> ... d r', r=2)
|
| 41 |
+
x1, x2 = x.unbind(dim=-1)
|
| 42 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 43 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@autocast('cuda', enabled=False)
|
| 47 |
+
def apply_rotary_emb(
|
| 48 |
+
freqs,
|
| 49 |
+
t,
|
| 50 |
+
start_index=0,
|
| 51 |
+
scale=1.,
|
| 52 |
+
seq_dim=-2,
|
| 53 |
+
freqs_seq_dim=None
|
| 54 |
+
):
|
| 55 |
+
"""Applies rotary positional embeddings to a given tensor.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
freqs (torch.Tensor): The rotary frequencies.
|
| 59 |
+
t (torch.Tensor): The tensor to apply embeddings to (e.g., queries or keys).
|
| 60 |
+
start_index (int): The feature dimension index to start applying rotations from.
|
| 61 |
+
scale (float): A scaling factor, used for xPos.
|
| 62 |
+
seq_dim (int): The sequence dimension of the input tensor `t`.
|
| 63 |
+
freqs_seq_dim (Optional[int]): The sequence dimension of the freqs tensor.
|
| 64 |
+
"""
|
| 65 |
+
dtype = t.dtype
|
| 66 |
+
|
| 67 |
+
if not exists(freqs_seq_dim):
|
| 68 |
+
if freqs.ndim == 2 or t.ndim == 3:
|
| 69 |
+
freqs_seq_dim = 0
|
| 70 |
+
|
| 71 |
+
if t.ndim == 3 or exists(freqs_seq_dim):
|
| 72 |
+
seq_len = t.shape[seq_dim]
|
| 73 |
+
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
|
| 74 |
+
|
| 75 |
+
rot_dim = freqs.shape[-1]
|
| 76 |
+
end_index = start_index + rot_dim
|
| 77 |
+
|
| 78 |
+
assert rot_dim <= t.shape[
|
| 79 |
+
-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
| 80 |
+
|
| 81 |
+
# Split t into three parts: left, middle (to be transformed), and right
|
| 82 |
+
t_left = t[..., :start_index]
|
| 83 |
+
t_middle = t[..., start_index:end_index]
|
| 84 |
+
t_right = t[..., end_index:]
|
| 85 |
+
|
| 86 |
+
# Apply rotary embeddings without modifying t in place
|
| 87 |
+
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
| 88 |
+
|
| 89 |
+
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
| 90 |
+
|
| 91 |
+
return out.type(dtype)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# learned rotation helpers
|
| 95 |
+
|
| 96 |
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
| 97 |
+
if exists(freq_ranges):
|
| 98 |
+
rotations = einsum('..., f -> ... f', rotations, freq_ranges)
|
| 99 |
+
rotations = rearrange(rotations, '... r f -> ... (r f)')
|
| 100 |
+
|
| 101 |
+
rotations = repeat(rotations, '... n -> ... (n r)', r=2)
|
| 102 |
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# classes
|
| 106 |
+
|
| 107 |
+
class RotaryEmbedding(Module):
|
| 108 |
+
"""
|
| 109 |
+
original paper: https://arxiv.org/abs/2104.09864
|
| 110 |
+
rescale rotary embeddings to longer sequence length without fine-tuning
|
| 111 |
+
code source: https://github.com/lucidrains/rotary-embedding-torch
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
dim,
|
| 117 |
+
custom_freqs: Tensor | None = None,
|
| 118 |
+
freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang',
|
| 119 |
+
theta=10000,
|
| 120 |
+
max_freq=10,
|
| 121 |
+
num_freqs=1,
|
| 122 |
+
learned_freq=False,
|
| 123 |
+
use_xpos=False,
|
| 124 |
+
xpos_scale_base=512,
|
| 125 |
+
interpolate_factor=1.,
|
| 126 |
+
theta_rescale_factor=1.,
|
| 127 |
+
seq_before_head_dim=False,
|
| 128 |
+
cache_if_possible=True,
|
| 129 |
+
cache_max_seq_len=8192
|
| 130 |
+
):
|
| 131 |
+
super().__init__()
|
| 132 |
+
"""Initializes the RotaryEmbedding module.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
dim (int): The feature dimension to apply rotary embeddings to.
|
| 136 |
+
custom_freqs ([Tensor]): An optional tensor of custom frequencies.
|
| 137 |
+
freqs_for : The method for generating
|
| 138 |
+
frequencies. 'lang' is standard for transformers.
|
| 139 |
+
theta (int): A core hyperparameter for frequency calculation.
|
| 140 |
+
learned_freq (bool): If True, the frequencies are trainable parameters.
|
| 141 |
+
use_xpos (bool): If True, enables the xPos (extrapolatable) variant.
|
| 142 |
+
interpolate_factor (float): A factor for positional interpolation, which
|
| 143 |
+
can help with length generalization.
|
| 144 |
+
cache_if_possible (bool): If True, caches calculated frequencies for efficiency.
|
| 145 |
+
"""
|
| 146 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 147 |
+
|
| 148 |
+
self.freqs_for = freqs_for
|
| 149 |
+
|
| 150 |
+
if exists(custom_freqs):
|
| 151 |
+
freqs = custom_freqs
|
| 152 |
+
elif freqs_for == 'lang':
|
| 153 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 154 |
+
elif freqs_for == 'pixel':
|
| 155 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 156 |
+
elif freqs_for == 'constant':
|
| 157 |
+
freqs = torch.ones(num_freqs).float()
|
| 158 |
+
|
| 159 |
+
self.cache_if_possible = cache_if_possible
|
| 160 |
+
self.cache_max_seq_len = cache_max_seq_len
|
| 161 |
+
|
| 162 |
+
self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent=False)
|
| 163 |
+
self.cached_freqs_seq_len = 0
|
| 164 |
+
|
| 165 |
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
| 166 |
+
|
| 167 |
+
self.learned_freq = learned_freq
|
| 168 |
+
|
| 169 |
+
# dummy for device
|
| 170 |
+
|
| 171 |
+
self.register_buffer('dummy', torch.tensor(0), persistent=False)
|
| 172 |
+
|
| 173 |
+
# default sequence dimension
|
| 174 |
+
|
| 175 |
+
self.seq_before_head_dim = seq_before_head_dim
|
| 176 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
| 177 |
+
|
| 178 |
+
# interpolation factors
|
| 179 |
+
|
| 180 |
+
assert interpolate_factor >= 1.
|
| 181 |
+
self.interpolate_factor = interpolate_factor
|
| 182 |
+
|
| 183 |
+
# xpos
|
| 184 |
+
|
| 185 |
+
self.use_xpos = use_xpos
|
| 186 |
+
|
| 187 |
+
if not use_xpos:
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 191 |
+
self.scale_base = xpos_scale_base
|
| 192 |
+
|
| 193 |
+
self.register_buffer('scale', scale, persistent=False)
|
| 194 |
+
self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent=False)
|
| 195 |
+
self.cached_scales_seq_len = 0
|
| 196 |
+
|
| 197 |
+
# add apply_rotary_emb as static method
|
| 198 |
+
|
| 199 |
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def device(self):
|
| 203 |
+
return self.dummy.device
|
| 204 |
+
|
| 205 |
+
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
| 206 |
+
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
| 207 |
+
|
| 208 |
+
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
|
| 209 |
+
"""Applies rotary embeddings to a single tensor (queries or keys).
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
t (torch.Tensor): The input tensor (queries or keys).
|
| 213 |
+
seq_dim : The sequence dimension of the tensor.
|
| 214 |
+
offset (int): An offset for the position sequence, used for caching.
|
| 215 |
+
scale (Optional[float]): A scaling factor, required if using xPos.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
torch.Tensor: The tensor with rotary embeddings applied.
|
| 219 |
+
"""
|
| 220 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 221 |
+
|
| 222 |
+
assert not self.use_xpos or exists(
|
| 223 |
+
scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
|
| 224 |
+
|
| 225 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
| 226 |
+
|
| 227 |
+
seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
|
| 228 |
+
|
| 229 |
+
freqs = self.forward(seq, seq_len=seq_len, offset=offset)
|
| 230 |
+
|
| 231 |
+
if seq_dim == -3:
|
| 232 |
+
freqs = rearrange(freqs, 'n d -> n 1 d')
|
| 233 |
+
|
| 234 |
+
return apply_rotary_emb(freqs, t, scale=default(scale, 1.), seq_dim=seq_dim)
|
| 235 |
+
|
| 236 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
| 237 |
+
dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
|
| 238 |
+
|
| 239 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
| 240 |
+
assert q_len <= k_len
|
| 241 |
+
|
| 242 |
+
q_scale = k_scale = 1.
|
| 243 |
+
|
| 244 |
+
if self.use_xpos:
|
| 245 |
+
seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
|
| 246 |
+
|
| 247 |
+
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
|
| 248 |
+
k_scale = self.get_scale(seq).type(dtype)
|
| 249 |
+
|
| 250 |
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
|
| 251 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale ** -1)
|
| 252 |
+
|
| 253 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 254 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 255 |
+
|
| 256 |
+
return rotated_q, rotated_k
|
| 257 |
+
|
| 258 |
+
def rotate_queries_and_keys(self, q, k, seq_dim=None):
|
| 259 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 260 |
+
|
| 261 |
+
assert self.use_xpos
|
| 262 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
| 263 |
+
|
| 264 |
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
| 265 |
+
|
| 266 |
+
freqs = self.forward(seq, seq_len=seq_len)
|
| 267 |
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
| 268 |
+
|
| 269 |
+
if seq_dim == -3:
|
| 270 |
+
freqs = rearrange(freqs, 'n d -> n 1 d')
|
| 271 |
+
scale = rearrange(scale, 'n d -> n 1 d')
|
| 272 |
+
|
| 273 |
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
|
| 274 |
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale ** -1, seq_dim=seq_dim)
|
| 275 |
+
|
| 276 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 277 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 278 |
+
|
| 279 |
+
return rotated_q, rotated_k
|
| 280 |
+
|
| 281 |
+
def get_scale(
|
| 282 |
+
self,
|
| 283 |
+
t: Tensor,
|
| 284 |
+
seq_len = None,
|
| 285 |
+
offset=0
|
| 286 |
+
):
|
| 287 |
+
assert self.use_xpos
|
| 288 |
+
|
| 289 |
+
should_cache = (
|
| 290 |
+
self.cache_if_possible and
|
| 291 |
+
exists(seq_len) and
|
| 292 |
+
(offset + seq_len) <= self.cache_max_seq_len
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if (
|
| 296 |
+
should_cache and \
|
| 297 |
+
exists(self.cached_scales) and \
|
| 298 |
+
(seq_len + offset) <= self.cached_scales_seq_len
|
| 299 |
+
):
|
| 300 |
+
return self.cached_scales[offset:(offset + seq_len)]
|
| 301 |
+
|
| 302 |
+
scale = 1.
|
| 303 |
+
if self.use_xpos:
|
| 304 |
+
power = (t - len(t) // 2) / self.scale_base
|
| 305 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
| 306 |
+
scale = repeat(scale, 'n d -> n (d r)', r=2)
|
| 307 |
+
|
| 308 |
+
if should_cache and offset == 0:
|
| 309 |
+
self.cached_scales[:seq_len] = scale.detach()
|
| 310 |
+
self.cached_scales_seq_len = seq_len
|
| 311 |
+
|
| 312 |
+
return scale
|
| 313 |
+
|
| 314 |
+
def get_axial_freqs(self, *dims):
|
| 315 |
+
Colon = slice(None)
|
| 316 |
+
all_freqs = []
|
| 317 |
+
|
| 318 |
+
for ind, dim in enumerate(dims):
|
| 319 |
+
if self.freqs_for == 'pixel':
|
| 320 |
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
| 321 |
+
else:
|
| 322 |
+
pos = torch.arange(dim, device=self.device)
|
| 323 |
+
|
| 324 |
+
freqs = self.forward(pos, seq_len=dim)
|
| 325 |
+
|
| 326 |
+
all_axis = [None] * len(dims)
|
| 327 |
+
all_axis[ind] = Colon
|
| 328 |
+
|
| 329 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
| 330 |
+
all_freqs.append(freqs[new_axis_slice])
|
| 331 |
+
|
| 332 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
| 333 |
+
return torch.cat(all_freqs, dim=-1)
|
| 334 |
+
|
| 335 |
+
@autocast('cuda', enabled=False)
|
| 336 |
+
def forward(
|
| 337 |
+
self,
|
| 338 |
+
t: Tensor,
|
| 339 |
+
seq_len = None,
|
| 340 |
+
offset=0
|
| 341 |
+
):
|
| 342 |
+
"""Calculates the rotary frequencies for a given sequence of positions.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
t (torch.Tensor): A tensor of position indices.
|
| 346 |
+
seq_len (int): The total sequence length, used for caching.
|
| 347 |
+
offset (int): The starting position offset.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
torch.Tensor: A tensor of calculated rotation frequencies.
|
| 351 |
+
"""
|
| 352 |
+
should_cache = (
|
| 353 |
+
self.cache_if_possible and
|
| 354 |
+
not self.learned_freq and
|
| 355 |
+
exists(seq_len) and
|
| 356 |
+
self.freqs_for != 'pixel' and
|
| 357 |
+
(offset + seq_len) <= self.cache_max_seq_len
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if (
|
| 361 |
+
should_cache and \
|
| 362 |
+
exists(self.cached_freqs) and \
|
| 363 |
+
(offset + seq_len) <= self.cached_freqs_seq_len
|
| 364 |
+
):
|
| 365 |
+
return self.cached_freqs[offset:(offset + seq_len)].detach()
|
| 366 |
+
|
| 367 |
+
freqs = self.freqs
|
| 368 |
+
|
| 369 |
+
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
|
| 370 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r=2)
|
| 371 |
+
|
| 372 |
+
if should_cache and offset == 0:
|
| 373 |
+
self.cached_freqs[:seq_len] = freqs.detach()
|
| 374 |
+
self.cached_freqs_seq_len = seq_len
|
| 375 |
+
|
| 376 |
+
return freqs
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
#################################################################################
|
| 380 |
+
# Multi Head Attention #
|
| 381 |
+
#################################################################################
|
| 382 |
+
|
| 383 |
+
class LayerNorm(nn.Module):
|
| 384 |
+
"""Implements a Layer Normalization module."""
|
| 385 |
+
def __init__(self, d_model, eps=1e-12):
|
| 386 |
+
"""Initializes the LayerNorm module.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
d_model (int): The dimension of the model's features.
|
| 390 |
+
eps (float): A small value added to the variance for numerical stability.
|
| 391 |
+
"""
|
| 392 |
+
super(LayerNorm, self).__init__()
|
| 393 |
+
self.gamma = nn.Parameter(torch.ones(d_model))
|
| 394 |
+
self.beta = nn.Parameter(torch.zeros(d_model))
|
| 395 |
+
self.eps = eps
|
| 396 |
+
|
| 397 |
+
def forward(self, x):
|
| 398 |
+
"""Applies Layer Normalization to the input tensor along the last dimension.
|
| 399 |
+
Args:
|
| 400 |
+
x (torch.Tensor): The input tensor to normalize.
|
| 401 |
+
Returns:
|
| 402 |
+
torch.Tensor: The normalized tensor.
|
| 403 |
+
"""
|
| 404 |
+
mean = x.mean(-1, keepdim=True)
|
| 405 |
+
var = x.var(-1, unbiased=False, keepdim=True)
|
| 406 |
+
# '-1' means last dimension.
|
| 407 |
+
|
| 408 |
+
out = (x - mean) / torch.sqrt(var + self.eps)
|
| 409 |
+
out = self.gamma * out + self.beta
|
| 410 |
+
return out
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class PositionwiseFeedForward(nn.Module):
|
| 414 |
+
"""Implements the Position-wise Feed-Forward network of a Transformer block."""
|
| 415 |
+
|
| 416 |
+
def __init__(self, d_model, hidden, drop_prob=0.1):
|
| 417 |
+
"""Initializes the PositionwiseFeedForward module.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
d_model (int): The input and output dimension of the layer.
|
| 421 |
+
hidden (int): The dimension of the inner hidden layer.
|
| 422 |
+
drop_prob (float): The probability for the dropout layer.
|
| 423 |
+
"""
|
| 424 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 425 |
+
self.linear1 = nn.Linear(d_model, hidden)
|
| 426 |
+
self.linear2 = nn.Linear(hidden, d_model)
|
| 427 |
+
self.relu = nn.ReLU()
|
| 428 |
+
self.dropout = nn.Dropout(p=drop_prob)
|
| 429 |
+
|
| 430 |
+
def forward(self, x):
|
| 431 |
+
"""Passes the input through the feed-forward network.
|
| 432 |
+
The process is: Linear -> ReLU -> Dropout -> Linear.
|
| 433 |
+
Args:
|
| 434 |
+
x (torch.Tensor): The input tensor.
|
| 435 |
+
Returns:
|
| 436 |
+
torch.Tensor: The output tensor.
|
| 437 |
+
"""
|
| 438 |
+
x = self.linear1(x)
|
| 439 |
+
x = self.relu(x)
|
| 440 |
+
x = self.dropout(x)
|
| 441 |
+
x = self.linear2(x)
|
| 442 |
+
return x
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class ScaleDotProductAttention(nn.Module):
|
| 446 |
+
|
| 447 |
+
def __init__(self):
|
| 448 |
+
super(ScaleDotProductAttention, self).__init__()
|
| 449 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 450 |
+
|
| 451 |
+
def forward(self, q, k, v, mask=None, e=1e-12):
|
| 452 |
+
"""
|
| 453 |
+
Performs the Scaled Dot-Product Attention calculation.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
q (torch.Tensor): The query tensor.
|
| 457 |
+
k (torch.Tensor): The key tensor.
|
| 458 |
+
v (torch.Tensor): The value tensor.
|
| 459 |
+
mask (torch.Tensor, optional): A mask to prevent attention to
|
| 460 |
+
certain positions. Defaults to None.
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the attention
|
| 464 |
+
output and the attention scores.
|
| 465 |
+
"""
|
| 466 |
+
batch_size, head, length, d_tensor = k.size()
|
| 467 |
+
k_t = k.transpose(2, 3) # transpose
|
| 468 |
+
score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product
|
| 469 |
+
if mask is not None:
|
| 470 |
+
score = score.masked_fill(mask == 0, -10000)
|
| 471 |
+
score = self.softmax(score)
|
| 472 |
+
v = score @ v
|
| 473 |
+
return v, score
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class MultiHeadAttention(nn.Module):
|
| 477 |
+
"""Implements a Multi-Head Attention layer with optional Rotary Position Embeddings."""
|
| 478 |
+
|
| 479 |
+
def __init__(self, d_model, n_head):
|
| 480 |
+
"""Initializes the MultiHeadAttention layer.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
d_model (int): The total dimension of the model.
|
| 484 |
+
n_head (int): The number of attention heads. d_model must be divisible by n_head.
|
| 485 |
+
"""
|
| 486 |
+
super(MultiHeadAttention, self).__init__()
|
| 487 |
+
self.n_head = n_head
|
| 488 |
+
self.attention = ScaleDotProductAttention()
|
| 489 |
+
self.w_q = nn.Linear(d_model, d_model)
|
| 490 |
+
self.w_k = nn.Linear(d_model, d_model)
|
| 491 |
+
self.w_v = nn.Linear(d_model, d_model)
|
| 492 |
+
self.w_concat = nn.Linear(d_model, d_model)
|
| 493 |
+
|
| 494 |
+
self.rotary_emb = RotaryEmbedding(dim=d_model // n_head)
|
| 495 |
+
|
| 496 |
+
def forward(self, q, k, v, mask=None, apply_rotary=False):
|
| 497 |
+
"""Performs the forward pass for multi-head attention.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
q (torch.Tensor): The query tensor.
|
| 501 |
+
k (torch.Tensor): The key tensor.
|
| 502 |
+
v (torch.Tensor): The value tensor.
|
| 503 |
+
mask (torch.Tensor, optional): An attention mask. Defaults to None.
|
| 504 |
+
apply_rotary (bool): If True, applies Rotary Position Embeddings to Q and K
|
| 505 |
+
before the attention calculation. Defaults to False.
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the final output tensor
|
| 509 |
+
and the attention scores.
|
| 510 |
+
"""
|
| 511 |
+
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
|
| 512 |
+
q, k, v = self.split(q), self.split(k), self.split(v)
|
| 513 |
+
|
| 514 |
+
if apply_rotary:
|
| 515 |
+
# add Rotary Positional Embeddings (RoPE)
|
| 516 |
+
# https://arxiv.org/abs/2104.09864
|
| 517 |
+
q = self.rotary_emb.rotate_queries_or_keys(q)
|
| 518 |
+
k = self.rotary_emb.rotate_queries_or_keys(k)
|
| 519 |
+
|
| 520 |
+
out, attention = self.attention(q, k, v, mask=mask)
|
| 521 |
+
out = self.concat(out)
|
| 522 |
+
out = self.w_concat(out)
|
| 523 |
+
return out, attention
|
| 524 |
+
|
| 525 |
+
def split(self, tensor):
|
| 526 |
+
"""Splits the last dimension of a tensor into multiple heads."""
|
| 527 |
+
batch_size, length, d_model = tensor.size()
|
| 528 |
+
d_tensor = d_model // self.n_head
|
| 529 |
+
tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
|
| 530 |
+
return tensor
|
| 531 |
+
|
| 532 |
+
def concat(self, tensor):
|
| 533 |
+
"""Concatenates multiple heads back into a single tensor."""
|
| 534 |
+
batch_size, head, length, d_tensor = tensor.size()
|
| 535 |
+
d_model = head * d_tensor
|
| 536 |
+
tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
|
| 537 |
+
return tensor
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
#################################################################################
|
| 541 |
+
# Embedding Layers #
|
| 542 |
+
#################################################################################
|
| 543 |
+
|
| 544 |
+
class EmbeddingLayer(nn.Module):
|
| 545 |
+
"""A simple lookup-based embedding layer with Kaiming uniform initialization."""
|
| 546 |
+
def __init__(self, dim, vocab_dim):
|
| 547 |
+
super().__init__()
|
| 548 |
+
self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
|
| 549 |
+
torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
|
| 550 |
+
|
| 551 |
+
def forward(self, x):
|
| 552 |
+
"""Looks up the embeddings for the given indices.
|
| 553 |
+
Args:
|
| 554 |
+
x (torch.Tensor): A tensor of integer indices.
|
| 555 |
+
Returns:
|
| 556 |
+
torch.Tensor: The corresponding embedding vectors.
|
| 557 |
+
"""
|
| 558 |
+
return self.embedding[x]
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class TimestepEmbedder(nn.Module):
|
| 562 |
+
"""
|
| 563 |
+
Embeds scalar timesteps into vector representations.
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 567 |
+
"""Initializes the TimestepEmbedder.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
hidden_size (int): The final dimension of the timestep embedding.
|
| 571 |
+
frequency_embedding_size (int): The number of frequencies to use for
|
| 572 |
+
the sinusoidal embedding.
|
| 573 |
+
"""
|
| 574 |
+
super().__init__()
|
| 575 |
+
self.mlp = nn.Sequential(
|
| 576 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 577 |
+
nn.SiLU(),
|
| 578 |
+
nn.Linear(hidden_size, hidden_size, bias=True))
|
| 579 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 580 |
+
|
| 581 |
+
@staticmethod
|
| 582 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 583 |
+
"""
|
| 584 |
+
Create sinusoidal timestep embeddings.
|
| 585 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 586 |
+
These may be fractional.
|
| 587 |
+
:param dim: the dimension of the output.
|
| 588 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 589 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 590 |
+
"""
|
| 591 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 592 |
+
half = dim // 2
|
| 593 |
+
freqs = torch.exp(
|
| 594 |
+
- math.log(max_period)
|
| 595 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 596 |
+
/ half).to(device=t.device)
|
| 597 |
+
args = t[:, None].float() * freqs[None]
|
| 598 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 599 |
+
if dim % 2:
|
| 600 |
+
embedding = torch.cat(
|
| 601 |
+
[embedding,
|
| 602 |
+
torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 603 |
+
return embedding
|
| 604 |
+
|
| 605 |
+
def forward(self, t):
|
| 606 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 607 |
+
t_emb = self.mlp(t_freq)
|
| 608 |
+
return t_emb
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
#################################################################################
|
| 612 |
+
# Decoder #
|
| 613 |
+
#################################################################################
|
| 614 |
+
|
| 615 |
+
class DecoderLayer(nn.Module):
|
| 616 |
+
"""
|
| 617 |
+
code source: https://github.com/hyunwoongko/transformer
|
| 618 |
+
"""
|
| 619 |
+
|
| 620 |
+
def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
|
| 621 |
+
"""Initializes the DecoderLayer.
|
| 622 |
+
|
| 623 |
+
Args:
|
| 624 |
+
d_model (int): The dimension of the model.
|
| 625 |
+
ffn_hidden (int): The dimension of the hidden layer in the feed-forward network.
|
| 626 |
+
n_head (int): The number of attention heads.
|
| 627 |
+
drop_prob (float): The dropout probability.
|
| 628 |
+
"""
|
| 629 |
+
super(DecoderLayer, self).__init__()
|
| 630 |
+
|
| 631 |
+
self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
|
| 632 |
+
self.norm1 = LayerNorm(d_model=d_model)
|
| 633 |
+
self.dropout1 = nn.Dropout(p=drop_prob)
|
| 634 |
+
|
| 635 |
+
self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
|
| 636 |
+
self.norm2 = LayerNorm(d_model=d_model)
|
| 637 |
+
self.dropout2 = nn.Dropout(p=drop_prob)
|
| 638 |
+
|
| 639 |
+
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
|
| 640 |
+
self.norm3 = LayerNorm(d_model=d_model)
|
| 641 |
+
self.dropout3 = nn.Dropout(p=drop_prob)
|
| 642 |
+
|
| 643 |
+
def forward(self, dec, enc, trg_mask, src_mask, return_attention=False):
|
| 644 |
+
"""Performs one forward pass of the decoder layer.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
dec (torch.Tensor): The input tensor from the previous decoder layer.
|
| 648 |
+
enc (torch.Tensor): The output tensor from the encoder (for conditioning).
|
| 649 |
+
trg_mask (torch.Tensor): The mask for the decoder's self-attention.
|
| 650 |
+
src_mask (torch.Tensor): The mask for the cross-attention.
|
| 651 |
+
return_attention (bool): If True, returns the cross-attention weights.
|
| 652 |
+
|
| 653 |
+
Returns:
|
| 654 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the output tensor
|
| 655 |
+
and the attention weights (or None).
|
| 656 |
+
"""
|
| 657 |
+
attention = None
|
| 658 |
+
|
| 659 |
+
_x = dec
|
| 660 |
+
x, _ = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask, apply_rotary=True)
|
| 661 |
+
x = self.dropout1(x)
|
| 662 |
+
x = self.norm1(x + _x)
|
| 663 |
+
|
| 664 |
+
if enc is not None:
|
| 665 |
+
_x = x
|
| 666 |
+
if return_attention:
|
| 667 |
+
x, attention = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
|
| 668 |
+
else:
|
| 669 |
+
x, _ = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
|
| 670 |
+
x = self.dropout2(x)
|
| 671 |
+
x = self.norm2(x + _x)
|
| 672 |
+
|
| 673 |
+
_x = x
|
| 674 |
+
x = self.ffn(x)
|
| 675 |
+
x = self.dropout3(x)
|
| 676 |
+
x = self.norm3(x + _x)
|
| 677 |
+
return x, attention
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class Decoder_RoPE(nn.Module):
|
| 681 |
+
"""A decoder that uses Rotary Position Embeddings (RoPE).
|
| 682 |
+
|
| 683 |
+
This model is designed for a diffusion task, taking a ligand sequence, a
|
| 684 |
+
conditioning protein sequence, and a diffusion timestep (sigma) as input
|
| 685 |
+
to predict the output logits for the ligand.
|
| 686 |
+
"""
|
| 687 |
+
def __init__(self,
|
| 688 |
+
vocab_size,
|
| 689 |
+
seq_emb_dim,
|
| 690 |
+
hidden_size: int=640,
|
| 691 |
+
nhead: int=8,
|
| 692 |
+
n_layers: int=4,
|
| 693 |
+
expand_feedforward: int=3,
|
| 694 |
+
dropout: float=0.1):
|
| 695 |
+
|
| 696 |
+
"""Args:
|
| 697 |
+
vocab_size (int): The size of the output vocabulary (e.g., ligand tokens).
|
| 698 |
+
seq_emb_dim (int): The dimension of the input sequence embeddings.
|
| 699 |
+
hidden_size (int): The main hidden dimension of the Transformer model.
|
| 700 |
+
nhead (int): The number of attention heads in each DecoderLayer.
|
| 701 |
+
n_layers (int): The number of DecoderLayers to stack.
|
| 702 |
+
expand_feedforward (int): The expansion factor for the feed-forward
|
| 703 |
+
network's hidden layer.
|
| 704 |
+
dropout (float): The dropout probability.
|
| 705 |
+
"""
|
| 706 |
+
super().__init__()
|
| 707 |
+
|
| 708 |
+
self.hidden_size = hidden_size
|
| 709 |
+
self.vocab_embed = EmbeddingLayer(self.hidden_size, vocab_size)
|
| 710 |
+
self.linear = nn.Linear(self.hidden_size, vocab_size)
|
| 711 |
+
self.apply_seq_linear = False
|
| 712 |
+
|
| 713 |
+
if seq_emb_dim != self.hidden_size:
|
| 714 |
+
self.apply_seq_linear = True
|
| 715 |
+
self.linear_seq = nn.Linear(seq_emb_dim, self.hidden_size)
|
| 716 |
+
|
| 717 |
+
self.sigma_map = TimestepEmbedder(self.hidden_size)
|
| 718 |
+
|
| 719 |
+
self.layers = nn.ModuleList([DecoderLayer(d_model=self.hidden_size,
|
| 720 |
+
ffn_hidden=self.hidden_size * expand_feedforward,
|
| 721 |
+
n_head=nhead,
|
| 722 |
+
drop_prob=dropout)
|
| 723 |
+
for _ in range(n_layers)])
|
| 724 |
+
|
| 725 |
+
def forward(self,
|
| 726 |
+
ligand: torch.Tensor,
|
| 727 |
+
sigma: torch.Tensor,
|
| 728 |
+
sequence: torch.Tensor,
|
| 729 |
+
sequence_lengths: torch.Tensor,
|
| 730 |
+
lig_padding_mask: Optional[torch.Tensor]=None,
|
| 731 |
+
return_attention: bool=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 732 |
+
"""Performs the forward pass of the decoder.
|
| 733 |
+
|
| 734 |
+
It processes the ligand sequence conditioned on the protein sequence and the
|
| 735 |
+
diffusion timestep (sigma). The sigma embedding is prepended to the protein
|
| 736 |
+
sequence to form a single conditioning context.
|
| 737 |
+
|
| 738 |
+
Args:
|
| 739 |
+
ligand (torch.Tensor): A batch of ligand token ID tensors.
|
| 740 |
+
sigma (torch.Tensor): A batch of scalar diffusion timesteps.
|
| 741 |
+
sequence (torch.Tensor): A batch of conditioning protein sequence embeddings.
|
| 742 |
+
sequence_lengths (torch.Tensor): The original lengths of the protein sequences.
|
| 743 |
+
lig_padding_mask (Optional[torch.Tensor]): A padding mask for the ligand.
|
| 744 |
+
return_attention (bool): If True, returns the cross-attention weights
|
| 745 |
+
from the last decoder layer.
|
| 746 |
+
|
| 747 |
+
Returns:
|
| 748 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple of (output_logits, attention_weights).
|
| 749 |
+
"""
|
| 750 |
+
ligand = self.vocab_embed(ligand)
|
| 751 |
+
sigma = F.silu(self.sigma_map(sigma)).unsqueeze(1)
|
| 752 |
+
if self.apply_seq_linear:
|
| 753 |
+
sequence = self.linear_seq(sequence)
|
| 754 |
+
condition = torch.cat([sigma, sequence], dim=1)
|
| 755 |
+
sequence_lengths += 1
|
| 756 |
+
|
| 757 |
+
range_tensor = torch.arange(condition.shape[1], device=sequence.device).unsqueeze(0)
|
| 758 |
+
condition_mask = range_tensor < sequence_lengths.unsqueeze(1)
|
| 759 |
+
condition_mask = condition_mask.unsqueeze(1).unsqueeze(2)
|
| 760 |
+
if lig_padding_mask is not None:
|
| 761 |
+
lig_padding_mask = lig_padding_mask.unsqueeze(1).unsqueeze(2)
|
| 762 |
+
|
| 763 |
+
for layer in self.layers:
|
| 764 |
+
ligand, attention = layer(ligand, condition,
|
| 765 |
+
trg_mask=lig_padding_mask, src_mask=condition_mask,
|
| 766 |
+
return_attention=return_attention)
|
| 767 |
+
|
| 768 |
+
output = self.linear(ligand)
|
| 769 |
+
return output, attention
|
protobind_diff/esm_inference.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, sys
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import esm
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import re
|
| 9 |
+
from Bio import SeqIO
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
import lightning.pytorch as pl
|
| 12 |
+
from protobind_diff.model import ModelGenerator
|
| 13 |
+
from protobind_diff.data_loader import InferenceDataset
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
+
|
| 16 |
+
REPO_ID = "ai-gero/ProtoBind-Diff"
|
| 17 |
+
FILENAME = "model.ckpt"
|
| 18 |
+
TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
|
| 19 |
+
|
| 20 |
+
class ProtobindInference():
|
| 21 |
+
"""
|
| 22 |
+
Simplified inference class that only supports ProtobindMaskedDiffusion model.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, checkpoint_path, tokenizer_path,
|
| 26 |
+
sequence_embedding_dim, lig_max_length: int=170, nucleus_p: float=0.9,
|
| 27 |
+
eta: float=0.1, sampling_steps: int=250,
|
| 28 |
+
**kwargs):
|
| 29 |
+
self.checkpoint_path = Path(checkpoint_path)
|
| 30 |
+
self.tokenizer_path = Path(tokenizer_path)
|
| 31 |
+
self.sequence_embedding_dim = sequence_embedding_dim
|
| 32 |
+
|
| 33 |
+
# Set up sampler params
|
| 34 |
+
self.lig_max_length = lig_max_length
|
| 35 |
+
self.nucleus_p = nucleus_p
|
| 36 |
+
self.eta = eta
|
| 37 |
+
self.sampling_steps = sampling_steps
|
| 38 |
+
|
| 39 |
+
# Load model
|
| 40 |
+
self.model = self.load_model()
|
| 41 |
+
|
| 42 |
+
def predict_on_dataloader(self, dl, devices=1, accelerator='cuda') -> Tuple[np.ndarray, np.ndarray]:
|
| 43 |
+
if accelerator == 'cuda':
|
| 44 |
+
torch.set_float32_matmul_precision('medium')
|
| 45 |
+
precision = "16-mixed"
|
| 46 |
+
else:
|
| 47 |
+
precision = "32-true"
|
| 48 |
+
trainer = pl.Trainer(precision=precision, use_distributed_sampler=False,
|
| 49 |
+
inference_mode=True, accelerator=accelerator, devices=devices)
|
| 50 |
+
predictions_batches = trainer.predict(model=self.model, dataloaders=dl)
|
| 51 |
+
return predictions_batches
|
| 52 |
+
|
| 53 |
+
def load_model(self):
|
| 54 |
+
"""Simplified model loading - only supports ModelGenerator"""
|
| 55 |
+
model = ModelGenerator.load_from_checkpoint(
|
| 56 |
+
self.checkpoint_path,
|
| 57 |
+
tokenizer_path=self.tokenizer_path,
|
| 58 |
+
seq_embedding_dim=self.sequence_embedding_dim,
|
| 59 |
+
load=True,
|
| 60 |
+
)
|
| 61 |
+
model.model_length = self.lig_max_length
|
| 62 |
+
model.nucleus_p = self.nucleus_p
|
| 63 |
+
model.eta = self.eta
|
| 64 |
+
model.sampling_steps = self.sampling_steps
|
| 65 |
+
model.model.eval()
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
def get_esm_embedding(sequence: str, model_name: str, device: torch.device) -> torch.Tensor:
|
| 69 |
+
"""Generates a protein embedding using a pre-trained ESM model.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
sequence (str): The amino acid sequence.
|
| 73 |
+
model_name (str): The name of the ESM model to use.
|
| 74 |
+
device (torch.device): The device to run the model on.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
torch.Tensor: The final residue-level embedding tensor, with start/end tokens removed.
|
| 78 |
+
"""
|
| 79 |
+
model, alphabet = esm.pretrained.load_model_and_alphabet(model_name)
|
| 80 |
+
model.eval()
|
| 81 |
+
number_layers = re.search(r'_t(\d+)_', model_name)
|
| 82 |
+
number_layers = int(number_layers.group(1))
|
| 83 |
+
|
| 84 |
+
model = model.to(device)
|
| 85 |
+
batch_converter = alphabet.get_batch_converter()
|
| 86 |
+
_, _, tokens = batch_converter([("protein", sequence)])
|
| 87 |
+
tokens = tokens.to(device)
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
out = model(tokens, repr_layers=[number_layers])
|
| 90 |
+
return out["representations"][number_layers][:, 1:-1, :] # [1, seq_len, emb_dim]
|
| 91 |
+
|
| 92 |
+
def download_from_hub_hf(cache: Path, filename) -> Path:
|
| 93 |
+
"""
|
| 94 |
+
Fetch file from Hugging Face into `cache`.
|
| 95 |
+
Returns the local path to the file inside HF’s cache structure.
|
| 96 |
+
"""
|
| 97 |
+
cache.mkdir(parents=True, exist_ok=True)
|
| 98 |
+
local_path = hf_hub_download(
|
| 99 |
+
repo_id=REPO_ID,
|
| 100 |
+
filename=filename,
|
| 101 |
+
cache_dir=cache,
|
| 102 |
+
)
|
| 103 |
+
return Path(local_path)
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
parser = argparse.ArgumentParser()
|
| 107 |
+
parser.add_argument("--sequence", help="Amino acid sequence (1-letter code)")
|
| 108 |
+
parser.add_argument("--output_dir", default="./outputs", help="Output dir for SMILES")
|
| 109 |
+
parser.add_argument("--output", default="generated_smiles.txt", help="Output file for generated SMILES")
|
| 110 |
+
parser.add_argument("--n_batches", type=int, default=5, help="Number of batches to generate for this sequence")
|
| 111 |
+
parser.add_argument("--batch_size", type=int, default=10, help="Max number of generated molecules per batch")
|
| 112 |
+
parser.add_argument("--fasta_file", default="./examples/input.fasta", help="Input FASTA file")
|
| 113 |
+
parser.add_argument("--checkpoint_path", type=str, help="Path to the model checkpoint")
|
| 114 |
+
parser.add_argument('--model_name', type=str, default='esm2_t33_650M_UR50D',
|
| 115 |
+
help="ESM model name. See https://github.com/facebookresearch/esm")
|
| 116 |
+
parser.add_argument('--tokenizer_path', help='Path to tokenizer.json file. If not provided, uses a default path and downloads if needed.')
|
| 117 |
+
parser.add_argument('--cache', type=str, default = "./cache", help='Cache folder for ckpt')
|
| 118 |
+
|
| 119 |
+
parser.add_argument("--sampling_steps", type=int, default=250, help="Number of steps during sampling")
|
| 120 |
+
parser.add_argument("--lig_max_length", type=int, default=170, help="Max length of generated molecules")
|
| 121 |
+
parser.add_argument("--nucleus_p", type=float, default=0.9,
|
| 122 |
+
help="Value of the nucleus sampling parameter. For more details, see https://arxiv.org/abs/2503.00307")
|
| 123 |
+
parser.add_argument("--eta", type=float, default=0.1,
|
| 124 |
+
help="Value of the probability of remasking. For more details, see https://arxiv.org/abs/2503.00307")
|
| 125 |
+
|
| 126 |
+
args = parser.parse_args()
|
| 127 |
+
if args.fasta_file:
|
| 128 |
+
sequence = str(next(SeqIO.parse(args.fasta_file, "fasta")).seq)
|
| 129 |
+
elif args.sequence:
|
| 130 |
+
sequence = args.sequence.strip().upper()
|
| 131 |
+
else:
|
| 132 |
+
sys.exit("Error: provide --sequence of --fasta_file")
|
| 133 |
+
|
| 134 |
+
if args.checkpoint_path:
|
| 135 |
+
ckpt_path = Path(args.checkpoint_path)
|
| 136 |
+
else:
|
| 137 |
+
torch.hub.set_dir(args.cache) # for ESM model
|
| 138 |
+
ckpt_path = download_from_hub_hf(Path(args.cache), FILENAME)
|
| 139 |
+
|
| 140 |
+
if args.tokenizer_path:
|
| 141 |
+
tokenizer_path = Path(args.tokenizer_path)
|
| 142 |
+
if not tokenizer_path.exists():
|
| 143 |
+
sys.exit(f"Error: Tokenizer file not found at specified path: {tokenizer_path}")
|
| 144 |
+
else:
|
| 145 |
+
tokenizer_path = download_from_hub_hf(Path(args.cache), TOKENIZER_FILENAME)
|
| 146 |
+
|
| 147 |
+
# Determine the device
|
| 148 |
+
if torch.cuda.is_available():
|
| 149 |
+
device = torch.device("cuda") # Use CUDA if available
|
| 150 |
+
elif torch.backends.mps.is_available():
|
| 151 |
+
device = torch.device("mps") # Use MPS for Apple Silicon if available
|
| 152 |
+
else:
|
| 153 |
+
device = torch.device("cpu") # Fallback to CPU
|
| 154 |
+
|
| 155 |
+
embedding = get_esm_embedding(sequence, args.model_name, device).to(dtype=torch.bfloat16)
|
| 156 |
+
sequence_embedding_dim = embedding.shape[2]
|
| 157 |
+
dataset = InferenceDataset(embedding, batch_size=args.batch_size, n_batches=args.n_batches)
|
| 158 |
+
loader = DataLoader(dataset, batch_size=None)
|
| 159 |
+
model = ProtobindInference(ckpt_path, tokenizer_path, sequence_embedding_dim,
|
| 160 |
+
sampling_steps=args.sampling_steps, nucleus_p=args.nucleus_p,
|
| 161 |
+
eta=args.eta, lig_max_length=args.lig_max_length,)
|
| 162 |
+
|
| 163 |
+
predictions = model.predict_on_dataloader(loader, accelerator=str(device))
|
| 164 |
+
|
| 165 |
+
all_smiles = [smi for batch in predictions for smi in batch[0]]
|
| 166 |
+
out_dir = Path(args.output_dir)
|
| 167 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 168 |
+
with open(out_dir / args.output, "w") as f:
|
| 169 |
+
f.write("SMILES\n")
|
| 170 |
+
for smi in all_smiles:
|
| 171 |
+
f.write(smi + "\n")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
main()
|
protobind_diff/ligands/__init__.py
ADDED
|
File without changes
|
protobind_diff/ligands/rdkit_utils.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from typing import Optional, Tuple, Union, List
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from multiprocessing import Pool
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
|
| 10 |
+
from FPSim2 import FPSim2Engine
|
| 11 |
+
import rdkit
|
| 12 |
+
from rdkit import Chem, RDLogger
|
| 13 |
+
from rdkit.Chem import DataStructs, Descriptors
|
| 14 |
+
from rdkit.DataStructs import BulkTanimotoSimilarity
|
| 15 |
+
from sklearn.cluster import DBSCAN
|
| 16 |
+
import scipy
|
| 17 |
+
RDLogger.DisableLog('rdApp.*')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BoostWrapper(object):
|
| 21 |
+
""" Help joblib to deal with boost functions """
|
| 22 |
+
def __init__(self, method_name, module_name):
|
| 23 |
+
self.method_name = method_name
|
| 24 |
+
self.module = importlib.import_module(module_name)
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def method(self):
|
| 28 |
+
return getattr(self.module, self.method_name)
|
| 29 |
+
|
| 30 |
+
def __call__(self, *args, **kwargs):
|
| 31 |
+
return self.method(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def cluster_fpsim2(distance_path, smiles_h5_path=None, dist_eps=0.15):
|
| 35 |
+
""" Cluster precomputed FPSim2 distance matrix using DBSCAN algorithm """
|
| 36 |
+
if isinstance(distance_path, str):
|
| 37 |
+
distance_path = Path(distance_path, smiles_h5_path=None)
|
| 38 |
+
|
| 39 |
+
if smiles_h5_path is None:
|
| 40 |
+
smiles_h5_path = distance_path.parent / 'all_smiles.h5'
|
| 41 |
+
precomputed_indices = FPSim2Engine(smiles_h5_path).fps[:, 0]
|
| 42 |
+
map_precomputed = np.argsort(precomputed_indices) # maps original smiles order to FPSim2 order
|
| 43 |
+
|
| 44 |
+
precomputed_distance = scipy.sparse.load_npz(distance_path)
|
| 45 |
+
db = DBSCAN(eps=dist_eps, min_samples=1, metric='precomputed', n_jobs=-1)
|
| 46 |
+
labels = db.fit_predict(precomputed_distance)
|
| 47 |
+
|
| 48 |
+
# df_ = pd.DataFrame(data=smiles.keys(), index=list(smiles.values()), columns=['SMILES'])
|
| 49 |
+
# df_ = df_.sort_index()
|
| 50 |
+
# df_['cluster'] = labels[map_precomputed]
|
| 51 |
+
return labels[map_precomputed]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def tanimoto_smiles(mol1, mol2, fp='rdkit', bits=2048, radius=2):
|
| 55 |
+
|
| 56 |
+
if isinstance(mol1, str):
|
| 57 |
+
mol1 = Chem.MolFromSmiles(mol1)
|
| 58 |
+
if isinstance(mol2, str):
|
| 59 |
+
mol2 = Chem.MolFromSmiles(mol2)
|
| 60 |
+
|
| 61 |
+
_supported_fps = {
|
| 62 |
+
'rdkit': Chem.RDKFingerprint,
|
| 63 |
+
'morgan': Chem.rdMolDescriptors.GetMorganFingerprintAsBitVect,
|
| 64 |
+
'maccs': Chem.rdMolDescriptors.GetMACCSKeysFingerprint,
|
| 65 |
+
}
|
| 66 |
+
if fp not in _supported_fps:
|
| 67 |
+
raise ValueError(f"Fingerprint {fp} is not supported, available fps {_supported_fps.keys()}")
|
| 68 |
+
|
| 69 |
+
ffp = None
|
| 70 |
+
if fp == 'rdkit':
|
| 71 |
+
ffp = lambda x: _supported_fps[fp](x, fpSize=bits)
|
| 72 |
+
elif fp == 'morgan':
|
| 73 |
+
ffp = lambda x: _supported_fps[fp](x, fpSize=bits, radius=radius, nBits=bits)
|
| 74 |
+
elif fp == 'maccs':
|
| 75 |
+
ffp = _supported_fps[fp]
|
| 76 |
+
|
| 77 |
+
return rdkit.DataStructs.TanimotoSimilarity(ffp(mol1), ffp(mol2))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def validate_smile(smile):
|
| 81 |
+
try:
|
| 82 |
+
mol = Chem.MolFromSmiles(smile)
|
| 83 |
+
Chem.SanitizeMol(mol)
|
| 84 |
+
return smile
|
| 85 |
+
except Exception:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def calc_chem_desc(smiles):
|
| 90 |
+
rdkit_features = {'MolWt': rdkit.Chem.Descriptors.MolWt,
|
| 91 |
+
'MolLogP': rdkit.Chem.Descriptors.MolLogP,
|
| 92 |
+
'NumRotatableBonds': rdkit.Chem.Descriptors.NumRotatableBonds,
|
| 93 |
+
'CalcTPSA': rdkit.Chem.rdMolDescriptors.CalcTPSA,
|
| 94 |
+
'RingCount': rdkit.Chem.Descriptors.RingCount,
|
| 95 |
+
}
|
| 96 |
+
if isinstance(smiles[0], str):
|
| 97 |
+
mols = smiles_to_mols(smiles)
|
| 98 |
+
elif isinstance(smiles[0], rdkit.Chem.rdchem.Mol):
|
| 99 |
+
mols = smiles
|
| 100 |
+
else:
|
| 101 |
+
raise TypeError(f'smiles must be a string or a rdkit.Chem.rdchem.Mol: {type(smiles[0])}')
|
| 102 |
+
res = {}
|
| 103 |
+
for name, func in rdkit_features.items():
|
| 104 |
+
res[name] = np.asarray([func(m) if m is not None else np.nan for m in mols ])
|
| 105 |
+
return pd.DataFrame(res)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def smiles_to_mols(smiles, n_jobs=8):
|
| 109 |
+
if isinstance(smiles, (list, tuple, np.ndarray)):
|
| 110 |
+
pass
|
| 111 |
+
elif isinstance(smiles, pd.Series):
|
| 112 |
+
smiles = smiles.tolist()
|
| 113 |
+
else:
|
| 114 |
+
raise TypeError(f"{type(smiles)=}")
|
| 115 |
+
|
| 116 |
+
assert len(smiles) > 0
|
| 117 |
+
assert isinstance(smiles[0], str), f"expect smiles string, got f{smiles[0]}"
|
| 118 |
+
|
| 119 |
+
mols = joblib.Parallel(n_jobs=n_jobs)(
|
| 120 |
+
joblib.delayed(BoostWrapper('MolFromSmiles', 'rdkit.Chem.rdmolfiles', ))(smi) for smi in smiles)
|
| 121 |
+
return mols
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def smiles_to_fps(smiles_or_mols, finger_type='rdkit', n_jobs=8, fp_param=None):
|
| 125 |
+
if isinstance(smiles_or_mols, (list, tuple, np.ndarray)):
|
| 126 |
+
pass
|
| 127 |
+
elif isinstance(smiles_or_mols, pd.Series):
|
| 128 |
+
smiles_or_mols = smiles_or_mols.tolist()
|
| 129 |
+
else:
|
| 130 |
+
raise TypeError(f"{type(smiles_or_mols)=}")
|
| 131 |
+
|
| 132 |
+
assert len(smiles_or_mols) > 0
|
| 133 |
+
assert isinstance(smiles_or_mols[0],
|
| 134 |
+
(str, rdkit.Chem.rdchem.Mol)), f"variable {smiles_or_mols[0]} has type {type(smiles_or_mols[0])}"
|
| 135 |
+
|
| 136 |
+
if isinstance(smiles_or_mols[0], str):
|
| 137 |
+
mols = smiles_to_mols(smiles_or_mols)
|
| 138 |
+
else:
|
| 139 |
+
mols = smiles_or_mols
|
| 140 |
+
|
| 141 |
+
if fp_param is None:
|
| 142 |
+
fp_param = {}
|
| 143 |
+
fp_func, fp_func_name, fp_func_module, fp_params = _find_fingerprint_function(finger_type)
|
| 144 |
+
fp_params.update(fp_param)
|
| 145 |
+
if finger_type == 'morgan':
|
| 146 |
+
fp_func = fp_func(**fp_params).GetFingerprint
|
| 147 |
+
fp_params = {}
|
| 148 |
+
fps = joblib.Parallel(n_jobs=n_jobs, prefer="threads")(
|
| 149 |
+
joblib.delayed(fp_func)(mol, **fp_params) for mol in mols)
|
| 150 |
+
return fps
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _find_fingerprint_function(finger_type: str) -> Tuple[callable, str, str, dict]:
|
| 154 |
+
kwargs = {}
|
| 155 |
+
if finger_type == 'rdkit':
|
| 156 |
+
fp_func_name = 'RDKFingerprint'
|
| 157 |
+
fp_func_module = 'rdkit.Chem'
|
| 158 |
+
elif finger_type == 'maccs':
|
| 159 |
+
fp_func_name = 'GetMACCSKeysFingerprint'
|
| 160 |
+
fp_func_module = 'rdkit.Chem.rdMolDescriptors'
|
| 161 |
+
elif finger_type == 'morgan':
|
| 162 |
+
fp_func_name = 'GetMorganGenerator'
|
| 163 |
+
fp_func_module = 'rdkit.Chem.AllChem'
|
| 164 |
+
kwargs = dict(atomInvariantsGenerator=rdkit.Chem.rdFingerprintGenerator.GetMorganFeatureAtomInvGen(),
|
| 165 |
+
radius=2, fpSize=2048, countSimulation=True)
|
| 166 |
+
else:
|
| 167 |
+
raise NotImplementedError(f"Use `rdkit` or `maccs` or `morgan` as fps")
|
| 168 |
+
|
| 169 |
+
fp_func = getattr(importlib.import_module(fp_func_module), fp_func_name)
|
| 170 |
+
return fp_func, fp_func_name, fp_func_module, kwargs
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def randomize_smiles_rotated(smiles: str, with_order_reversal: bool = True) -> str:
|
| 174 |
+
"""
|
| 175 |
+
Randomize a SMILES string by doing a cyclic rotation of the atomic indices.
|
| 176 |
+
|
| 177 |
+
Adapted from https://github.com/GLambard/SMILES-X/blob/758478663030580a363a9ee61c11f6d6448e18a1/SMILESX/augm.py#L19.
|
| 178 |
+
|
| 179 |
+
The outputs of this function can be reproduced by setting the seed with random.seed().
|
| 180 |
+
|
| 181 |
+
Raises:
|
| 182 |
+
InvalidSmiles: for invalid molecules.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
smiles: SMILES string to randomize.
|
| 186 |
+
with_order_reversal: whether to reverse the atom order with 50% chance.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Randomized SMILES string.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| 193 |
+
|
| 194 |
+
n_atoms = mol.GetNumAtoms()
|
| 195 |
+
|
| 196 |
+
# Generate random values
|
| 197 |
+
rotation_index = np.random.randint(0, n_atoms - 1)
|
| 198 |
+
reverse_order = with_order_reversal and np.random.choice([True, False])
|
| 199 |
+
|
| 200 |
+
# Generate new atom indices order
|
| 201 |
+
atoms = list(range(n_atoms))
|
| 202 |
+
new_atoms_order = (
|
| 203 |
+
atoms[rotation_index % len(atoms) :] + atoms[: rotation_index % len(atoms)]
|
| 204 |
+
)
|
| 205 |
+
if reverse_order:
|
| 206 |
+
new_atoms_order.reverse()
|
| 207 |
+
|
| 208 |
+
mol = Chem.RenumberAtoms(mol, new_atoms_order)
|
| 209 |
+
return Chem.MolToSmiles(mol, canonical=False)
|
protobind_diff/ligands/smiles_tokenizer.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Taken from https://github.com/MolecularAI/Chemformer/
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
from pysmilesutils.tokenize import SMILESTokenizer
|
| 4 |
+
|
| 5 |
+
class ChemformerTokenizer(SMILESTokenizer):
|
| 6 |
+
"""
|
| 7 |
+
Tokenizer for the Chemformer.
|
| 8 |
+
|
| 9 |
+
There are a few different features that sets this apart from the `SMILESTokenizer`:
|
| 10 |
+
* It reserves two extra special tokens, "mask" and "sep"
|
| 11 |
+
* It distinguish between chemical and non-chemical tokens
|
| 12 |
+
|
| 13 |
+
:param smiles: A list of SMILES that are used to create the vocabulary for the tokenizer. Defaults to None.
|
| 14 |
+
:param tokens: A list of tokens (strings) that the tokenizer uses when tokenizing SMILES. Defaults to None.
|
| 15 |
+
:param regex_token_patterns: A list of regular expressions that the tokenizer uses when tokenizing SMILES.
|
| 16 |
+
:param beginning_of_smiles_token: Token that is added to beginning of SMILES. Defaults to "^".
|
| 17 |
+
:param end_of_smiles_token: Token that is added to the end of SMILES. Defaults to "&".
|
| 18 |
+
:param padding_token: Token used for padding. Defalts to " ".
|
| 19 |
+
:param unknown_token: Token that is used for unknown ids when decoding encoded data. Defaults to "?".
|
| 20 |
+
:param mask_token: Token that is used by the Masker
|
| 21 |
+
:param sep_token: Token that is used to separate sentences, currently unused
|
| 22 |
+
:param filename: if given and `smiles` is None, load the vocabulary from disc
|
| 23 |
+
:raises: ValueError: If the `encoding_type` is invalid.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
smiles: List[str] = None,
|
| 29 |
+
tokens: List[str] = None,
|
| 30 |
+
regex_token_patterns: List[str] = None,
|
| 31 |
+
beginning_of_smiles_token: str = "^",
|
| 32 |
+
end_of_smiles_token: str = "&",
|
| 33 |
+
padding_token: str = "<PAD>",
|
| 34 |
+
unknown_token: str = "?",
|
| 35 |
+
mask_token: str = "<MASK>",
|
| 36 |
+
sep_token: str = "<SEP>",
|
| 37 |
+
filename: str = None,
|
| 38 |
+
) -> None:
|
| 39 |
+
self._mask_token = mask_token
|
| 40 |
+
self._sep_token = sep_token
|
| 41 |
+
self._chem_start_idx = 6 # Default, number of special tokens + 1
|
| 42 |
+
self._chem_token_idxs: Optional[List[int]] = None
|
| 43 |
+
super().__init__(
|
| 44 |
+
smiles=smiles,
|
| 45 |
+
tokens=tokens,
|
| 46 |
+
regex_token_patterns=regex_token_patterns,
|
| 47 |
+
beginning_of_smiles_token=beginning_of_smiles_token,
|
| 48 |
+
end_of_smiles_token=end_of_smiles_token,
|
| 49 |
+
padding_token=padding_token,
|
| 50 |
+
unknown_token=unknown_token,
|
| 51 |
+
encoding_type="index",
|
| 52 |
+
filename=filename,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def chem_token_idxs(self) -> List[int]:
|
| 58 |
+
"""Returns the indices of the vocabulary that are chemical tokens"""
|
| 59 |
+
if self._chem_token_idxs is None:
|
| 60 |
+
self._chem_token_idxs = list(range(self._chem_start_idx, len(self.vocabulary)))
|
| 61 |
+
return self._chem_token_idxs
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def mask_token_id(self):
|
| 65 |
+
"""Get the mask token id"""
|
| 66 |
+
return self.vocabulary[self._mask_token]
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def vocab_size(self):
|
| 70 |
+
return len(self.vocabulary)
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def special_tokens(self) -> Dict[str, str]:
|
| 74 |
+
"""Returns a dictionary of non-character tokens"""
|
| 75 |
+
return {
|
| 76 |
+
"start": self._beginning_of_smiles_token,
|
| 77 |
+
"end": self._end_of_smiles_token,
|
| 78 |
+
"pad": self._padding_token,
|
| 79 |
+
"unknown": self._unknown_token,
|
| 80 |
+
"mask": self._mask_token,
|
| 81 |
+
"sep": self._sep_token,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
def add_tokens(self, tokens: List[str], regex: bool = False, smiles=None) -> None:
|
| 85 |
+
"""Adds tokens to the classes list of tokens.
|
| 86 |
+
|
| 87 |
+
The new tokens are added to the front of the token list and take priority over old tokens. Note that that the
|
| 88 |
+
vocabulary of the tokenizer is not updated after the tokens are added,
|
| 89 |
+
and must be updated by calling `create_vocabulary_from_smiles`.
|
| 90 |
+
|
| 91 |
+
If `regex` is False, the tokens are interpreted as non-chemical tokens, which distinguish
|
| 92 |
+
them for processing by e.g. the masker.
|
| 93 |
+
|
| 94 |
+
:param tokens: List of tokens to be added.
|
| 95 |
+
:param regex: If `True` the input tokens are treated as
|
| 96 |
+
regular expressions and are added to the list of regular expressions
|
| 97 |
+
instead of token list. Defaults to False.
|
| 98 |
+
:param smiles: If a list of smiles is provided, the vocabulary will be created, defaults to None
|
| 99 |
+
|
| 100 |
+
:raises ValueError: If any of the tokens supplied are already in the list
|
| 101 |
+
of tokens.
|
| 102 |
+
"""
|
| 103 |
+
super().add_tokens(tokens, regex, smiles)
|
| 104 |
+
if not regex:
|
| 105 |
+
self._chem_start_idx += len(tokens)
|
| 106 |
+
self._chem_token_idxs = None
|
| 107 |
+
|
| 108 |
+
def _reset_vocabulary(self) -> Dict[str, int]:
|
| 109 |
+
"""Create a new tokens vocabulary.
|
| 110 |
+
|
| 111 |
+
:return: New tokens vocabulary
|
| 112 |
+
"""
|
| 113 |
+
dict_ = {
|
| 114 |
+
self._padding_token: 0,
|
| 115 |
+
self._unknown_token: 1,
|
| 116 |
+
self._beginning_of_smiles_token: 2,
|
| 117 |
+
self._end_of_smiles_token: 3,
|
| 118 |
+
self._mask_token: 4,
|
| 119 |
+
self._sep_token: 5,
|
| 120 |
+
}
|
| 121 |
+
for token in self._tokens:
|
| 122 |
+
dict_.setdefault(token, len(dict_))
|
| 123 |
+
return dict_
|
| 124 |
+
|
| 125 |
+
def _state_properties(self) -> Dict[str, Any]:
|
| 126 |
+
"""Return properties to reconstruct the internal state of the tokenizer"""
|
| 127 |
+
dict_ = super()._state_properties()
|
| 128 |
+
dict_["chem_start_idx"] = self._chem_start_idx
|
| 129 |
+
return dict_
|
| 130 |
+
|
| 131 |
+
def _update_state(self, dict_: Dict[str, Any]) -> None:
|
| 132 |
+
"""Update the internal state with properties loaded from disc"""
|
| 133 |
+
super()._update_state(dict_)
|
| 134 |
+
self._chem_start_idx = dict_["chem_start_idx"]
|
| 135 |
+
self._chem_token_idxs = None
|
protobind_diff/model.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Tuple, Optional, Dict
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import lightning.pytorch as pl
|
| 7 |
+
import logging
|
| 8 |
+
import huggingface_hub
|
| 9 |
+
|
| 10 |
+
from .ligands.rdkit_utils import validate_smile, calc_chem_desc, tanimoto_smiles
|
| 11 |
+
from .ligands.smiles_tokenizer import ChemformerTokenizer
|
| 12 |
+
from .noise_schedule import _sample_t, q_xt, _sample_categorical, LogLinearNoise
|
| 13 |
+
from .decoder_rope import Decoder_RoPE
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("lightning")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ModelGenerator(pl.LightningModule):
|
| 19 |
+
"""
|
| 20 |
+
ProtoBind-Diff model with SMILES and ESM-2 protein encodings.
|
| 21 |
+
"""
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_exp_dir(
|
| 24 |
+
exp_dir: str | None,
|
| 25 |
+
output_dir: str,
|
| 26 |
+
exp_dir_prefix: str,
|
| 27 |
+
split: str
|
| 28 |
+
) -> Path:
|
| 29 |
+
"""Determines the experiment directory path."""
|
| 30 |
+
if exp_dir:
|
| 31 |
+
return Path(exp_dir)
|
| 32 |
+
return Path(output_dir) / split / exp_dir_prefix
|
| 33 |
+
|
| 34 |
+
def __init__(self, *args, **kwargs):
|
| 35 |
+
"""Initializes the Lightning Module, saves hyperparameters, and configures the model."""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
is_load = kwargs['load']
|
| 39 |
+
if not is_load:
|
| 40 |
+
self.save_hyperparameters()
|
| 41 |
+
|
| 42 |
+
self.data_dir = Path(kwargs["data_dir"])
|
| 43 |
+
exp_dir = kwargs.get('exp_dir', None)
|
| 44 |
+
self.exp_dir = self.get_exp_dir(
|
| 45 |
+
exp_dir=exp_dir,
|
| 46 |
+
output_dir=kwargs["output_dir"],
|
| 47 |
+
exp_dir_prefix=kwargs["exp_dir_prefix"],
|
| 48 |
+
split=kwargs["split"]
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self.configure_model_params(**kwargs)
|
| 52 |
+
|
| 53 |
+
def configure_model_params(self, **kwargs):
|
| 54 |
+
"""Parses keyword arguments to configure the model, tokenizer, and training parameters."""
|
| 55 |
+
|
| 56 |
+
self.learning_rate = kwargs.pop('learning_rate')
|
| 57 |
+
self.weight_decay = float(kwargs.pop('weight_decay'))
|
| 58 |
+
|
| 59 |
+
# Decoder params for masked diffusion
|
| 60 |
+
decoder_params = {
|
| 61 |
+
'nhead': kwargs['num_heads_decoder'],
|
| 62 |
+
'n_layers': kwargs['num_decoder_layers'],
|
| 63 |
+
'hidden_size': kwargs['decoder_hidd_dim'],
|
| 64 |
+
'expand_feedforward': kwargs['expand_feedforward'],
|
| 65 |
+
'decoder_name': kwargs['decoder_name'],
|
| 66 |
+
}
|
| 67 |
+
# Tokenizer params
|
| 68 |
+
tokenizer_path = kwargs.get('tokenizer_path')
|
| 69 |
+
if tokenizer_path:
|
| 70 |
+
self.tokenizer = ChemformerTokenizer(filename=tokenizer_path)
|
| 71 |
+
else:
|
| 72 |
+
self.tokenizer = ChemformerTokenizer(filename=self.data_dir / f"{kwargs['tokenizer_json_name']}.json")
|
| 73 |
+
|
| 74 |
+
# Masking params
|
| 75 |
+
self.noise = LogLinearNoise()
|
| 76 |
+
self.mask_index = self.tokenizer.mask_token_id
|
| 77 |
+
|
| 78 |
+
# Sampler params
|
| 79 |
+
self.model_length = 170
|
| 80 |
+
self.noise_removal = True
|
| 81 |
+
self.nucleus_p = 0.9
|
| 82 |
+
self.eta = 0.1
|
| 83 |
+
self.sampling_steps = 100
|
| 84 |
+
self.time_conditioning = False
|
| 85 |
+
|
| 86 |
+
self.return_attention = False
|
| 87 |
+
|
| 88 |
+
self.model = ProtobindMaskedDiffusion(
|
| 89 |
+
embedding_dim=kwargs['seq_embedding_dim'],
|
| 90 |
+
mask_index=self.mask_index,
|
| 91 |
+
vocab_size=self.tokenizer.vocab_size,
|
| 92 |
+
decoder_params=decoder_params,
|
| 93 |
+
dropout=kwargs['dropout'],
|
| 94 |
+
)
|
| 95 |
+
self.optimizer = kwargs.get('optimizer', 'Adam')
|
| 96 |
+
|
| 97 |
+
def generate_mols(self, sequence: Tuple[torch.Tensor, torch.Tensor],
|
| 98 |
+
return_attention=False) -> Tuple[np.array, torch.Tensor,np.array]:
|
| 99 |
+
"""Generates and validates SMILES strings for a given protein sequence.
|
| 100 |
+
|
| 101 |
+
This method calls the internal sampler, decodes the generated tokens into
|
| 102 |
+
SMILES strings, and filters out any invalid molecules.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
sequence (Tuple[torch.Tensor, torch.Tensor]): The conditioned protein sequence
|
| 106 |
+
embedding and its length.
|
| 107 |
+
return_attention (bool): Whether to return attention maps from the sampler.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Tuple[np.array, torch.Tensor, np.array]: A tuple containing the valid SMILES
|
| 111 |
+
strings, corresponding attention maps, and the mask of valid indices.
|
| 112 |
+
"""
|
| 113 |
+
samples, attention = self._sample(sequence, return_attention=return_attention)
|
| 114 |
+
text_samples = self.tokenizer.decode(samples.long())
|
| 115 |
+
text_samples = np.array([validate_smile(smile) for smile in text_samples])
|
| 116 |
+
|
| 117 |
+
mask_invalid = (text_samples != None) & (text_samples != '.') & (text_samples != '')
|
| 118 |
+
text_samples = text_samples[mask_invalid]
|
| 119 |
+
if attention is not None:
|
| 120 |
+
attention = attention[mask_invalid]
|
| 121 |
+
|
| 122 |
+
return text_samples, attention, mask_invalid
|
| 123 |
+
|
| 124 |
+
def predict_step(self, batch, batch_idx):
|
| 125 |
+
sequence, smiles, seq_id, smi_id = batch
|
| 126 |
+
gen_samples, attention, mask_invalid = self.generate_mols(
|
| 127 |
+
sequence, return_attention=self.return_attention)
|
| 128 |
+
seq_id = seq_id[mask_invalid]
|
| 129 |
+
return gen_samples, attention, seq_id
|
| 130 |
+
|
| 131 |
+
def training_step(self, batch, batch_idx):
|
| 132 |
+
return self.common_step(batch, "train", batch_idx)
|
| 133 |
+
|
| 134 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=None):
|
| 135 |
+
# dataloader_idx to predict on several validation sets
|
| 136 |
+
return self.common_step(batch, "val", batch_idx, dataloader_idx)
|
| 137 |
+
|
| 138 |
+
def test_step(self, batch, batch_idx, dataloader_idx=0):
|
| 139 |
+
return self.common_step(batch, "test", batch_idx)
|
| 140 |
+
|
| 141 |
+
def common_step(self, batch, description, batch_idx, dataloader_idx=None):
|
| 142 |
+
"""Performs a common training, validation, or test step.
|
| 143 |
+
|
| 144 |
+
This method takes a batch, applies noise according to the diffusion
|
| 145 |
+
timestep, runs the model forward, calculates the loss, and logs metrics.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
batch (Tuple): The input batch from the dataloader.
|
| 149 |
+
description (str): The step description (e.g., 'train', 'val').
|
| 150 |
+
batch_idx (int): The index of the batch.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
torch.Tensor: The calculated loss for the batch.
|
| 154 |
+
"""
|
| 155 |
+
sequence, smiles, seq_id, smi_id = batch
|
| 156 |
+
|
| 157 |
+
# Get data and apply noise
|
| 158 |
+
X, length = smiles
|
| 159 |
+
bs = X.shape[0]
|
| 160 |
+
X = X.squeeze(-1)
|
| 161 |
+
padding_mask = (X != 0).float() # 0 is pad token id
|
| 162 |
+
t = _sample_t(X.shape[0], X.device)
|
| 163 |
+
sigma, dsigma = self.noise(t)
|
| 164 |
+
move_chance = 1 - torch.exp(-sigma[:, None])
|
| 165 |
+
xt = q_xt(X, move_chance, self.mask_index)
|
| 166 |
+
xt = xt.unsqueeze(dim=2)
|
| 167 |
+
smiles_t = (xt, length, None)
|
| 168 |
+
|
| 169 |
+
pred_x, _ = self.model(sequence, smiles_t, sigma, padding_mask)
|
| 170 |
+
total_loss = self.loss_mdlm(X.long(), pred_x, sigma, dsigma, padding_mask=None)
|
| 171 |
+
|
| 172 |
+
if batch_idx % 50 == 0:
|
| 173 |
+
tokens = pred_x.argmax(dim=-1) * padding_mask
|
| 174 |
+
true_smiles = self.tokenizer.decode(X.long())
|
| 175 |
+
pred_smiles = [smile for smile in self.tokenizer.decode(tokens)]
|
| 176 |
+
pred_smiles_valid = [validate_smile(smile) for smile in pred_smiles]
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
tanimoto = np.asarray([tanimoto_smiles(mol_pred, mol_ref) for mol_pred, mol_ref
|
| 180 |
+
in zip(pred_smiles_valid, true_smiles) if mol_pred is not None])
|
| 181 |
+
tanimoto_mean = np.mean(tanimoto) if len(tanimoto) > 0 else 0
|
| 182 |
+
num_mols_valid = len(tanimoto)
|
| 183 |
+
except:
|
| 184 |
+
num_mols_valid = 0
|
| 185 |
+
tanimoto_mean = 0.0
|
| 186 |
+
|
| 187 |
+
self.log(f"{description}_tanimoto", tanimoto_mean, prog_bar=True,
|
| 188 |
+
on_epoch=True, sync_dist=True)
|
| 189 |
+
self.log(f"{description}_perc_of_valid", num_mols_valid / bs * 100, prog_bar=True,
|
| 190 |
+
on_epoch=True, sync_dist=True)
|
| 191 |
+
|
| 192 |
+
self.log(f"{description}_loss", total_loss, prog_bar=True, on_epoch=True,
|
| 193 |
+
sync_dist=True, batch_size=bs)
|
| 194 |
+
return total_loss
|
| 195 |
+
|
| 196 |
+
def configure_optimizers(self):
|
| 197 |
+
if self.weight_decay > 0.:
|
| 198 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
| 199 |
+
else:
|
| 200 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
| 201 |
+
return optimizer
|
| 202 |
+
|
| 203 |
+
def loss_mdlm(self, x_0, model_output, sigma, dsigma, padding_mask=None):
|
| 204 |
+
"""Loss for SUBS parameterization, continuous time case"""
|
| 205 |
+
log_p_theta = torch.gather(
|
| 206 |
+
input=model_output,
|
| 207 |
+
dim=-1,
|
| 208 |
+
index=x_0[:, :, None]).squeeze(-1)
|
| 209 |
+
|
| 210 |
+
loss = - log_p_theta * (dsigma / torch.expm1(sigma))[:, None]
|
| 211 |
+
|
| 212 |
+
if padding_mask is not None:
|
| 213 |
+
return (loss * padding_mask).sum() / padding_mask.sum()
|
| 214 |
+
return loss.mean()
|
| 215 |
+
|
| 216 |
+
def _sample_prior(self, *batch_dims):
|
| 217 |
+
return self.mask_index * torch.ones(*batch_dims, dtype=torch.int64)
|
| 218 |
+
|
| 219 |
+
def _ddpm_caching_update(self, sequence, x, t, dt, p_x0=None, conf=None,
|
| 220 |
+
return_attention=False):
|
| 221 |
+
attention = None
|
| 222 |
+
if t.ndim > 1:
|
| 223 |
+
t = t.squeeze(-1)
|
| 224 |
+
sigma_t, _ = self.noise(t)
|
| 225 |
+
assert t.ndim == 1
|
| 226 |
+
move_chance_t = t[:, None, None]
|
| 227 |
+
move_chance_s = (t - dt)[:, None, None]
|
| 228 |
+
assert move_chance_t.ndim == 3, move_chance_t.shape
|
| 229 |
+
padding_mask = (x != 0).float()
|
| 230 |
+
|
| 231 |
+
if p_x0 is None:
|
| 232 |
+
p_x0, attention = self.model(sequence, (x.unsqueeze(dim=2), None, None), sigma_t,
|
| 233 |
+
padding_mask, return_attention=return_attention)
|
| 234 |
+
p_x0 = p_x0.exp()
|
| 235 |
+
if self.nucleus_p < 1:
|
| 236 |
+
sorted_probs, sorted_indices = torch.sort(p_x0, descending=True, dim=-1)
|
| 237 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 238 |
+
top_p_mask = cumulative_probs <= self.nucleus_p
|
| 239 |
+
top_p_mask[..., 0] = True
|
| 240 |
+
nucleus_probs = sorted_probs * top_p_mask
|
| 241 |
+
nucleus_probs /= nucleus_probs.sum(dim=-1, keepdim=True)
|
| 242 |
+
p_x0 = torch.zeros_like(p_x0).scatter_(-1, sorted_indices, nucleus_probs)
|
| 243 |
+
|
| 244 |
+
assert move_chance_t.ndim == p_x0.ndim
|
| 245 |
+
|
| 246 |
+
# Use remdm-cap sampler
|
| 247 |
+
alpha_t = (1 - move_chance_t)[0].item()
|
| 248 |
+
alpha_s = (1 - move_chance_s)[0].item()
|
| 249 |
+
if alpha_t > 0:
|
| 250 |
+
sigma = min(self.eta, (1 - alpha_s) / alpha_t)
|
| 251 |
+
else:
|
| 252 |
+
sigma = self.eta
|
| 253 |
+
q_xs = p_x0 * (1 - sigma)
|
| 254 |
+
q_xs[..., self.mask_index] = sigma
|
| 255 |
+
q_xs_2 = p_x0 * ((alpha_s - (1 - sigma) * alpha_t) / (1 - alpha_t))
|
| 256 |
+
q_xs_2[..., self.mask_index] = (1 - alpha_s - sigma * alpha_t) / (1 - alpha_t)
|
| 257 |
+
copy_flag = (x != self.mask_index).to(torch.bool)
|
| 258 |
+
q_xs = torch.where(copy_flag.unsqueeze(-1), q_xs, q_xs_2)
|
| 259 |
+
xs = _sample_categorical(q_xs)
|
| 260 |
+
|
| 261 |
+
if torch.allclose(xs, x) and not self.time_conditioning:
|
| 262 |
+
p_x0_cache = p_x0
|
| 263 |
+
else:
|
| 264 |
+
p_x0_cache = None
|
| 265 |
+
|
| 266 |
+
return p_x0_cache, xs, conf, attention
|
| 267 |
+
|
| 268 |
+
@torch.no_grad()
|
| 269 |
+
def _sample(self, sequence, eps=1e-3, return_attention=False):
|
| 270 |
+
"""Generate samples from the model"""
|
| 271 |
+
num_steps = self.sampling_steps
|
| 272 |
+
bs = sequence[0].shape[0]
|
| 273 |
+
x = self._sample_prior(bs, self.model_length).to(self.device)
|
| 274 |
+
|
| 275 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 276 |
+
dt = (1 - eps) / num_steps
|
| 277 |
+
p_x0_cache = None
|
| 278 |
+
|
| 279 |
+
min_t = timesteps[-1].item()
|
| 280 |
+
confident_score = - torch.ones_like(x, device=self.device) * torch.inf
|
| 281 |
+
|
| 282 |
+
for i in range(num_steps):
|
| 283 |
+
t = timesteps[i] * torch.ones(bs, 1, device=self.device)
|
| 284 |
+
p_x0_cache, x_next, confident_score, attention = self._ddpm_caching_update(
|
| 285 |
+
sequence, x, t, dt, p_x0=p_x0_cache, conf=confident_score,
|
| 286 |
+
return_attention=return_attention)
|
| 287 |
+
|
| 288 |
+
if (not torch.allclose(x_next, x)):
|
| 289 |
+
p_x0_cache = None
|
| 290 |
+
x = x_next
|
| 291 |
+
|
| 292 |
+
if self.noise_removal:
|
| 293 |
+
t = min_t * torch.ones(bs, 1, device=self.device)
|
| 294 |
+
unet_conditioning = self.noise(t)[0]
|
| 295 |
+
padding_mask = (x != 0).float()
|
| 296 |
+
x, attention = self.model(sequence, (x, None, None), unet_conditioning.squeeze(-1),
|
| 297 |
+
padding_mask, return_attention=return_attention)
|
| 298 |
+
x = x.argmax(dim=-1)
|
| 299 |
+
return x, attention
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class ProtobindMaskedDiffusion(nn.Module, huggingface_hub.PyTorchModelHubMixin):
|
| 303 |
+
"""The core Protobind-Diff model, which uses a Transformer decoder with RoPE.
|
| 304 |
+
|
| 305 |
+
This model is designed for a masked diffusion task and supports conditioning
|
| 306 |
+
on ESM-2 protein embeddings and generating ligands with a ChemformerTokenizer.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def __init__(self,
|
| 311 |
+
embedding_dim: int,
|
| 312 |
+
mask_index: int,
|
| 313 |
+
vocab_size: int,
|
| 314 |
+
decoder_params: Optional[dict] = None,
|
| 315 |
+
dropout: float = 0.2,
|
| 316 |
+
parametrization_strategy: str = 'subs',
|
| 317 |
+
**kwargs) -> None:
|
| 318 |
+
"""Initializes the ProtobindMaskedDiffusion model.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
embedding_dim (int): The dimension of the protein sequence embeddings.
|
| 322 |
+
mask_index (int): The token ID for the MASK token.
|
| 323 |
+
vocab_size (int): The size of the ligand's vocabulary.
|
| 324 |
+
decoder_params (Optional[dict]): A dictionary of parameters for the
|
| 325 |
+
internal Transformer decoder (e.g., nhead, n_layers).
|
| 326 |
+
dropout (float): The dropout rate.
|
| 327 |
+
parametrization_strategy (str): The diffusion parameterization to use.
|
| 328 |
+
Currently only 'subs' is supported.
|
| 329 |
+
"""
|
| 330 |
+
super().__init__()
|
| 331 |
+
|
| 332 |
+
self.neg_infinity = -1000000.0
|
| 333 |
+
self.parametrization_strategy = parametrization_strategy
|
| 334 |
+
self.decoder_name = decoder_params.pop('decoder_name')
|
| 335 |
+
expand_feedforward = decoder_params.pop('expand_feedforward')
|
| 336 |
+
self.mask_index = mask_index
|
| 337 |
+
|
| 338 |
+
# Decoder options
|
| 339 |
+
if self.decoder_name == 'decoder_re':
|
| 340 |
+
self.decoder = Decoder_RoPE(vocab_size, embedding_dim, expand_feedforward=expand_feedforward,
|
| 341 |
+
dropout=dropout, **decoder_params)
|
| 342 |
+
else:
|
| 343 |
+
raise ValueError(f"Model only supports decoder with rotary embeddings ('decoder_re'), got: {self.decoder_name}")
|
| 344 |
+
|
| 345 |
+
def forward(self,
|
| 346 |
+
sequence: Tuple[torch.Tensor, torch.Tensor],
|
| 347 |
+
ligands: Tuple[torch.Tensor, torch.Tensor],
|
| 348 |
+
sigma: torch.Tensor,
|
| 349 |
+
mask_ligand: torch.Tensor,
|
| 350 |
+
return_attention: bool = False) -> torch.Tensor:
|
| 351 |
+
"""Performs the main forward pass of the diffusion model.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
sequence (Tuple[torch.Tensor, torch.Tensor]): A tuple of the conditioning
|
| 355 |
+
protein sequence embeddings and their lengths.
|
| 356 |
+
ligands (Tuple[torch.Tensor, torch.Tensor]): A tuple
|
| 357 |
+
containing the noised ligand `xt`and its length.
|
| 358 |
+
sigma (torch.Tensor): The diffusion timestep (noise level).
|
| 359 |
+
mask_ligand (torch.Tensor): The padding mask for the ligand.
|
| 360 |
+
return_attention (bool): If True, returns attention weights from the decoder.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the final predicted logits
|
| 364 |
+
and the attention weights.
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
sequence, sequence_lengths = sequence
|
| 368 |
+
xt, ligand_lengths, _ = ligands
|
| 369 |
+
|
| 370 |
+
# Decode ligand
|
| 371 |
+
ligand_masked = xt.squeeze(-1).long()
|
| 372 |
+
ligand_decoded, attention = self.decoder(ligand_masked,
|
| 373 |
+
sigma,
|
| 374 |
+
sequence,
|
| 375 |
+
sequence_lengths,
|
| 376 |
+
lig_padding_mask=None,
|
| 377 |
+
return_attention=return_attention)
|
| 378 |
+
|
| 379 |
+
# Apply parametrization
|
| 380 |
+
ligand_decoded = self.parametrization(ligand_decoded, xt)
|
| 381 |
+
|
| 382 |
+
return ligand_decoded, attention
|
| 383 |
+
|
| 384 |
+
def parametrization(self, logits, xt):
|
| 385 |
+
"""Applies the chosen parameterization to the model's output logits.
|
| 386 |
+
|
| 387 |
+
The 'subs' strategy modifies the logits to represent the probability
|
| 388 |
+
p(x_{t-1}|x_t), enforcing that unmasked tokens remain unchanged.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
logits (torch.Tensor): The raw output logits from the decoder.
|
| 392 |
+
xt (torch.Tensor): The noised input ligand at timestep t.
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
torch.Tensor: The re-parameterized logits.
|
| 396 |
+
"""
|
| 397 |
+
if self.parametrization_strategy == 'subs':
|
| 398 |
+
# log prob at the mask index = - infinity
|
| 399 |
+
logits[:, :, self.mask_index] += self.neg_infinity
|
| 400 |
+
|
| 401 |
+
# Normalize the logits
|
| 402 |
+
logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
|
| 403 |
+
|
| 404 |
+
# Apply updates for unmasked tokens
|
| 405 |
+
xt = xt.squeeze(-1)
|
| 406 |
+
unmasked_indices = (xt != self.mask_index)
|
| 407 |
+
logits[unmasked_indices] = self.neg_infinity
|
| 408 |
+
logits[unmasked_indices, xt[unmasked_indices].long()] = 0
|
| 409 |
+
else:
|
| 410 |
+
raise NotImplementedError(f'Parametrization strategy {self.parametrization_strategy} not implemented')
|
| 411 |
+
return logits
|
protobind_diff/noise_schedule.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
# Flags required to enable jit fusion kernels
|
| 7 |
+
torch._C._jit_set_profiling_mode(False)
|
| 8 |
+
torch._C._jit_set_profiling_executor(False)
|
| 9 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 10 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _sample_categorical(categorical_probs):
|
| 15 |
+
gumbel_norm = (
|
| 16 |
+
1e-10
|
| 17 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 18 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _unsqueeze(x, reference):
|
| 22 |
+
return x.view(
|
| 23 |
+
* x.shape,
|
| 24 |
+
* ((1,) * (len(reference.shape) - len(x.shape))))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _sample_t(n, device, antithetic_sampling=True, sampling_eps=1e-3):
|
| 28 |
+
_eps_t = torch.rand(n, device=device)
|
| 29 |
+
if antithetic_sampling:
|
| 30 |
+
offset = torch.arange(n, device=device) / n
|
| 31 |
+
_eps_t = (_eps_t / n + offset) % 1
|
| 32 |
+
t = (1 - sampling_eps) * _eps_t + sampling_eps
|
| 33 |
+
return t
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def q_xt( x, move_chance, mask_index):
|
| 37 |
+
"""Computes the noisy sample xt.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
x: int torch.Tensor with shape (batch_size,
|
| 41 |
+
diffusion_model_input_length), input.
|
| 42 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 43 |
+
"""
|
| 44 |
+
move_indices = torch.rand(
|
| 45 |
+
* x.shape, device=x.device) < move_chance
|
| 46 |
+
xt = torch.where(move_indices, mask_index, x)
|
| 47 |
+
return xt
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_noise(config, dtype=torch.float32):
|
| 51 |
+
if config.noise.type == 'geometric':
|
| 52 |
+
return GeometricNoise(config.noise.sigma_min,
|
| 53 |
+
config.noise.sigma_max)
|
| 54 |
+
elif config.noise.type == 'loglinear':
|
| 55 |
+
return LogLinearNoise()
|
| 56 |
+
elif config.noise.type == 'cosine':
|
| 57 |
+
return CosineNoise()
|
| 58 |
+
elif config.noise.type == 'cosinesqr':
|
| 59 |
+
return CosineSqrNoise()
|
| 60 |
+
elif config.noise.type == 'linear':
|
| 61 |
+
return Linear(config.noise.sigma_min,
|
| 62 |
+
config.noise.sigma_max,
|
| 63 |
+
dtype)
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f'{config.noise.type} is not a valid noise')
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def binary_discretization(z):
|
| 69 |
+
z_hard = torch.sign(z)
|
| 70 |
+
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
|
| 71 |
+
return z_soft + (z_hard - z_soft).detach()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Noise(abc.ABC, nn.Module):
|
| 75 |
+
"""
|
| 76 |
+
Baseline forward method to get the total + rate of noise at a timestep
|
| 77 |
+
"""
|
| 78 |
+
def forward(self, t):
|
| 79 |
+
# Assume time goes from 0 to 1
|
| 80 |
+
return self.total_noise(t), self.rate_noise(t)
|
| 81 |
+
|
| 82 |
+
@abc.abstractmethod
|
| 83 |
+
def rate_noise(self, t):
|
| 84 |
+
"""
|
| 85 |
+
Rate of change of noise ie g(t)
|
| 86 |
+
"""
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
@abc.abstractmethod
|
| 90 |
+
def total_noise(self, t):
|
| 91 |
+
"""
|
| 92 |
+
Total noise ie \int_0^t g(t) dt + g(0)
|
| 93 |
+
"""
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class CosineNoise(Noise):
|
| 98 |
+
def __init__(self, eps=1e-3):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.eps = eps
|
| 101 |
+
|
| 102 |
+
def rate_noise(self, t):
|
| 103 |
+
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
|
| 104 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
|
| 105 |
+
scale = torch.pi / 2
|
| 106 |
+
return scale * sin / (cos + self.eps)
|
| 107 |
+
|
| 108 |
+
def total_noise(self, t):
|
| 109 |
+
cos = torch.cos(t * torch.pi / 2)
|
| 110 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class CosineSqrNoise(Noise):
|
| 114 |
+
def __init__(self, eps=1e-3):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.eps = eps
|
| 117 |
+
|
| 118 |
+
def rate_noise(self, t):
|
| 119 |
+
cos = (1 - self.eps) * (
|
| 120 |
+
torch.cos(t * torch.pi / 2) ** 2)
|
| 121 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi)
|
| 122 |
+
scale = torch.pi / 2
|
| 123 |
+
return scale * sin / (cos + self.eps)
|
| 124 |
+
|
| 125 |
+
def total_noise(self, t):
|
| 126 |
+
cos = torch.cos(t * torch.pi / 2) ** 2
|
| 127 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Linear(Noise):
|
| 131 |
+
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
|
| 134 |
+
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
|
| 135 |
+
|
| 136 |
+
def rate_noise(self, t):
|
| 137 |
+
return self.sigma_max - self.sigma_min
|
| 138 |
+
|
| 139 |
+
def total_noise(self, t):
|
| 140 |
+
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
|
| 141 |
+
|
| 142 |
+
def importance_sampling_transformation(self, t):
|
| 143 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 144 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 145 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 146 |
+
return (sigma_t - self.sigma_min) / (
|
| 147 |
+
self.sigma_max - self.sigma_min)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class GeometricNoise(Noise):
|
| 151 |
+
def __init__(self, sigma_min=1e-3, sigma_max=1):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
| 154 |
+
|
| 155 |
+
def rate_noise(self, t):
|
| 156 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
|
| 157 |
+
self.sigmas[1].log() - self.sigmas[0].log())
|
| 158 |
+
|
| 159 |
+
def total_noise(self, t):
|
| 160 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class LogLinearNoise(Noise):
|
| 164 |
+
"""Log Linear noise schedule.
|
| 165 |
+
|
| 166 |
+
Built such that 1 - 1/e^(n(t)) interpolates between 0 and 1.
|
| 167 |
+
"""
|
| 168 |
+
def __init__(self, eps=1e-3):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.eps = eps
|
| 171 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 172 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 173 |
+
|
| 174 |
+
def rate_noise(self, t):
|
| 175 |
+
return (1 - self.eps) / (1 - (1 - self.eps) * t)
|
| 176 |
+
|
| 177 |
+
def total_noise(self, t):
|
| 178 |
+
return -torch.log1p(-(1 - self.eps) * t)
|
| 179 |
+
|
| 180 |
+
def importance_sampling_transformation(self, t):
|
| 181 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 182 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 183 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 184 |
+
t = - torch.expm1(- sigma_t) / (1 - self.eps)
|
| 185 |
+
return t
|
pyproject.toml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "ProtoBind-Diff"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "ProtoBind-Diff: A Structure-Free Diffusion Language Model for Protein Sequence-Conditioned Ligand Design"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10,<3.13"
|
| 7 |
+
license = "MIT AND (Apache-2.0 OR BSD-2-Clause)"
|
| 8 |
+
authors = [
|
| 9 |
+
{ name = "Lukia Mistryukova", email = "lukiia.mistriukova@gero.ai" },
|
| 10 |
+
{ name = "Vladimir Manuilov", email = "vladimir.manuylov@gero.ai" },
|
| 11 |
+
{ name = "Konstantin Avchaciov", email = "ka@gero.ai" },
|
| 12 |
+
{ name = "Peter Fedichev", email = "pf@gero.ai" },
|
| 13 |
+
]
|
| 14 |
+
dependencies = [
|
| 15 |
+
"torch>=2.2",
|
| 16 |
+
"numpy>=1.26,<2.0",
|
| 17 |
+
"lightning>=2.3.0",
|
| 18 |
+
"rdkit>=2024.3.2",
|
| 19 |
+
"requests==2.32.3",
|
| 20 |
+
"pandas>=2.2.2",
|
| 21 |
+
"PyYAML>=6.0",
|
| 22 |
+
"scipy>=1.13.0",
|
| 23 |
+
"scikit-learn>=1.1.0",
|
| 24 |
+
"fair-esm==2.0.0",
|
| 25 |
+
"biopython>=1.80",
|
| 26 |
+
"pysmilesutils @ git+https://github.com/MolecularAI/pysmilesutils.git",
|
| 27 |
+
"FPSim2==0.5.2",
|
| 28 |
+
"huggingface_hub",
|
| 29 |
+
"einops==0.8.0",
|
| 30 |
+
"easydict>=1.11",
|
| 31 |
+
"tensorboard>=2.14.0",
|
| 32 |
+
"rich>=13.5.0"
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
[project.scripts]
|
| 36 |
+
protobind-train = "protobind_diff.train:main"
|
| 37 |
+
protobind-infer = "protobind_diff.esm_inference:main"
|
| 38 |
+
|
| 39 |
+
[build-system]
|
| 40 |
+
requires = ["setuptools>=61.0"]
|
| 41 |
+
build-backend = "setuptools.build_meta"
|
| 42 |
+
|
| 43 |
+
[tool.setuptools.packages.find]
|
| 44 |
+
include = ["protobind_diff"]
|