vladimir.manuylov commited on
Commit
bd082dc
·
1 Parent(s): 8e21d42

initial commit

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: ProtoBind Diff
3
- emoji: 🏢
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
@@ -11,4 +11,31 @@ license: mit
11
  short_description: Structure-free target-specific molecule generation
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: ProtoBind Diff
3
+ emoji: 💊
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
 
11
  short_description: Structure-free target-specific molecule generation
12
  ---
13
 
14
+ ## A Structure-Free Diffusion Language Model for Protein Sequence-Conditioned Ligand Design
15
+
16
+ <a href="https://www.biorxiv.org/content/10.1101/2025.06.16.659955v1">
17
+ <img
18
+ src="https://img.shields.io/badge/bioRxiv-paper-blue?logo=biorxiv&logoColor=white"
19
+ alt="Paper on bioRxiv"
20
+ />
21
+ </a>
22
+ <a href="https://github.com/gero-science/ProtoBind-Diff">
23
+ <img
24
+ src="https://img.shields.io/badge/GitHub-code-black?logo=github&logoColor=white"
25
+ alt="View on GitHub"
26
+ />
27
+ </a>
28
+
29
+ ## Citation
30
+
31
+ ```bibtex
32
+ @article {Mistryukova2025.06.16.659955,
33
+ author = {Mistryukova, Lukia and Manuilov, Vladimir and Avchaciov, Konstantin and Fedichev, Peter O.},
34
+ title = {ProtoBind-Diff: A Structure-Free Diffusion Language Model for Protein Sequence-Conditioned Ligand Design},
35
+ year = {2025},
36
+ journal = {bioRxiv}
37
+ }
38
+ ```
39
+
40
+ ## License
41
+ The code and model weights are released under MIT license. See the [LICENSE](LICENSE) file for details.
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # --- IMPORTS ---
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ import lightning.pytorch as pl
10
+ from protobind_diff.esm_inference import get_esm_embedding
11
+ from protobind_diff.model import ModelGenerator
12
+ from protobind_diff.data_loader import InferenceDataset
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # Hugging Face Hub details
16
+ REPO_ID = "ai-gero/ProtoBind-Diff"
17
+ MODEL_FILENAME = "model.ckpt"
18
+ TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
19
+
20
+
21
+ def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
22
+ """
23
+ The main prediction function that runs the full pipeline.
24
+ """
25
+ if not protein_sequence:
26
+ raise gr.Error("Protein sequence cannot be empty.")
27
+ protein_sequence = re.sub(r"[^A-Z]", "", protein_sequence.upper())
28
+ if len(protein_sequence) < 10:
29
+ raise gr.Error("Protein sequence is too short.")
30
+
31
+ embedding = get_esm_embedding(
32
+ protein_sequence,
33
+ 'esm2_t33_650M_UR50D',
34
+ device
35
+ ).to(dtype=torch.bfloat16)
36
+ n_batches = num_samples // 10
37
+ dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches)
38
+ loader = DataLoader(dataset, batch_size=None)
39
+
40
+ trainer = pl.Trainer(
41
+ accelerator="auto",
42
+ devices=1,
43
+ logger=False,
44
+ precision="16-mixed" if device == "cuda" else "32-true"
45
+ )
46
+
47
+ predictions_batches = trainer.predict(model=protobind_model, dataloaders=loader)
48
+
49
+ all_smiles = [smi for batch in predictions_batches for smi in batch[0]]
50
+ unique_smiles = list(set(all_smiles))
51
+
52
+ return ",\n".join(unique_smiles)
53
+
54
+
55
+ # --- GRADIO APP DEFINITION ---
56
+
57
+ # Load models on app startup
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ tokenizer_path = hf_hub_download(
60
+ repo_id=REPO_ID,
61
+ filename=TOKENIZER_FILENAME,
62
+ )
63
+ ckpt_path = hf_hub_download(
64
+ repo_id=REPO_ID,
65
+ filename=MODEL_FILENAME,
66
+ )
67
+ protobind_model = ModelGenerator.load_from_checkpoint(
68
+ ckpt_path,
69
+ map_location=device,
70
+ tokenizer_path=tokenizer_path,
71
+ seq_embedding_dim=1280,
72
+ load=True,
73
+ )
74
+ protobind_model.eval()
75
+ protobind_model.to(device)
76
+ # Define the UI
77
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
78
+ gr.Markdown(
79
+ """
80
+ # ProtoBind-Diff: Protein-Conditioned Ligand Generation
81
+ This Space demonstrates **ProtoBind-Diff**, a diffusion model for generating novel drug-like molecules (ligands)
82
+ conditioned on a target protein sequence. Provide a protein's amino acid sequence to generate potential binding molecules in SMILES format.
83
+ """
84
+ )
85
+
86
+ with gr.Row():
87
+ with gr.Column(scale=2):
88
+ protein_sequence = gr.Textbox(
89
+ lines=10,
90
+ label="Protein Amino Acid Sequence",
91
+ placeholder="Enter your protein sequence here (e.g., MGY...)"
92
+ )
93
+ num_samples = gr.Slider(
94
+ minimum=10,
95
+ maximum=200,
96
+ value=50,
97
+ step=10,
98
+ label="Generation Attempts",
99
+ info=(
100
+ "Upper limit on generation attempts. Duplicates and invalid molecules "
101
+ "are discarded, so the final count of unique molecules may be lower. "
102
+ "More attempts increase runtime but can improve diversity."
103
+ )
104
+ )
105
+ submit_btn = gr.Button("Generate Molecules", variant="primary")
106
+
107
+ with gr.Column(scale=3):
108
+ output_smiles = gr.Textbox(
109
+ lines=15,
110
+ label="Generated SMILES",
111
+ info="A list of unique, valid SMILES strings generated for the target protein.",
112
+ interactive=True
113
+ )
114
+
115
+ gr.Markdown("### Examples")
116
+ gr.Examples(
117
+ examples=[
118
+ ["MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS",
119
+ 50],
120
+ ["MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWTVDSENRTNLSCEGCLSPSCLSLLHLQEKNWSALLTAVVIILTIAGNILVIMAVSLEKKLQNATNYFLMSLAIADMLLGFLVMPVSMLTILYGYRWPLPSKLCAVWIYLDVLFSTASIMHLCAISLDRYVAIQNPIHHSRFNSRTKAFLKIIAVWTISVGISMPIPVFGLQDDSKVFKEGSCLLADDNFVLIGSFVSFFIPLTIMVITYFLTIKSLQKEATLCVSDLGTRAKLASFSFLPQSSLSSEKLFQRSIHREPGSYTGRRTMQSISNEQKACKVLGIVFFLFVVMWCPFFITNIMAVICKESCNEDVIGALLNVFVWIGYLSSAVNPLVYTLFNKTYRSAFSRYIQCQYKENKKPLQLILVNTIPALAYKSSQLQMGQKKNSKQDAKTTDNDCSMVALGKQHSEEASKDNSDGVNEKVSCV",
121
+ 100]
122
+ ],
123
+ inputs=[protein_sequence, num_samples],
124
+ outputs=output_smiles,
125
+ fn=generate_smiles_for_sequence,
126
+ cache_examples=False,
127
+ )
128
+
129
+ gr.Markdown(
130
+ """
131
+ ---
132
+ *Model developed by Gero AI. For more details, check out the [original repository](https://github.com/gero-science/ProtoBind-Diff).*
133
+ """
134
+ )
135
+
136
+ submit_btn.click(
137
+ fn=generate_smiles_for_sequence,
138
+ inputs=[protein_sequence, num_samples],
139
+ outputs=output_smiles
140
+ )
141
+
142
+ # Launch the app
143
+ if __name__ == "__main__":
144
+ demo.launch(share=True)
145
+
146
+
protobind_diff/__init__.py ADDED
File without changes
protobind_diff/data_loader.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data loader for the protobind-diff.
2
+ # This version only supports ProtobindMaskedDiffusion with SMILES and ESM-2 protein encodings.
3
+ import os.path
4
+ import json
5
+ import logging
6
+ from pathlib import Path
7
+ from enum import Enum
8
+ from typing import Dict, List, Tuple, Optional, Union
9
+ from zipfile import ZipFile
10
+
11
+ import lightning.pytorch as pl
12
+ import numpy as np
13
+ import torch
14
+ import pandas as pd
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ from tqdm.auto import tqdm
18
+
19
+ from .ligands.smiles_tokenizer import ChemformerTokenizer
20
+ from .ligands.rdkit_utils import randomize_smiles_rotated, cluster_fpsim2
21
+
22
+ logger = logging.getLogger("lightning")
23
+
24
+
25
+ class SplittingMethod(Enum):
26
+ # enum that describes various train/val/test splitting methods.
27
+ RANDOM = 1
28
+
29
+
30
+ def split_at_random(df: pd.DataFrame, valid_fraction=0.1, test_fraction=0.1, seed=777):
31
+ """Randomly splits a DataFrame into training, validation, and test sets.
32
+
33
+ Args:
34
+ df (pd.DataFrame): The DataFrame to split.
35
+ valid_fraction (float): The fraction of the data to allocate to the validation set.
36
+ test_fraction (float): The fraction of the data to allocate to the test set.
37
+ seed (int): The random seed for shuffling to ensure reproducibility.
38
+
39
+ Returns:
40
+ Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing the
41
+ training, validation, and test DataFrames.
42
+ """
43
+ df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
44
+ valid_size = int(len(df) * valid_fraction)
45
+ test_size = int(len(df) * test_fraction)
46
+ train_size = len(df) - valid_size - test_size
47
+ train_df = df[:train_size]
48
+ valid_df = df[train_size:train_size + valid_size]
49
+ test_df = df[train_size + valid_size:]
50
+ return train_df, valid_df, test_df
51
+
52
+
53
+ class RandomizedSmilesDataset(object):
54
+ """Creates a dataset of tokenized SMILES strings, with an option for on-the-fly randomization.
55
+
56
+ This dataset maps integer indices to SMILES strings and provides tokenized
57
+ representations. It can randomize SMILES strings during data retrieval to
58
+ augment the training data.
59
+
60
+ Attributes:
61
+ smiles (pd.Series): A series of SMILES strings indexed by integers.
62
+ tokenizer (ChemformerTokenizer): The tokenizer for converting SMILES to tokens.
63
+ randomize (bool): If True, applies SMILES randomization at retrieval time.
64
+ """
65
+ def __init__(self, smiles: dict, tokenizer: ChemformerTokenizer,
66
+ randomize: bool = True):
67
+ self.smiles = pd.Series(data=smiles.keys(), index=smiles.values()).sort_index()
68
+ assert len(self.smiles) == self.smiles.index[-1] + 1, (f"{len(self.smiles)}"
69
+ f" {self.smiles.index[:5]} {self.smiles.index[-5:]}")
70
+ self.tokenizer = tokenizer
71
+ self.randomize = randomize
72
+ logger.info(f"Molecular dataset initialized: RandomizedSmilesDataset {type(self.tokenizer)}"
73
+ f" random: {self.randomize}")
74
+
75
+ def __len__(self):
76
+ return len(self.smiles)
77
+
78
+ def __getitem__(self, item):
79
+ smi = self.smiles[item]
80
+ if self.randomize:
81
+ smi = randomize_smiles_rotated(smi)
82
+ mol = self.tokenizer.encode(smi)[0]
83
+ return mol
84
+
85
+ @classmethod
86
+ def from_json(cls, path, **kwargs):
87
+ with open(path) as f:
88
+ categorical_mappings = json.load(f)
89
+ smiles = categorical_mappings['smiles']
90
+ loaded = cls(smiles, **kwargs)
91
+ return loaded
92
+
93
+
94
+ class RandomizedBatchSampler(torch.utils.data.Sampler):
95
+ """A batch sampler that minimizes padding while maximizing batch randomness.
96
+
97
+ To achieve this, the sampler employs a two-level shuffling strategy:
98
+ 1. The data is first sorted by sequence length and grouped into buckets.
99
+ 2. Within each bucket, the sample indices are shuffled.
100
+ 3. Batches are created by slicing across the globally sorted list of indices,
101
+ which keeps sequence lengths within a batch similar.
102
+ 4. The order of these batches is then shuffled to ensure randomness across epochs.
103
+
104
+ This approach balances the trade-off between minimizing padding (by batching
105
+ similar-length sequences) and maintaining randomness required for effective training.
106
+ """
107
+
108
+ def __init__(self, sequence_length: np.ndarray, shuffle: bool, batch_volume: int,
109
+ generator: torch.Generator = None, num_ranges: int = 150, batch_size: int = 128):
110
+ """Initializes the RandomizedBatchSampler.
111
+
112
+ Args:
113
+ sequence_length (np.ndarray): An array of sequence lengths for each item in the dataset.
114
+ shuffle (bool): If True, shuffle batches and indices within length buckets.
115
+ batch_volume (int): The maximum total number of elements (seq_len^2) per batch.
116
+ generator (torch.Generator, optional): PyTorch random number generator. Defaults to None.
117
+ num_ranges (int): The number of buckets to partition the sequence lengths into.
118
+ batch_size (int): The maximum number of samples per batch.
119
+ """
120
+ self.shuffle = shuffle
121
+ # For val/test (i.e. when we don't shuffle) we can fit more batches in memory as we don't need grads.
122
+ batch_volume_factor = 1 if shuffle else 2
123
+ self.batch_volume = batch_volume * batch_volume_factor
124
+ assert max(sequence_length) ** 2 < self.batch_volume, \
125
+ f"Cannot fit sequence {max(sequence_length)=} to {batch_volume=}"
126
+
127
+ if generator is None:
128
+ self.generator = self._init_generator()
129
+ else:
130
+ self.generator = generator
131
+ self.num_ranges = num_ranges
132
+ self.sequence_length = sequence_length
133
+ self.sequence_length_2 = self.sequence_length ** 2
134
+ self.batch_size = batch_size
135
+
136
+ bins = np.linspace(np.min(sequence_length), np.max(sequence_length) + 1, num_ranges)
137
+ digit_bins = np.digitize(sequence_length, bins=bins, right=True)
138
+ self.sequence_length_buckets = [torch.tensor(np.where(digit_bins == i)[0],
139
+ dtype=torch.int32) for i in range(num_ranges)]
140
+ self._prepared_batches = None
141
+
142
+ def _get_sliced_batches(self):
143
+ if self.shuffle:
144
+ # reshuffle the sequence length buckets.
145
+ for i in range(len(self.sequence_length_buckets)):
146
+ self.sequence_length_buckets[i] = self.sequence_length_buckets[i][torch.randperm(
147
+ len(self.sequence_length_buckets[i]), generator=self.generator)]
148
+
149
+ current_batch = []
150
+ current_batch_volume = 0
151
+ current_batch_size = 0
152
+ for i in range(self.num_ranges):
153
+ for idx in self.sequence_length_buckets[i]:
154
+ if (current_batch_volume + self.sequence_length_2[idx] >= self.batch_volume
155
+ or current_batch_size >= self.batch_size):
156
+ yield current_batch
157
+ current_batch = []
158
+ current_batch_volume = 0
159
+ current_batch_size = 0
160
+ current_batch.append(idx.item())
161
+ current_batch_volume += self.sequence_length_2[idx]
162
+ current_batch_size += 1
163
+ if len(current_batch) > 0:
164
+ yield current_batch
165
+
166
+ @staticmethod
167
+ def _init_generator():
168
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
169
+ generator = torch.Generator()
170
+ generator.manual_seed(seed)
171
+ return generator
172
+
173
+ @property
174
+ def _length(self):
175
+ if self._prepared_batches is None:
176
+ self._prepared_batches = list(self._get_sliced_batches())
177
+ return len(self._prepared_batches)
178
+
179
+ def __len__(self):
180
+ return self._length
181
+
182
+ def __iter__(self):
183
+ if self.shuffle:
184
+ # Then get the batches and serve them in random order
185
+ if self._prepared_batches is None:
186
+ self._prepared_batches = list(self._get_sliced_batches())
187
+ for batch_idx in torch.randperm(self._length, generator=self.generator):
188
+ yield self._prepared_batches[batch_idx]
189
+ self._prepared_batches = None # Destroy _prepared_batches to recreate it again in __len__
190
+ else:
191
+ for batch in self._get_sliced_batches():
192
+ yield batch
193
+
194
+
195
+ class ProtobindDataModule(pl.LightningDataModule):
196
+ """PyTorch Lightning DataModule for Protobind-diffusion datasets.
197
+
198
+ This module handles the loading, processing, and batching of protein-ligand
199
+ data. It is designed to work with ESM-2 protein embeddings and tokenized
200
+ SMILES representations for ligands. The module manages data splitting,
201
+ feature loading, and provides DataLoaders with an efficient batching
202
+ strategy to minimize padding.
203
+
204
+ Key Features:
205
+ - Loads pre-computed ESM-2 protein embeddings.
206
+ - Utilizes tokenized SMILES for ligands via `ChemformerTokenizer`.
207
+ - Implements a `RandomizedBatchSampler` to create efficient, low-padding batches.
208
+ - Handles dataset splitting into training, validation, and test sets.
209
+ """
210
+ MASK_VALUE = 0
211
+
212
+ def __init__(self, *,
213
+ data_dir: Path,
214
+ exp_dir: Path,
215
+ splitting_method: SplittingMethod,
216
+ batch_volume: int,
217
+ num_workers: int,
218
+ sequence_type: str = 'esm_zip',
219
+ esm_model_name: str = "esm2_t33_650M_UR50D",
220
+ max_size_batch: int = 16,
221
+ dataset_params: Optional[dict] = None,
222
+ float_type: str = 'float32'):
223
+ super().__init__()
224
+ """Initializes the ProtobindDataModule.
225
+
226
+ Args:
227
+ data_dir (Path): The directory containing the raw dataset files (e.g., data.csv, embeddings).
228
+ exp_dir (Path): The directory to save experiment artifacts, including data splits.
229
+ splitting_method (SplittingMethod): The method for splitting data (e.g., RANDOM).
230
+ batch_volume (int): The target batch volume for the RandomizedBatchSampler.
231
+ num_workers (int): The number of workers for the DataLoader.
232
+ sequence_type (str): The type of protein sequence data. Must be 'esm_zip'.
233
+ esm_model_name (str): The specific ESM model name for embeddings.
234
+ max_size_batch (int): The maximum number of samples in a batch.
235
+ dataset_params (Optional[dict]): Parameters for the underlying molecular dataset.
236
+ float_type (str): The floating-point precision to use.
237
+ """
238
+ self.csv_path = data_dir / "data.csv"
239
+ self.categorical_mappings_path = data_dir / "categorical_mappings.json"
240
+
241
+ # Validate sequence type - only allow ESM variants
242
+ if sequence_type not in ['esm_zip']:
243
+ raise ValueError(f"DataModule only supports only 'esm_zip' sequence type, got: {sequence_type}")
244
+
245
+ # directory structure:
246
+ # output_dir / split / exp_dir_prefix
247
+ self.exp_dir: Path = Path(exp_dir)
248
+ self.split_dir: Path = self.exp_dir.parent
249
+ self.exp_data_dir: Path = self.split_dir.parent
250
+ self.data_dir = data_dir
251
+
252
+ if dataset_params is None:
253
+ dataset_params = {}
254
+
255
+ # Create simplified SMILES dataloader
256
+ self.molecular_dataloader = MolecularDataloaderSMILES(
257
+ data_dir=data_dir,
258
+ dataset_options=dataset_params,
259
+ )
260
+
261
+ self.float_type = float_type
262
+ self.batch_volume = batch_volume
263
+ self.max_size_batch = max_size_batch
264
+ self.num_workers = num_workers
265
+ self.splitting_method = splitting_method
266
+ self.esm_model_name = esm_model_name
267
+
268
+ # Only support ESM embeddings (float type data)
269
+ self.sequence_dtype = getattr(torch, self.float_type)
270
+
271
+ # Will be initialized in setup()
272
+ self.train_dataset: Optional[torch.utils.data.Dataset] = None
273
+ self.val_dataset: Optional[torch.utils.data.Dataset] = None
274
+ self.test_dataset: Optional[torch.utils.data.Dataset] = None
275
+
276
+ self.datasets: Dict[str, pd.DataFrame] = {}
277
+ self.torch_datasets: Dict[str, torch.utils.data.Dataset] = {}
278
+
279
+ @staticmethod
280
+ def _read_df(csv_path: Path) -> pd.DataFrame:
281
+ _use_columns = ['smiles', 'sequence', 'log_IC50', 'log_Ki', 'log_Kd', 'log_EC50', 'label', 'split',
282
+ 'cluster_smi']
283
+ df = pd.read_csv(csv_path, nrows=1)
284
+ _use_columns = df.columns.intersection(_use_columns)
285
+
286
+ dtypes = {"smiles": int, "sequence": int, "log_IC50": float,
287
+ "log_Ki": float, "log_Kd": float, "log_EC50": float,
288
+ "label": float, "split": str, "cluster_smi": str}
289
+
290
+ df = pd.read_csv(csv_path, dtype=dtypes, usecols=_use_columns)
291
+ return df
292
+
293
+ @staticmethod
294
+ def _read_df_and_compute_sequence_lengths(csv_path: Path, length_dict: dict) -> pd.DataFrame:
295
+ # to reduce RAM load only necessary columns
296
+ df = ProtobindDataModule._read_df(csv_path)
297
+ df['sequence_length'] = df["sequence"].map(length_dict)
298
+
299
+ # sort by sequence length to increase the batching efficiency.
300
+ df.sort_values(by="sequence_length", inplace=True)
301
+ return df
302
+
303
+ def check_splits_exist(self):
304
+ """ Tries to find that train-test split exist """
305
+ if (self.split_dir / "train.csv").exists():
306
+ assert (self.split_dir / "valid.csv").exists()
307
+ assert (self.split_dir / "test.csv").exists()
308
+ logger.info(f"train.csv/valid.csv/test.csv exist, "
309
+ f"no new splits will be created for {self.splitting_method}")
310
+ return True
311
+
312
+ return False
313
+
314
+ def prepare_data_split(self, seed=777, valid_fraction=0.1, test_fraction=0.1):
315
+ """ Create train.csv, val.csv and test.csv in the experiment dir """
316
+
317
+ if self.check_splits_exist():
318
+ return
319
+
320
+ # Check that data exists
321
+ for path in [self.csv_path, self.categorical_mappings_path]:
322
+ if not path.exists():
323
+ raise FileNotFoundError(
324
+ f"Could not find {path}. Please download the data.")
325
+
326
+ # load label data
327
+ data_df = pd.read_csv(self.csv_path)
328
+
329
+ # add clusters
330
+ distance_data = list(self.csv_path.parent.glob('all_smiles_sparse_*.npz'))
331
+ if len(distance_data) > 0:
332
+ logger.info(f"Calculating clusters for SMILES and distance data {distance_data[0]}")
333
+ clusters_smi = cluster_fpsim2(distance_data[0])
334
+ len_ = len(data_df)
335
+ data_df = data_df.merge(pd.Series(clusters_smi, name='cluster_smi'), left_on='smiles', right_index=True)
336
+ assert data_df.shape[0] == len_, (f"Failed to merge clusters, {len_=} {data_df.shape=}"
337
+ f" {clusters_smi.min()} {clusters_smi.max()}")
338
+ else:
339
+ raise FileNotFoundError(f'Could not find any all_smiles_sparse_*.npz in {str(self.csv_path.parent)}')
340
+
341
+ # Create splits
342
+ if self.splitting_method == SplittingMethod.RANDOM:
343
+ train, valid, test = split_at_random(data_df, valid_fraction=valid_fraction,
344
+ test_fraction=test_fraction, seed=seed)
345
+ else:
346
+ raise NotImplementedError(
347
+ f"Splitting method {self.splitting_method} is not implemented in simplified version.")
348
+
349
+ train.to_csv(self.split_dir / "train.csv", index=False)
350
+ valid.to_csv(self.split_dir / "valid.csv", index=False)
351
+ test.to_csv(self.split_dir / "test.csv", index=False)
352
+
353
+ def prepare_data(self, **kwargs):
354
+
355
+ if kwargs.get('load', False):
356
+ return
357
+
358
+ if self.exp_dir.exists():
359
+ logger.info(f"Experiment directory {self.exp_dir} already exists. All existing files "
360
+ f" will be kept. To create new data/split remove {self.exp_data_dir} or {self.split_dir}")
361
+ self.exp_dir.mkdir(parents=True, exist_ok=True)
362
+
363
+ # Make train-test split
364
+ default_split_kwargs = {'seed': 777,
365
+ 'valid_fraction': 0.1,
366
+ 'test_fraction': 0.1,
367
+ }
368
+ # update from kwargs
369
+ for key in default_split_kwargs.keys():
370
+ if key in kwargs:
371
+ default_split_kwargs[key] = kwargs[key]
372
+ # Create new split or skip if exist
373
+ self.prepare_data_split(**default_split_kwargs)
374
+
375
+ # Prepare smiles (simplified - only tokenized smiles)
376
+ self.molecular_dataloader.prepare_molecular_features()
377
+
378
+ def setup(self, stage=None):
379
+ """Loads and prepares the datasets for a given stage.
380
+
381
+ This method is called by PyTorch Lightning. It performs the following steps:
382
+ 1. Loads molecular features (tokenized SMILES).
383
+ 2. Loads protein features (pre-computed ESM embeddings).
384
+ 3. Loads data splits (train/val/test) from CSV files.
385
+ 4. Initializes the PyTorch Datasets for each split.
386
+
387
+ Args:
388
+ stage (str, optional): The stage to setup ('fit', 'validate', 'test', 'predict').
389
+ """
390
+ logger.info("Loading molecular features")
391
+
392
+ # Load molecular features (simplified - only SMILES)
393
+ self.molecular_dataloader.load_molecular_features()
394
+
395
+ # Load protein features (only ESM embeddings)
396
+ logger.info(f"Loading protein features {self.esm_model_name}")
397
+ prot_embbeding_pt = self.data_dir / f'all_prots_{self.esm_model_name}.pt'
398
+
399
+ if prot_embbeding_pt.exists():
400
+ self.idx_to_sequence_data = torch.load(prot_embbeding_pt, map_location='cpu', weights_only=False)
401
+ length_dict = {idx: emb.shape[0] for idx, emb in self.idx_to_sequence_data.items()}
402
+ self.sequence_embedding_dim = next(iter(self.idx_to_sequence_data.values())).shape[1]
403
+ else:
404
+ raise FileNotFoundError(
405
+ f"Packed proteins `all_prots_{self.esm_model_name}.pt` is not found in {self.data_dir}")
406
+
407
+ # load data. Use integer dtypes for categorical features and float for labels.
408
+ logger.info("Loading activity table")
409
+
410
+ self.datasets = dict(zip(["train", "val", "test"],
411
+ [self._read_df_and_compute_sequence_lengths(self.split_dir / f"{split}.csv",
412
+ length_dict)
413
+ for split in ["train", "valid", "test"]]))
414
+
415
+
416
+ # initialise self.train_dataset, self.val_dataset, self.test_dataset
417
+ for ds in ['train', 'val', 'test']:
418
+ df_ds = self.datasets[ds]
419
+ assert len(ds) > 0, f"{ds=} is empty"
420
+ ds_proto = self.create_dataset(df_ds)
421
+ ds_proto._is_train = (ds == 'train')
422
+ self.torch_datasets[ds] = ds_proto
423
+
424
+ def create_dataset(self, df, **kwargs):
425
+ dataset_kwargs = self.molecular_dataloader.dataset_kwargs
426
+ dataset_class = DatasetMolecularEmbeddings
427
+
428
+ cluster_smi = None
429
+ sample_smiles = dataset_kwargs.get('sample_smiles', False)
430
+ if sample_smiles:
431
+ cluster_smi = df['cluster_smi'].values
432
+
433
+ logger.info(f"Creating dataset: using {dataset_class=}")
434
+ ds_proto = dataset_class(
435
+ sequence_embedding=(self.idx_to_sequence_data),
436
+ smiles_embeddings=self.molecular_dataloader.get_features(),
437
+ sequences=df['sequence'].values,
438
+ sequences_length=df['sequence_length'].values,
439
+ smiles=df['smiles'].values,
440
+ dtype=self.float_type,
441
+ cluster_smi=cluster_smi,
442
+ **dataset_kwargs,
443
+ **kwargs,
444
+ )
445
+ return ds_proto
446
+
447
+ def get_dataloader(self, dataset, shuffle, use_sampler=True, pin_memory=True):
448
+ if use_sampler:
449
+ sampler = RandomizedBatchSampler(sequence_length=dataset.sequences_length,
450
+ shuffle=shuffle,
451
+ batch_volume=self.batch_volume,
452
+ batch_size=self.max_size_batch)
453
+ return DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.collate_fn,
454
+ num_workers=self.num_workers, pin_memory=pin_memory)
455
+ else:
456
+ return DataLoader(dataset=dataset, collate_fn=dataset.collate_fn, batch_size=self.max_size_batch,
457
+ num_workers=self.num_workers, pin_memory=pin_memory, shuffle=shuffle)
458
+
459
+ def train_dataloader(self, use_sampler=True, shuffle=True):
460
+ return self.get_dataloader(self.torch_datasets['train'], shuffle=shuffle, use_sampler=use_sampler)
461
+
462
+ def val_dataloader(self, use_sampler=True, shuffle=False):
463
+ return self.get_dataloader(self.torch_datasets['val'], shuffle=shuffle, use_sampler=use_sampler)
464
+
465
+ def test_dataloader(self, use_sampler=True, shuffle=False):
466
+ return self.get_dataloader(self.torch_datasets['test'], shuffle=shuffle, use_sampler=use_sampler)
467
+
468
+ def predict_dataloader(self, dataset='test', use_sampler=False, shuffle=False):
469
+ return self.get_dataloader(self.torch_datasets[dataset], shuffle=shuffle, use_sampler=use_sampler)
470
+
471
+ def get_smiles_embedding_dim(self):
472
+ return self.molecular_dataloader.embedding_size
473
+
474
+ def get_sequence_embedding_dim(self):
475
+ return self.sequence_embedding_dim
476
+
477
+
478
+ class DatasetNumpy(Dataset):
479
+ """ Dataset for feeding model with sequences and ligands embeddings """
480
+
481
+ def __init__(self, *, sequence_embedding: Tuple[np.array, np.array],
482
+ smiles_embeddings: np.ndarray,
483
+ sequences: np.ndarray,
484
+ sequences_length: np.ndarray,
485
+ smiles: np.ndarray,
486
+ dtype='float16',
487
+ **kwargs,
488
+ ):
489
+ """
490
+ Args:
491
+ sequence_embedding: embedding for sequences - 1 per each sequence
492
+ smiles_embeddings: embedding for smiles - 1 per each smile
493
+ sequences: sequence label in the dataset - 1 per sample
494
+ sequences_length: sequence length in the dataset - 1 per sample
495
+ smiles: smile label in the dataset - 1 per sample
496
+ """
497
+ assert len(sequences) == len(sequences_length), f"{len(sequences)=} {len(sequences_length)=}"
498
+ assert len(sequences) == len(smiles), f"{len(sequences)=} {len(smiles)=}"
499
+
500
+ self.data_sequence = sequence_embedding
501
+ self.smiles_embeddings = self.init_smiles_embeddings(smiles_embeddings)
502
+ self.sequences_length = sequences_length
503
+ self.sequences = sequences
504
+ self.smiles = smiles
505
+ self.float_type = getattr(torch, dtype)
506
+
507
+ # Only support ESM embeddings (float type)
508
+ self.sequence_dtype = self.float_type
509
+ self._is_train = False # this parameter is assigned in during model.setup()
510
+
511
+ # SMILES SAMPLER
512
+ sample_smiles = kwargs.get('sample_smiles', False)
513
+ self.cluster_smiles = kwargs.get('cluster_smi', None)
514
+ self.smiles_to_cluster = None
515
+ if sample_smiles:
516
+ self.group_smiles(self.cluster_smiles)
517
+ self.get_smiles_id = self._smiles_id_sample
518
+ else:
519
+ self.get_smiles_id = self._smiles_id_as_ind
520
+
521
+ def init_smiles_embeddings(self, smiles_embeddings):
522
+ return smiles_embeddings
523
+
524
+ def group_smiles(self, clusters):
525
+ """ for each sequence group similar smiles to list for random sampling during training """
526
+
527
+ len_ = len(self.sequences)
528
+ df = pd.DataFrame(data={'smiles': self.smiles, 'sequence': self.sequences, 'cluster_smi': clusters,
529
+ 'sequences_length': self.sequences_length}
530
+ ).groupby(['cluster_smi', 'sequence', 'sequences_length'], as_index=False).agg(list)
531
+ self.smiles_to_cluster = df['smiles'].values
532
+ self.sequences = df['sequence'].values
533
+ self.cluster_smiles = df['cluster_smi'].values
534
+ self.sequences_length = df['sequences_length'].values
535
+ logger.info(f"Sampling from similar smiles is ON, dataset size reduced from {len_} to {len(self.sequences)}")
536
+
537
+ def _smiles_id_as_ind(self, idx: int) -> int:
538
+ """ Get smiles is from array self.smiles """
539
+ return self.smiles[idx]
540
+
541
+ def _smiles_id_sample(self, idx) -> int:
542
+ """ Sample smile id from grouped SMILES from same cluster"""
543
+ return np.random.choice(self.smiles_to_cluster[idx])
544
+
545
+ def __len__(self) -> int:
546
+ # the number of entries in the dataset
547
+ return len(self.sequences)
548
+
549
+ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray, int]:
550
+
551
+ seq_id = self.sequences[idx]
552
+ smi_id = self.get_smiles_id(idx)
553
+
554
+ return (self.parametrize_sequence(seq_id),
555
+ self.parametrize_smiles(smi_id),
556
+ self.sequences_length[idx])
557
+
558
+ def parametrize_smiles(self, smiles_id: int) -> np.array:
559
+ return self.smiles_embeddings[smiles_id]
560
+
561
+ def parametrize_sequence(self, sequence_id: int) -> np.array:
562
+ return self.data_sequence[sequence_id]
563
+
564
+ @staticmethod
565
+ def _collate_fn_pack(batch):
566
+ """ Pack dataset samples to sequences of sequences, smiles, sequence_lengths """
567
+ return zip(*batch)
568
+
569
+ def _pad_sequence(self, sequences: List[np.ndarray]) -> torch.Tensor:
570
+ return pad_sequence([torch.tensor(s, dtype=self.sequence_dtype) for s in sequences], batch_first=True,
571
+ padding_value=ProtobindDataModule.MASK_VALUE)
572
+
573
+ def collate_fn(self, batch: Tuple[np.ndarray, np.ndarray, int ]) -> Tuple[
574
+ Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
575
+ """Collates samples into a single batch, padding sequences to the same length.
576
+
577
+ Args:
578
+ batch : A tuple of samples, where each sample is the output of `__getitem__`.
579
+
580
+ Returns:
581
+ Tuple: A tuple containing batched tensors:
582
+ - ((torch.Tensor, torch.Tensor)): A tuple of padded protein sequences
583
+ and a tensor of their original lengths.
584
+ - (torch.Tensor): A batch of SMILES embeddings.
585
+ """
586
+
587
+ sequences, smiles, sequence_lengths = self._collate_fn_pack(batch)
588
+
589
+ padded_sequences = self._pad_sequence(sequences)
590
+
591
+ return ((padded_sequences, torch.tensor(sequence_lengths, dtype=torch.int32)),
592
+ torch.tensor(np.array(smiles), dtype=self.float_type))
593
+
594
+
595
+ class DatasetMolecularEmbeddings(DatasetNumpy):
596
+ """A dataset for masked diffusion models using protein embeddings and tokenized SMILES.
597
+
598
+ This class extends `DatasetNumpy` to handle variable-length, tokenized SMILES
599
+ representations from a `RandomizedSmilesDataset`. It overrides methods for
600
+ SMILES parameterization and batch collation to support this token-based approach,
601
+ which is required for diffusion models.
602
+ """
603
+
604
+ def parametrize_smiles(self, smiles_id: int) -> Tuple[np.array, int]:
605
+ mol = self.smiles_embeddings[smiles_id]
606
+ return mol, len(mol)
607
+
608
+ def __getitem__(self, idx) -> Tuple[np.ndarray, np.array, int, int, int, int]:
609
+ """Retrieves a single data sample with tokenized SMILES.
610
+
611
+ Unlike the parent class, this method returns the tokenized SMILES
612
+ and its length instead of a fixed-size embedding.
613
+ """
614
+ seq_id = self.sequences[idx]
615
+ smi_id = self.smiles[idx]
616
+ return (self.parametrize_sequence(seq_id),) + self.parametrize_smiles(smi_id) + (
617
+ self.sequences_length[idx], seq_id, smi_id)
618
+
619
+ def init_smiles_embeddings(self, smiles_embeddings):
620
+ if isinstance(smiles_embeddings, RandomizedSmilesDataset):
621
+ return smiles_embeddings
622
+ else:
623
+ raise ValueError("version only supports RandomizedSmilesDataset")
624
+
625
+ def collate_fn(self, batch: List[Tuple[np.ndarray, np.array, int, int, int, int]]) -> Tuple[
626
+ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
627
+ torch.Tensor, torch.Tensor]:
628
+
629
+ """Collates samples into a batch, padding both protein and SMILES sequences.
630
+
631
+ Args:
632
+ batch (list): A list of samples, where each sample is the output of __getitem__.
633
+
634
+ Returns:
635
+ Tuple: A tuple containing the final batched tensors for the model:
636
+ - ((torch.Tensor, torch.Tensor)): Padded protein sequences and their lengths.
637
+ - ((torch.Tensor, torch.Tensor)): Padded tokenized SMILES and their lengths.
638
+ - (torch.Tensor): A batch of sequence IDs.
639
+ - (torch.Tensor): A batch of SMILES IDs.
640
+ """
641
+
642
+ sequences, atom, atom_lengths, sequence_lengths, seq_id, smi_id \
643
+ = self._collate_fn_pack(batch)
644
+
645
+ padded_sequences = self._pad_sequence(sequences) # padding proteins sequences
646
+ padded_atom = pad_sequence([s.to(dtype=self.float_type) for s in atom], batch_first=True)
647
+ atom_lengths = torch.tensor(atom_lengths, dtype=torch.int32)
648
+
649
+ return ((padded_sequences, torch.tensor(sequence_lengths, dtype=torch.int32)),
650
+ (padded_atom, atom_lengths),
651
+ torch.tensor(seq_id, dtype=torch.int32),
652
+ torch.tensor(smi_id, dtype=torch.int32),
653
+ )
654
+
655
+
656
+ class MolecularDataloaderSMILES(object):
657
+ """
658
+ molecular dataloader that only supports tokenized SMILES
659
+ with ChemformerTokenizer for masked diffusion models.
660
+ """
661
+
662
+ def __init__(self, *,
663
+ data_dir: Path,
664
+ dataset_options: Optional[dict] = None):
665
+ """
666
+ Args:
667
+ data_dir: path to data folder containing tokenizer files and dict with all smiles and fasta sequences
668
+ dataset_options: dictionary with additional parameters used to create pytorch Dataset
669
+ """
670
+ self.data_dir = data_dir
671
+ if dataset_options is None:
672
+ logger.info('Setting tokenizer file name to tokenizer_smiles_diffusion.json')
673
+ dataset_options = {'tokenizer_json_name': 'tokenizer_smiles_diffusion'}
674
+ self.dataset_options = dataset_options
675
+
676
+ self.tokenizer_path = self.data_dir / f"{dataset_options['tokenizer_json_name']}.json"
677
+ self.tokenizer = ChemformerTokenizer(filename=str(self.tokenizer_path))
678
+ self.randomize = dataset_options.get('randomize', False)
679
+ self.smiles_embedding_dim = 1 # For tokenized SMILES, embedding dim is 1
680
+ self.baseline_dim = 0 # this version doesn't support baseline features
681
+
682
+ def prepare_molecular_features(self):
683
+ """Prepare molecular features"""
684
+ if not self.tokenizer_path.exists():
685
+ raise FileNotFoundError(
686
+ f"Could not find tokenizer at {self.tokenizer_path}. Please ensure the tokenizer file exists.")
687
+ logger.info(f"Found ChemformerTokenizer at {self.tokenizer_path}")
688
+
689
+ def load_molecular_features(self):
690
+ """Load molecular features - loads SMILES mappings"""
691
+ categorical_mappings_path = self.data_dir / 'categorical_mappings.json'
692
+ if not categorical_mappings_path.exists():
693
+ raise FileNotFoundError(f"categorical_mappings.json not found in data_dir: {self.data_dir}")
694
+
695
+ self.smiles_dataset = RandomizedSmilesDataset.from_json(
696
+ categorical_mappings_path,
697
+ tokenizer=self.tokenizer,
698
+ randomize=self.randomize
699
+ )
700
+
701
+ def get_features(self):
702
+ """Get the SMILES dataset for tokenized molecular features"""
703
+ return self.smiles_dataset
704
+
705
+ @property
706
+ def dataset_kwargs(self):
707
+ """Return dataset options for creating pytorch datasets"""
708
+ return self.dataset_options
709
+
710
+ @property
711
+ def embedding_size(self):
712
+ """Get embedding size for tokenized SMILES"""
713
+ return self.smiles_embedding_dim
714
+
715
+
716
+ class InferenceDataset(Dataset):
717
+ """Creates a dataset for running inference on a single protein embedding.
718
+
719
+ This utility dataset repeatedly yields the same batch, created by expanding
720
+ a single input embedding. It's designed for generating a large number of
721
+ ligand samples for one protein target without a traditional dataset structure.
722
+ """
723
+ def __init__(self, embedding: torch.Tensor, batch_size: int, n_batches: int):
724
+ """Initializes the inference dataset.
725
+
726
+ Args:
727
+ embedding (torch.Tensor): The single protein embedding tensor to be used.
728
+ batch_size (int): The number of times to repeat the embedding in each batch.
729
+ n_batches (int): The total number of identical batches the dataset should yield.
730
+ """
731
+ self.embedding_single = embedding
732
+ self.batch_size = batch_size
733
+ self.n_batches = n_batches
734
+ self.seq_len = embedding.shape[1]
735
+
736
+ def __len__(self) -> int:
737
+ return self.n_batches
738
+
739
+ def __getitem__(self, idx: int) -> Tuple:
740
+ """Generates a full batch ready for model inference.
741
+
742
+ Note: This method ignores the `idx` argument and always returns the same
743
+ batch, which is constructed by expanding the stored protein embedding.
744
+ It includes dummy values to match the data structure expected by the model.
745
+
746
+ Returns:
747
+ Tuple: A tuple containing pre-batched tensors:
748
+ - ((torch.Tensor, torch.Tensor)): Expanded protein embeddings and their lengths.
749
+ - (torch.Tensor): A dummy NaN tensor (placeholder for SMILES).
750
+ - (torch.Tensor): A batch of placeholder sequence IDs (-1).
751
+ - (torch.Tensor): A dummy NaN tensor (placeholder for smiles IDs).
752
+ """
753
+ embedding = self.embedding_single.expand(self.batch_size, -1, -1).contiguous()
754
+ lengths = torch.full((self.batch_size,), self.seq_len, dtype=torch.int32)
755
+ seq_ids = torch.full((self.batch_size,), -1, dtype=torch.int32) #seq_ids dont exist for new sequences
756
+ return (
757
+ (embedding, lengths),
758
+ torch.tensor(float('nan')),
759
+ seq_ids,
760
+ torch.tensor(float('nan')),
761
+ )
protobind_diff/decoder_rope.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from math import pi
3
+ import typing
4
+ from typing import Tuple, Optional, Literal
5
+
6
+ from einops import rearrange, repeat
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.amp import autocast
12
+ from torch.nn import Module, ModuleList
13
+ from torch import nn, einsum, broadcast_tensors, Tensor
14
+
15
+
16
+
17
+ #################################################################################
18
+ # Rotary Encoding #
19
+ #################################################################################
20
+
21
+ # helper functions
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def default(val, d):
28
+ return val if exists(val) else d
29
+
30
+ def slice_at_dim(t, dim_slice: slice, *, dim):
31
+ dim += (t.ndim if dim < 0 else 0)
32
+ colons = [slice(None)] * t.ndim
33
+ colons[dim] = dim_slice
34
+ return t[tuple(colons)]
35
+
36
+ # rotary embedding helper functions
37
+
38
+ def rotate_half(x):
39
+ """Splits the last dimension of a tensor, swaps halves, and negates the first half."""
40
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
41
+ x1, x2 = x.unbind(dim=-1)
42
+ x = torch.stack((-x2, x1), dim=-1)
43
+ return rearrange(x, '... d r -> ... (d r)')
44
+
45
+
46
+ @autocast('cuda', enabled=False)
47
+ def apply_rotary_emb(
48
+ freqs,
49
+ t,
50
+ start_index=0,
51
+ scale=1.,
52
+ seq_dim=-2,
53
+ freqs_seq_dim=None
54
+ ):
55
+ """Applies rotary positional embeddings to a given tensor.
56
+
57
+ Args:
58
+ freqs (torch.Tensor): The rotary frequencies.
59
+ t (torch.Tensor): The tensor to apply embeddings to (e.g., queries or keys).
60
+ start_index (int): The feature dimension index to start applying rotations from.
61
+ scale (float): A scaling factor, used for xPos.
62
+ seq_dim (int): The sequence dimension of the input tensor `t`.
63
+ freqs_seq_dim (Optional[int]): The sequence dimension of the freqs tensor.
64
+ """
65
+ dtype = t.dtype
66
+
67
+ if not exists(freqs_seq_dim):
68
+ if freqs.ndim == 2 or t.ndim == 3:
69
+ freqs_seq_dim = 0
70
+
71
+ if t.ndim == 3 or exists(freqs_seq_dim):
72
+ seq_len = t.shape[seq_dim]
73
+ freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
74
+
75
+ rot_dim = freqs.shape[-1]
76
+ end_index = start_index + rot_dim
77
+
78
+ assert rot_dim <= t.shape[
79
+ -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
80
+
81
+ # Split t into three parts: left, middle (to be transformed), and right
82
+ t_left = t[..., :start_index]
83
+ t_middle = t[..., start_index:end_index]
84
+ t_right = t[..., end_index:]
85
+
86
+ # Apply rotary embeddings without modifying t in place
87
+ t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
88
+
89
+ out = torch.cat((t_left, t_transformed, t_right), dim=-1)
90
+
91
+ return out.type(dtype)
92
+
93
+
94
+ # learned rotation helpers
95
+
96
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
97
+ if exists(freq_ranges):
98
+ rotations = einsum('..., f -> ... f', rotations, freq_ranges)
99
+ rotations = rearrange(rotations, '... r f -> ... (r f)')
100
+
101
+ rotations = repeat(rotations, '... n -> ... (n r)', r=2)
102
+ return apply_rotary_emb(rotations, t, start_index=start_index)
103
+
104
+
105
+ # classes
106
+
107
+ class RotaryEmbedding(Module):
108
+ """
109
+ original paper: https://arxiv.org/abs/2104.09864
110
+ rescale rotary embeddings to longer sequence length without fine-tuning
111
+ code source: https://github.com/lucidrains/rotary-embedding-torch
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ dim,
117
+ custom_freqs: Tensor | None = None,
118
+ freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang',
119
+ theta=10000,
120
+ max_freq=10,
121
+ num_freqs=1,
122
+ learned_freq=False,
123
+ use_xpos=False,
124
+ xpos_scale_base=512,
125
+ interpolate_factor=1.,
126
+ theta_rescale_factor=1.,
127
+ seq_before_head_dim=False,
128
+ cache_if_possible=True,
129
+ cache_max_seq_len=8192
130
+ ):
131
+ super().__init__()
132
+ """Initializes the RotaryEmbedding module.
133
+
134
+ Args:
135
+ dim (int): The feature dimension to apply rotary embeddings to.
136
+ custom_freqs ([Tensor]): An optional tensor of custom frequencies.
137
+ freqs_for : The method for generating
138
+ frequencies. 'lang' is standard for transformers.
139
+ theta (int): A core hyperparameter for frequency calculation.
140
+ learned_freq (bool): If True, the frequencies are trainable parameters.
141
+ use_xpos (bool): If True, enables the xPos (extrapolatable) variant.
142
+ interpolate_factor (float): A factor for positional interpolation, which
143
+ can help with length generalization.
144
+ cache_if_possible (bool): If True, caches calculated frequencies for efficiency.
145
+ """
146
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
147
+
148
+ self.freqs_for = freqs_for
149
+
150
+ if exists(custom_freqs):
151
+ freqs = custom_freqs
152
+ elif freqs_for == 'lang':
153
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
154
+ elif freqs_for == 'pixel':
155
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
156
+ elif freqs_for == 'constant':
157
+ freqs = torch.ones(num_freqs).float()
158
+
159
+ self.cache_if_possible = cache_if_possible
160
+ self.cache_max_seq_len = cache_max_seq_len
161
+
162
+ self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent=False)
163
+ self.cached_freqs_seq_len = 0
164
+
165
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
166
+
167
+ self.learned_freq = learned_freq
168
+
169
+ # dummy for device
170
+
171
+ self.register_buffer('dummy', torch.tensor(0), persistent=False)
172
+
173
+ # default sequence dimension
174
+
175
+ self.seq_before_head_dim = seq_before_head_dim
176
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
177
+
178
+ # interpolation factors
179
+
180
+ assert interpolate_factor >= 1.
181
+ self.interpolate_factor = interpolate_factor
182
+
183
+ # xpos
184
+
185
+ self.use_xpos = use_xpos
186
+
187
+ if not use_xpos:
188
+ return
189
+
190
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
191
+ self.scale_base = xpos_scale_base
192
+
193
+ self.register_buffer('scale', scale, persistent=False)
194
+ self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent=False)
195
+ self.cached_scales_seq_len = 0
196
+
197
+ # add apply_rotary_emb as static method
198
+
199
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
200
+
201
+ @property
202
+ def device(self):
203
+ return self.dummy.device
204
+
205
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
206
+ return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
207
+
208
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
209
+ """Applies rotary embeddings to a single tensor (queries or keys).
210
+
211
+ Args:
212
+ t (torch.Tensor): The input tensor (queries or keys).
213
+ seq_dim : The sequence dimension of the tensor.
214
+ offset (int): An offset for the position sequence, used for caching.
215
+ scale (Optional[float]): A scaling factor, required if using xPos.
216
+
217
+ Returns:
218
+ torch.Tensor: The tensor with rotary embeddings applied.
219
+ """
220
+ seq_dim = default(seq_dim, self.default_seq_dim)
221
+
222
+ assert not self.use_xpos or exists(
223
+ scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
224
+
225
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
226
+
227
+ seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
228
+
229
+ freqs = self.forward(seq, seq_len=seq_len, offset=offset)
230
+
231
+ if seq_dim == -3:
232
+ freqs = rearrange(freqs, 'n d -> n 1 d')
233
+
234
+ return apply_rotary_emb(freqs, t, scale=default(scale, 1.), seq_dim=seq_dim)
235
+
236
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
237
+ dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
238
+
239
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
240
+ assert q_len <= k_len
241
+
242
+ q_scale = k_scale = 1.
243
+
244
+ if self.use_xpos:
245
+ seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
246
+
247
+ q_scale = self.get_scale(seq[-q_len:]).type(dtype)
248
+ k_scale = self.get_scale(seq).type(dtype)
249
+
250
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
251
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale ** -1)
252
+
253
+ rotated_q = rotated_q.type(q.dtype)
254
+ rotated_k = rotated_k.type(k.dtype)
255
+
256
+ return rotated_q, rotated_k
257
+
258
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
259
+ seq_dim = default(seq_dim, self.default_seq_dim)
260
+
261
+ assert self.use_xpos
262
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
263
+
264
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
265
+
266
+ freqs = self.forward(seq, seq_len=seq_len)
267
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
268
+
269
+ if seq_dim == -3:
270
+ freqs = rearrange(freqs, 'n d -> n 1 d')
271
+ scale = rearrange(scale, 'n d -> n 1 d')
272
+
273
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
274
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale ** -1, seq_dim=seq_dim)
275
+
276
+ rotated_q = rotated_q.type(q.dtype)
277
+ rotated_k = rotated_k.type(k.dtype)
278
+
279
+ return rotated_q, rotated_k
280
+
281
+ def get_scale(
282
+ self,
283
+ t: Tensor,
284
+ seq_len = None,
285
+ offset=0
286
+ ):
287
+ assert self.use_xpos
288
+
289
+ should_cache = (
290
+ self.cache_if_possible and
291
+ exists(seq_len) and
292
+ (offset + seq_len) <= self.cache_max_seq_len
293
+ )
294
+
295
+ if (
296
+ should_cache and \
297
+ exists(self.cached_scales) and \
298
+ (seq_len + offset) <= self.cached_scales_seq_len
299
+ ):
300
+ return self.cached_scales[offset:(offset + seq_len)]
301
+
302
+ scale = 1.
303
+ if self.use_xpos:
304
+ power = (t - len(t) // 2) / self.scale_base
305
+ scale = self.scale ** rearrange(power, 'n -> n 1')
306
+ scale = repeat(scale, 'n d -> n (d r)', r=2)
307
+
308
+ if should_cache and offset == 0:
309
+ self.cached_scales[:seq_len] = scale.detach()
310
+ self.cached_scales_seq_len = seq_len
311
+
312
+ return scale
313
+
314
+ def get_axial_freqs(self, *dims):
315
+ Colon = slice(None)
316
+ all_freqs = []
317
+
318
+ for ind, dim in enumerate(dims):
319
+ if self.freqs_for == 'pixel':
320
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
321
+ else:
322
+ pos = torch.arange(dim, device=self.device)
323
+
324
+ freqs = self.forward(pos, seq_len=dim)
325
+
326
+ all_axis = [None] * len(dims)
327
+ all_axis[ind] = Colon
328
+
329
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
330
+ all_freqs.append(freqs[new_axis_slice])
331
+
332
+ all_freqs = broadcast_tensors(*all_freqs)
333
+ return torch.cat(all_freqs, dim=-1)
334
+
335
+ @autocast('cuda', enabled=False)
336
+ def forward(
337
+ self,
338
+ t: Tensor,
339
+ seq_len = None,
340
+ offset=0
341
+ ):
342
+ """Calculates the rotary frequencies for a given sequence of positions.
343
+
344
+ Args:
345
+ t (torch.Tensor): A tensor of position indices.
346
+ seq_len (int): The total sequence length, used for caching.
347
+ offset (int): The starting position offset.
348
+
349
+ Returns:
350
+ torch.Tensor: A tensor of calculated rotation frequencies.
351
+ """
352
+ should_cache = (
353
+ self.cache_if_possible and
354
+ not self.learned_freq and
355
+ exists(seq_len) and
356
+ self.freqs_for != 'pixel' and
357
+ (offset + seq_len) <= self.cache_max_seq_len
358
+ )
359
+
360
+ if (
361
+ should_cache and \
362
+ exists(self.cached_freqs) and \
363
+ (offset + seq_len) <= self.cached_freqs_seq_len
364
+ ):
365
+ return self.cached_freqs[offset:(offset + seq_len)].detach()
366
+
367
+ freqs = self.freqs
368
+
369
+ freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
370
+ freqs = repeat(freqs, '... n -> ... (n r)', r=2)
371
+
372
+ if should_cache and offset == 0:
373
+ self.cached_freqs[:seq_len] = freqs.detach()
374
+ self.cached_freqs_seq_len = seq_len
375
+
376
+ return freqs
377
+
378
+
379
+ #################################################################################
380
+ # Multi Head Attention #
381
+ #################################################################################
382
+
383
+ class LayerNorm(nn.Module):
384
+ """Implements a Layer Normalization module."""
385
+ def __init__(self, d_model, eps=1e-12):
386
+ """Initializes the LayerNorm module.
387
+
388
+ Args:
389
+ d_model (int): The dimension of the model's features.
390
+ eps (float): A small value added to the variance for numerical stability.
391
+ """
392
+ super(LayerNorm, self).__init__()
393
+ self.gamma = nn.Parameter(torch.ones(d_model))
394
+ self.beta = nn.Parameter(torch.zeros(d_model))
395
+ self.eps = eps
396
+
397
+ def forward(self, x):
398
+ """Applies Layer Normalization to the input tensor along the last dimension.
399
+ Args:
400
+ x (torch.Tensor): The input tensor to normalize.
401
+ Returns:
402
+ torch.Tensor: The normalized tensor.
403
+ """
404
+ mean = x.mean(-1, keepdim=True)
405
+ var = x.var(-1, unbiased=False, keepdim=True)
406
+ # '-1' means last dimension.
407
+
408
+ out = (x - mean) / torch.sqrt(var + self.eps)
409
+ out = self.gamma * out + self.beta
410
+ return out
411
+
412
+
413
+ class PositionwiseFeedForward(nn.Module):
414
+ """Implements the Position-wise Feed-Forward network of a Transformer block."""
415
+
416
+ def __init__(self, d_model, hidden, drop_prob=0.1):
417
+ """Initializes the PositionwiseFeedForward module.
418
+
419
+ Args:
420
+ d_model (int): The input and output dimension of the layer.
421
+ hidden (int): The dimension of the inner hidden layer.
422
+ drop_prob (float): The probability for the dropout layer.
423
+ """
424
+ super(PositionwiseFeedForward, self).__init__()
425
+ self.linear1 = nn.Linear(d_model, hidden)
426
+ self.linear2 = nn.Linear(hidden, d_model)
427
+ self.relu = nn.ReLU()
428
+ self.dropout = nn.Dropout(p=drop_prob)
429
+
430
+ def forward(self, x):
431
+ """Passes the input through the feed-forward network.
432
+ The process is: Linear -> ReLU -> Dropout -> Linear.
433
+ Args:
434
+ x (torch.Tensor): The input tensor.
435
+ Returns:
436
+ torch.Tensor: The output tensor.
437
+ """
438
+ x = self.linear1(x)
439
+ x = self.relu(x)
440
+ x = self.dropout(x)
441
+ x = self.linear2(x)
442
+ return x
443
+
444
+
445
+ class ScaleDotProductAttention(nn.Module):
446
+
447
+ def __init__(self):
448
+ super(ScaleDotProductAttention, self).__init__()
449
+ self.softmax = nn.Softmax(dim=-1)
450
+
451
+ def forward(self, q, k, v, mask=None, e=1e-12):
452
+ """
453
+ Performs the Scaled Dot-Product Attention calculation.
454
+
455
+ Args:
456
+ q (torch.Tensor): The query tensor.
457
+ k (torch.Tensor): The key tensor.
458
+ v (torch.Tensor): The value tensor.
459
+ mask (torch.Tensor, optional): A mask to prevent attention to
460
+ certain positions. Defaults to None.
461
+
462
+ Returns:
463
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing the attention
464
+ output and the attention scores.
465
+ """
466
+ batch_size, head, length, d_tensor = k.size()
467
+ k_t = k.transpose(2, 3) # transpose
468
+ score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product
469
+ if mask is not None:
470
+ score = score.masked_fill(mask == 0, -10000)
471
+ score = self.softmax(score)
472
+ v = score @ v
473
+ return v, score
474
+
475
+
476
+ class MultiHeadAttention(nn.Module):
477
+ """Implements a Multi-Head Attention layer with optional Rotary Position Embeddings."""
478
+
479
+ def __init__(self, d_model, n_head):
480
+ """Initializes the MultiHeadAttention layer.
481
+
482
+ Args:
483
+ d_model (int): The total dimension of the model.
484
+ n_head (int): The number of attention heads. d_model must be divisible by n_head.
485
+ """
486
+ super(MultiHeadAttention, self).__init__()
487
+ self.n_head = n_head
488
+ self.attention = ScaleDotProductAttention()
489
+ self.w_q = nn.Linear(d_model, d_model)
490
+ self.w_k = nn.Linear(d_model, d_model)
491
+ self.w_v = nn.Linear(d_model, d_model)
492
+ self.w_concat = nn.Linear(d_model, d_model)
493
+
494
+ self.rotary_emb = RotaryEmbedding(dim=d_model // n_head)
495
+
496
+ def forward(self, q, k, v, mask=None, apply_rotary=False):
497
+ """Performs the forward pass for multi-head attention.
498
+
499
+ Args:
500
+ q (torch.Tensor): The query tensor.
501
+ k (torch.Tensor): The key tensor.
502
+ v (torch.Tensor): The value tensor.
503
+ mask (torch.Tensor, optional): An attention mask. Defaults to None.
504
+ apply_rotary (bool): If True, applies Rotary Position Embeddings to Q and K
505
+ before the attention calculation. Defaults to False.
506
+
507
+ Returns:
508
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing the final output tensor
509
+ and the attention scores.
510
+ """
511
+ q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
512
+ q, k, v = self.split(q), self.split(k), self.split(v)
513
+
514
+ if apply_rotary:
515
+ # add Rotary Positional Embeddings (RoPE)
516
+ # https://arxiv.org/abs/2104.09864
517
+ q = self.rotary_emb.rotate_queries_or_keys(q)
518
+ k = self.rotary_emb.rotate_queries_or_keys(k)
519
+
520
+ out, attention = self.attention(q, k, v, mask=mask)
521
+ out = self.concat(out)
522
+ out = self.w_concat(out)
523
+ return out, attention
524
+
525
+ def split(self, tensor):
526
+ """Splits the last dimension of a tensor into multiple heads."""
527
+ batch_size, length, d_model = tensor.size()
528
+ d_tensor = d_model // self.n_head
529
+ tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
530
+ return tensor
531
+
532
+ def concat(self, tensor):
533
+ """Concatenates multiple heads back into a single tensor."""
534
+ batch_size, head, length, d_tensor = tensor.size()
535
+ d_model = head * d_tensor
536
+ tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
537
+ return tensor
538
+
539
+
540
+ #################################################################################
541
+ # Embedding Layers #
542
+ #################################################################################
543
+
544
+ class EmbeddingLayer(nn.Module):
545
+ """A simple lookup-based embedding layer with Kaiming uniform initialization."""
546
+ def __init__(self, dim, vocab_dim):
547
+ super().__init__()
548
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
549
+ torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
550
+
551
+ def forward(self, x):
552
+ """Looks up the embeddings for the given indices.
553
+ Args:
554
+ x (torch.Tensor): A tensor of integer indices.
555
+ Returns:
556
+ torch.Tensor: The corresponding embedding vectors.
557
+ """
558
+ return self.embedding[x]
559
+
560
+
561
+ class TimestepEmbedder(nn.Module):
562
+ """
563
+ Embeds scalar timesteps into vector representations.
564
+ """
565
+
566
+ def __init__(self, hidden_size, frequency_embedding_size=256):
567
+ """Initializes the TimestepEmbedder.
568
+
569
+ Args:
570
+ hidden_size (int): The final dimension of the timestep embedding.
571
+ frequency_embedding_size (int): The number of frequencies to use for
572
+ the sinusoidal embedding.
573
+ """
574
+ super().__init__()
575
+ self.mlp = nn.Sequential(
576
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
577
+ nn.SiLU(),
578
+ nn.Linear(hidden_size, hidden_size, bias=True))
579
+ self.frequency_embedding_size = frequency_embedding_size
580
+
581
+ @staticmethod
582
+ def timestep_embedding(t, dim, max_period=10000):
583
+ """
584
+ Create sinusoidal timestep embeddings.
585
+ :param t: a 1-D Tensor of N indices, one per batch element.
586
+ These may be fractional.
587
+ :param dim: the dimension of the output.
588
+ :param max_period: controls the minimum frequency of the embeddings.
589
+ :return: an (N, D) Tensor of positional embeddings.
590
+ """
591
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
592
+ half = dim // 2
593
+ freqs = torch.exp(
594
+ - math.log(max_period)
595
+ * torch.arange(start=0, end=half, dtype=torch.float32)
596
+ / half).to(device=t.device)
597
+ args = t[:, None].float() * freqs[None]
598
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
599
+ if dim % 2:
600
+ embedding = torch.cat(
601
+ [embedding,
602
+ torch.zeros_like(embedding[:, :1])], dim=-1)
603
+ return embedding
604
+
605
+ def forward(self, t):
606
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
607
+ t_emb = self.mlp(t_freq)
608
+ return t_emb
609
+
610
+
611
+ #################################################################################
612
+ # Decoder #
613
+ #################################################################################
614
+
615
+ class DecoderLayer(nn.Module):
616
+ """
617
+ code source: https://github.com/hyunwoongko/transformer
618
+ """
619
+
620
+ def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
621
+ """Initializes the DecoderLayer.
622
+
623
+ Args:
624
+ d_model (int): The dimension of the model.
625
+ ffn_hidden (int): The dimension of the hidden layer in the feed-forward network.
626
+ n_head (int): The number of attention heads.
627
+ drop_prob (float): The dropout probability.
628
+ """
629
+ super(DecoderLayer, self).__init__()
630
+
631
+ self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
632
+ self.norm1 = LayerNorm(d_model=d_model)
633
+ self.dropout1 = nn.Dropout(p=drop_prob)
634
+
635
+ self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
636
+ self.norm2 = LayerNorm(d_model=d_model)
637
+ self.dropout2 = nn.Dropout(p=drop_prob)
638
+
639
+ self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
640
+ self.norm3 = LayerNorm(d_model=d_model)
641
+ self.dropout3 = nn.Dropout(p=drop_prob)
642
+
643
+ def forward(self, dec, enc, trg_mask, src_mask, return_attention=False):
644
+ """Performs one forward pass of the decoder layer.
645
+
646
+ Args:
647
+ dec (torch.Tensor): The input tensor from the previous decoder layer.
648
+ enc (torch.Tensor): The output tensor from the encoder (for conditioning).
649
+ trg_mask (torch.Tensor): The mask for the decoder's self-attention.
650
+ src_mask (torch.Tensor): The mask for the cross-attention.
651
+ return_attention (bool): If True, returns the cross-attention weights.
652
+
653
+ Returns:
654
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing the output tensor
655
+ and the attention weights (or None).
656
+ """
657
+ attention = None
658
+
659
+ _x = dec
660
+ x, _ = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask, apply_rotary=True)
661
+ x = self.dropout1(x)
662
+ x = self.norm1(x + _x)
663
+
664
+ if enc is not None:
665
+ _x = x
666
+ if return_attention:
667
+ x, attention = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
668
+ else:
669
+ x, _ = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
670
+ x = self.dropout2(x)
671
+ x = self.norm2(x + _x)
672
+
673
+ _x = x
674
+ x = self.ffn(x)
675
+ x = self.dropout3(x)
676
+ x = self.norm3(x + _x)
677
+ return x, attention
678
+
679
+
680
+ class Decoder_RoPE(nn.Module):
681
+ """A decoder that uses Rotary Position Embeddings (RoPE).
682
+
683
+ This model is designed for a diffusion task, taking a ligand sequence, a
684
+ conditioning protein sequence, and a diffusion timestep (sigma) as input
685
+ to predict the output logits for the ligand.
686
+ """
687
+ def __init__(self,
688
+ vocab_size,
689
+ seq_emb_dim,
690
+ hidden_size: int=640,
691
+ nhead: int=8,
692
+ n_layers: int=4,
693
+ expand_feedforward: int=3,
694
+ dropout: float=0.1):
695
+
696
+ """Args:
697
+ vocab_size (int): The size of the output vocabulary (e.g., ligand tokens).
698
+ seq_emb_dim (int): The dimension of the input sequence embeddings.
699
+ hidden_size (int): The main hidden dimension of the Transformer model.
700
+ nhead (int): The number of attention heads in each DecoderLayer.
701
+ n_layers (int): The number of DecoderLayers to stack.
702
+ expand_feedforward (int): The expansion factor for the feed-forward
703
+ network's hidden layer.
704
+ dropout (float): The dropout probability.
705
+ """
706
+ super().__init__()
707
+
708
+ self.hidden_size = hidden_size
709
+ self.vocab_embed = EmbeddingLayer(self.hidden_size, vocab_size)
710
+ self.linear = nn.Linear(self.hidden_size, vocab_size)
711
+ self.apply_seq_linear = False
712
+
713
+ if seq_emb_dim != self.hidden_size:
714
+ self.apply_seq_linear = True
715
+ self.linear_seq = nn.Linear(seq_emb_dim, self.hidden_size)
716
+
717
+ self.sigma_map = TimestepEmbedder(self.hidden_size)
718
+
719
+ self.layers = nn.ModuleList([DecoderLayer(d_model=self.hidden_size,
720
+ ffn_hidden=self.hidden_size * expand_feedforward,
721
+ n_head=nhead,
722
+ drop_prob=dropout)
723
+ for _ in range(n_layers)])
724
+
725
+ def forward(self,
726
+ ligand: torch.Tensor,
727
+ sigma: torch.Tensor,
728
+ sequence: torch.Tensor,
729
+ sequence_lengths: torch.Tensor,
730
+ lig_padding_mask: Optional[torch.Tensor]=None,
731
+ return_attention: bool=False) -> Tuple[torch.Tensor, torch.Tensor]:
732
+ """Performs the forward pass of the decoder.
733
+
734
+ It processes the ligand sequence conditioned on the protein sequence and the
735
+ diffusion timestep (sigma). The sigma embedding is prepended to the protein
736
+ sequence to form a single conditioning context.
737
+
738
+ Args:
739
+ ligand (torch.Tensor): A batch of ligand token ID tensors.
740
+ sigma (torch.Tensor): A batch of scalar diffusion timesteps.
741
+ sequence (torch.Tensor): A batch of conditioning protein sequence embeddings.
742
+ sequence_lengths (torch.Tensor): The original lengths of the protein sequences.
743
+ lig_padding_mask (Optional[torch.Tensor]): A padding mask for the ligand.
744
+ return_attention (bool): If True, returns the cross-attention weights
745
+ from the last decoder layer.
746
+
747
+ Returns:
748
+ Tuple[torch.Tensor, torch.Tensor]: A tuple of (output_logits, attention_weights).
749
+ """
750
+ ligand = self.vocab_embed(ligand)
751
+ sigma = F.silu(self.sigma_map(sigma)).unsqueeze(1)
752
+ if self.apply_seq_linear:
753
+ sequence = self.linear_seq(sequence)
754
+ condition = torch.cat([sigma, sequence], dim=1)
755
+ sequence_lengths += 1
756
+
757
+ range_tensor = torch.arange(condition.shape[1], device=sequence.device).unsqueeze(0)
758
+ condition_mask = range_tensor < sequence_lengths.unsqueeze(1)
759
+ condition_mask = condition_mask.unsqueeze(1).unsqueeze(2)
760
+ if lig_padding_mask is not None:
761
+ lig_padding_mask = lig_padding_mask.unsqueeze(1).unsqueeze(2)
762
+
763
+ for layer in self.layers:
764
+ ligand, attention = layer(ligand, condition,
765
+ trg_mask=lig_padding_mask, src_mask=condition_mask,
766
+ return_attention=return_attention)
767
+
768
+ output = self.linear(ligand)
769
+ return output, attention
protobind_diff/esm_inference.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, sys
2
+ from typing import Optional, Tuple
3
+ from pathlib import Path
4
+ import esm
5
+ import os
6
+ import torch
7
+ import numpy as np
8
+ import re
9
+ from Bio import SeqIO
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import lightning.pytorch as pl
12
+ from protobind_diff.model import ModelGenerator
13
+ from protobind_diff.data_loader import InferenceDataset
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ REPO_ID = "ai-gero/ProtoBind-Diff"
17
+ FILENAME = "model.ckpt"
18
+ TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
19
+
20
+ class ProtobindInference():
21
+ """
22
+ Simplified inference class that only supports ProtobindMaskedDiffusion model.
23
+ """
24
+
25
+ def __init__(self, checkpoint_path, tokenizer_path,
26
+ sequence_embedding_dim, lig_max_length: int=170, nucleus_p: float=0.9,
27
+ eta: float=0.1, sampling_steps: int=250,
28
+ **kwargs):
29
+ self.checkpoint_path = Path(checkpoint_path)
30
+ self.tokenizer_path = Path(tokenizer_path)
31
+ self.sequence_embedding_dim = sequence_embedding_dim
32
+
33
+ # Set up sampler params
34
+ self.lig_max_length = lig_max_length
35
+ self.nucleus_p = nucleus_p
36
+ self.eta = eta
37
+ self.sampling_steps = sampling_steps
38
+
39
+ # Load model
40
+ self.model = self.load_model()
41
+
42
+ def predict_on_dataloader(self, dl, devices=1, accelerator='cuda') -> Tuple[np.ndarray, np.ndarray]:
43
+ if accelerator == 'cuda':
44
+ torch.set_float32_matmul_precision('medium')
45
+ precision = "16-mixed"
46
+ else:
47
+ precision = "32-true"
48
+ trainer = pl.Trainer(precision=precision, use_distributed_sampler=False,
49
+ inference_mode=True, accelerator=accelerator, devices=devices)
50
+ predictions_batches = trainer.predict(model=self.model, dataloaders=dl)
51
+ return predictions_batches
52
+
53
+ def load_model(self):
54
+ """Simplified model loading - only supports ModelGenerator"""
55
+ model = ModelGenerator.load_from_checkpoint(
56
+ self.checkpoint_path,
57
+ tokenizer_path=self.tokenizer_path,
58
+ seq_embedding_dim=self.sequence_embedding_dim,
59
+ load=True,
60
+ )
61
+ model.model_length = self.lig_max_length
62
+ model.nucleus_p = self.nucleus_p
63
+ model.eta = self.eta
64
+ model.sampling_steps = self.sampling_steps
65
+ model.model.eval()
66
+ return model
67
+
68
+ def get_esm_embedding(sequence: str, model_name: str, device: torch.device) -> torch.Tensor:
69
+ """Generates a protein embedding using a pre-trained ESM model.
70
+
71
+ Args:
72
+ sequence (str): The amino acid sequence.
73
+ model_name (str): The name of the ESM model to use.
74
+ device (torch.device): The device to run the model on.
75
+
76
+ Returns:
77
+ torch.Tensor: The final residue-level embedding tensor, with start/end tokens removed.
78
+ """
79
+ model, alphabet = esm.pretrained.load_model_and_alphabet(model_name)
80
+ model.eval()
81
+ number_layers = re.search(r'_t(\d+)_', model_name)
82
+ number_layers = int(number_layers.group(1))
83
+
84
+ model = model.to(device)
85
+ batch_converter = alphabet.get_batch_converter()
86
+ _, _, tokens = batch_converter([("protein", sequence)])
87
+ tokens = tokens.to(device)
88
+ with torch.no_grad():
89
+ out = model(tokens, repr_layers=[number_layers])
90
+ return out["representations"][number_layers][:, 1:-1, :] # [1, seq_len, emb_dim]
91
+
92
+ def download_from_hub_hf(cache: Path, filename) -> Path:
93
+ """
94
+ Fetch file from Hugging Face into `cache`.
95
+ Returns the local path to the file inside HF’s cache structure.
96
+ """
97
+ cache.mkdir(parents=True, exist_ok=True)
98
+ local_path = hf_hub_download(
99
+ repo_id=REPO_ID,
100
+ filename=filename,
101
+ cache_dir=cache,
102
+ )
103
+ return Path(local_path)
104
+
105
+ def main():
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--sequence", help="Amino acid sequence (1-letter code)")
108
+ parser.add_argument("--output_dir", default="./outputs", help="Output dir for SMILES")
109
+ parser.add_argument("--output", default="generated_smiles.txt", help="Output file for generated SMILES")
110
+ parser.add_argument("--n_batches", type=int, default=5, help="Number of batches to generate for this sequence")
111
+ parser.add_argument("--batch_size", type=int, default=10, help="Max number of generated molecules per batch")
112
+ parser.add_argument("--fasta_file", default="./examples/input.fasta", help="Input FASTA file")
113
+ parser.add_argument("--checkpoint_path", type=str, help="Path to the model checkpoint")
114
+ parser.add_argument('--model_name', type=str, default='esm2_t33_650M_UR50D',
115
+ help="ESM model name. See https://github.com/facebookresearch/esm")
116
+ parser.add_argument('--tokenizer_path', help='Path to tokenizer.json file. If not provided, uses a default path and downloads if needed.')
117
+ parser.add_argument('--cache', type=str, default = "./cache", help='Cache folder for ckpt')
118
+
119
+ parser.add_argument("--sampling_steps", type=int, default=250, help="Number of steps during sampling")
120
+ parser.add_argument("--lig_max_length", type=int, default=170, help="Max length of generated molecules")
121
+ parser.add_argument("--nucleus_p", type=float, default=0.9,
122
+ help="Value of the nucleus sampling parameter. For more details, see https://arxiv.org/abs/2503.00307")
123
+ parser.add_argument("--eta", type=float, default=0.1,
124
+ help="Value of the probability of remasking. For more details, see https://arxiv.org/abs/2503.00307")
125
+
126
+ args = parser.parse_args()
127
+ if args.fasta_file:
128
+ sequence = str(next(SeqIO.parse(args.fasta_file, "fasta")).seq)
129
+ elif args.sequence:
130
+ sequence = args.sequence.strip().upper()
131
+ else:
132
+ sys.exit("Error: provide --sequence of --fasta_file")
133
+
134
+ if args.checkpoint_path:
135
+ ckpt_path = Path(args.checkpoint_path)
136
+ else:
137
+ torch.hub.set_dir(args.cache) # for ESM model
138
+ ckpt_path = download_from_hub_hf(Path(args.cache), FILENAME)
139
+
140
+ if args.tokenizer_path:
141
+ tokenizer_path = Path(args.tokenizer_path)
142
+ if not tokenizer_path.exists():
143
+ sys.exit(f"Error: Tokenizer file not found at specified path: {tokenizer_path}")
144
+ else:
145
+ tokenizer_path = download_from_hub_hf(Path(args.cache), TOKENIZER_FILENAME)
146
+
147
+ # Determine the device
148
+ if torch.cuda.is_available():
149
+ device = torch.device("cuda") # Use CUDA if available
150
+ elif torch.backends.mps.is_available():
151
+ device = torch.device("mps") # Use MPS for Apple Silicon if available
152
+ else:
153
+ device = torch.device("cpu") # Fallback to CPU
154
+
155
+ embedding = get_esm_embedding(sequence, args.model_name, device).to(dtype=torch.bfloat16)
156
+ sequence_embedding_dim = embedding.shape[2]
157
+ dataset = InferenceDataset(embedding, batch_size=args.batch_size, n_batches=args.n_batches)
158
+ loader = DataLoader(dataset, batch_size=None)
159
+ model = ProtobindInference(ckpt_path, tokenizer_path, sequence_embedding_dim,
160
+ sampling_steps=args.sampling_steps, nucleus_p=args.nucleus_p,
161
+ eta=args.eta, lig_max_length=args.lig_max_length,)
162
+
163
+ predictions = model.predict_on_dataloader(loader, accelerator=str(device))
164
+
165
+ all_smiles = [smi for batch in predictions for smi in batch[0]]
166
+ out_dir = Path(args.output_dir)
167
+ os.makedirs(out_dir, exist_ok=True)
168
+ with open(out_dir / args.output, "w") as f:
169
+ f.write("SMILES\n")
170
+ for smi in all_smiles:
171
+ f.write(smi + "\n")
172
+
173
+
174
+ if __name__ == "__main__":
175
+ main()
protobind_diff/ligands/__init__.py ADDED
File without changes
protobind_diff/ligands/rdkit_utils.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import Optional, Tuple, Union, List
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ from multiprocessing import Pool
7
+ from pathlib import Path
8
+ from rdkit import Chem
9
+
10
+ from FPSim2 import FPSim2Engine
11
+ import rdkit
12
+ from rdkit import Chem, RDLogger
13
+ from rdkit.Chem import DataStructs, Descriptors
14
+ from rdkit.DataStructs import BulkTanimotoSimilarity
15
+ from sklearn.cluster import DBSCAN
16
+ import scipy
17
+ RDLogger.DisableLog('rdApp.*')
18
+
19
+
20
+ class BoostWrapper(object):
21
+ """ Help joblib to deal with boost functions """
22
+ def __init__(self, method_name, module_name):
23
+ self.method_name = method_name
24
+ self.module = importlib.import_module(module_name)
25
+
26
+ @property
27
+ def method(self):
28
+ return getattr(self.module, self.method_name)
29
+
30
+ def __call__(self, *args, **kwargs):
31
+ return self.method(*args, **kwargs)
32
+
33
+
34
+ def cluster_fpsim2(distance_path, smiles_h5_path=None, dist_eps=0.15):
35
+ """ Cluster precomputed FPSim2 distance matrix using DBSCAN algorithm """
36
+ if isinstance(distance_path, str):
37
+ distance_path = Path(distance_path, smiles_h5_path=None)
38
+
39
+ if smiles_h5_path is None:
40
+ smiles_h5_path = distance_path.parent / 'all_smiles.h5'
41
+ precomputed_indices = FPSim2Engine(smiles_h5_path).fps[:, 0]
42
+ map_precomputed = np.argsort(precomputed_indices) # maps original smiles order to FPSim2 order
43
+
44
+ precomputed_distance = scipy.sparse.load_npz(distance_path)
45
+ db = DBSCAN(eps=dist_eps, min_samples=1, metric='precomputed', n_jobs=-1)
46
+ labels = db.fit_predict(precomputed_distance)
47
+
48
+ # df_ = pd.DataFrame(data=smiles.keys(), index=list(smiles.values()), columns=['SMILES'])
49
+ # df_ = df_.sort_index()
50
+ # df_['cluster'] = labels[map_precomputed]
51
+ return labels[map_precomputed]
52
+
53
+
54
+ def tanimoto_smiles(mol1, mol2, fp='rdkit', bits=2048, radius=2):
55
+
56
+ if isinstance(mol1, str):
57
+ mol1 = Chem.MolFromSmiles(mol1)
58
+ if isinstance(mol2, str):
59
+ mol2 = Chem.MolFromSmiles(mol2)
60
+
61
+ _supported_fps = {
62
+ 'rdkit': Chem.RDKFingerprint,
63
+ 'morgan': Chem.rdMolDescriptors.GetMorganFingerprintAsBitVect,
64
+ 'maccs': Chem.rdMolDescriptors.GetMACCSKeysFingerprint,
65
+ }
66
+ if fp not in _supported_fps:
67
+ raise ValueError(f"Fingerprint {fp} is not supported, available fps {_supported_fps.keys()}")
68
+
69
+ ffp = None
70
+ if fp == 'rdkit':
71
+ ffp = lambda x: _supported_fps[fp](x, fpSize=bits)
72
+ elif fp == 'morgan':
73
+ ffp = lambda x: _supported_fps[fp](x, fpSize=bits, radius=radius, nBits=bits)
74
+ elif fp == 'maccs':
75
+ ffp = _supported_fps[fp]
76
+
77
+ return rdkit.DataStructs.TanimotoSimilarity(ffp(mol1), ffp(mol2))
78
+
79
+
80
+ def validate_smile(smile):
81
+ try:
82
+ mol = Chem.MolFromSmiles(smile)
83
+ Chem.SanitizeMol(mol)
84
+ return smile
85
+ except Exception:
86
+ return None
87
+
88
+
89
+ def calc_chem_desc(smiles):
90
+ rdkit_features = {'MolWt': rdkit.Chem.Descriptors.MolWt,
91
+ 'MolLogP': rdkit.Chem.Descriptors.MolLogP,
92
+ 'NumRotatableBonds': rdkit.Chem.Descriptors.NumRotatableBonds,
93
+ 'CalcTPSA': rdkit.Chem.rdMolDescriptors.CalcTPSA,
94
+ 'RingCount': rdkit.Chem.Descriptors.RingCount,
95
+ }
96
+ if isinstance(smiles[0], str):
97
+ mols = smiles_to_mols(smiles)
98
+ elif isinstance(smiles[0], rdkit.Chem.rdchem.Mol):
99
+ mols = smiles
100
+ else:
101
+ raise TypeError(f'smiles must be a string or a rdkit.Chem.rdchem.Mol: {type(smiles[0])}')
102
+ res = {}
103
+ for name, func in rdkit_features.items():
104
+ res[name] = np.asarray([func(m) if m is not None else np.nan for m in mols ])
105
+ return pd.DataFrame(res)
106
+
107
+
108
+ def smiles_to_mols(smiles, n_jobs=8):
109
+ if isinstance(smiles, (list, tuple, np.ndarray)):
110
+ pass
111
+ elif isinstance(smiles, pd.Series):
112
+ smiles = smiles.tolist()
113
+ else:
114
+ raise TypeError(f"{type(smiles)=}")
115
+
116
+ assert len(smiles) > 0
117
+ assert isinstance(smiles[0], str), f"expect smiles string, got f{smiles[0]}"
118
+
119
+ mols = joblib.Parallel(n_jobs=n_jobs)(
120
+ joblib.delayed(BoostWrapper('MolFromSmiles', 'rdkit.Chem.rdmolfiles', ))(smi) for smi in smiles)
121
+ return mols
122
+
123
+
124
+ def smiles_to_fps(smiles_or_mols, finger_type='rdkit', n_jobs=8, fp_param=None):
125
+ if isinstance(smiles_or_mols, (list, tuple, np.ndarray)):
126
+ pass
127
+ elif isinstance(smiles_or_mols, pd.Series):
128
+ smiles_or_mols = smiles_or_mols.tolist()
129
+ else:
130
+ raise TypeError(f"{type(smiles_or_mols)=}")
131
+
132
+ assert len(smiles_or_mols) > 0
133
+ assert isinstance(smiles_or_mols[0],
134
+ (str, rdkit.Chem.rdchem.Mol)), f"variable {smiles_or_mols[0]} has type {type(smiles_or_mols[0])}"
135
+
136
+ if isinstance(smiles_or_mols[0], str):
137
+ mols = smiles_to_mols(smiles_or_mols)
138
+ else:
139
+ mols = smiles_or_mols
140
+
141
+ if fp_param is None:
142
+ fp_param = {}
143
+ fp_func, fp_func_name, fp_func_module, fp_params = _find_fingerprint_function(finger_type)
144
+ fp_params.update(fp_param)
145
+ if finger_type == 'morgan':
146
+ fp_func = fp_func(**fp_params).GetFingerprint
147
+ fp_params = {}
148
+ fps = joblib.Parallel(n_jobs=n_jobs, prefer="threads")(
149
+ joblib.delayed(fp_func)(mol, **fp_params) for mol in mols)
150
+ return fps
151
+
152
+
153
+ def _find_fingerprint_function(finger_type: str) -> Tuple[callable, str, str, dict]:
154
+ kwargs = {}
155
+ if finger_type == 'rdkit':
156
+ fp_func_name = 'RDKFingerprint'
157
+ fp_func_module = 'rdkit.Chem'
158
+ elif finger_type == 'maccs':
159
+ fp_func_name = 'GetMACCSKeysFingerprint'
160
+ fp_func_module = 'rdkit.Chem.rdMolDescriptors'
161
+ elif finger_type == 'morgan':
162
+ fp_func_name = 'GetMorganGenerator'
163
+ fp_func_module = 'rdkit.Chem.AllChem'
164
+ kwargs = dict(atomInvariantsGenerator=rdkit.Chem.rdFingerprintGenerator.GetMorganFeatureAtomInvGen(),
165
+ radius=2, fpSize=2048, countSimulation=True)
166
+ else:
167
+ raise NotImplementedError(f"Use `rdkit` or `maccs` or `morgan` as fps")
168
+
169
+ fp_func = getattr(importlib.import_module(fp_func_module), fp_func_name)
170
+ return fp_func, fp_func_name, fp_func_module, kwargs
171
+
172
+
173
+ def randomize_smiles_rotated(smiles: str, with_order_reversal: bool = True) -> str:
174
+ """
175
+ Randomize a SMILES string by doing a cyclic rotation of the atomic indices.
176
+
177
+ Adapted from https://github.com/GLambard/SMILES-X/blob/758478663030580a363a9ee61c11f6d6448e18a1/SMILESX/augm.py#L19.
178
+
179
+ The outputs of this function can be reproduced by setting the seed with random.seed().
180
+
181
+ Raises:
182
+ InvalidSmiles: for invalid molecules.
183
+
184
+ Args:
185
+ smiles: SMILES string to randomize.
186
+ with_order_reversal: whether to reverse the atom order with 50% chance.
187
+
188
+ Returns:
189
+ Randomized SMILES string.
190
+ """
191
+
192
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
193
+
194
+ n_atoms = mol.GetNumAtoms()
195
+
196
+ # Generate random values
197
+ rotation_index = np.random.randint(0, n_atoms - 1)
198
+ reverse_order = with_order_reversal and np.random.choice([True, False])
199
+
200
+ # Generate new atom indices order
201
+ atoms = list(range(n_atoms))
202
+ new_atoms_order = (
203
+ atoms[rotation_index % len(atoms) :] + atoms[: rotation_index % len(atoms)]
204
+ )
205
+ if reverse_order:
206
+ new_atoms_order.reverse()
207
+
208
+ mol = Chem.RenumberAtoms(mol, new_atoms_order)
209
+ return Chem.MolToSmiles(mol, canonical=False)
protobind_diff/ligands/smiles_tokenizer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/MolecularAI/Chemformer/
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+ from pysmilesutils.tokenize import SMILESTokenizer
4
+
5
+ class ChemformerTokenizer(SMILESTokenizer):
6
+ """
7
+ Tokenizer for the Chemformer.
8
+
9
+ There are a few different features that sets this apart from the `SMILESTokenizer`:
10
+ * It reserves two extra special tokens, "mask" and "sep"
11
+ * It distinguish between chemical and non-chemical tokens
12
+
13
+ :param smiles: A list of SMILES that are used to create the vocabulary for the tokenizer. Defaults to None.
14
+ :param tokens: A list of tokens (strings) that the tokenizer uses when tokenizing SMILES. Defaults to None.
15
+ :param regex_token_patterns: A list of regular expressions that the tokenizer uses when tokenizing SMILES.
16
+ :param beginning_of_smiles_token: Token that is added to beginning of SMILES. Defaults to "^".
17
+ :param end_of_smiles_token: Token that is added to the end of SMILES. Defaults to "&".
18
+ :param padding_token: Token used for padding. Defalts to " ".
19
+ :param unknown_token: Token that is used for unknown ids when decoding encoded data. Defaults to "?".
20
+ :param mask_token: Token that is used by the Masker
21
+ :param sep_token: Token that is used to separate sentences, currently unused
22
+ :param filename: if given and `smiles` is None, load the vocabulary from disc
23
+ :raises: ValueError: If the `encoding_type` is invalid.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ smiles: List[str] = None,
29
+ tokens: List[str] = None,
30
+ regex_token_patterns: List[str] = None,
31
+ beginning_of_smiles_token: str = "^",
32
+ end_of_smiles_token: str = "&",
33
+ padding_token: str = "<PAD>",
34
+ unknown_token: str = "?",
35
+ mask_token: str = "<MASK>",
36
+ sep_token: str = "<SEP>",
37
+ filename: str = None,
38
+ ) -> None:
39
+ self._mask_token = mask_token
40
+ self._sep_token = sep_token
41
+ self._chem_start_idx = 6 # Default, number of special tokens + 1
42
+ self._chem_token_idxs: Optional[List[int]] = None
43
+ super().__init__(
44
+ smiles=smiles,
45
+ tokens=tokens,
46
+ regex_token_patterns=regex_token_patterns,
47
+ beginning_of_smiles_token=beginning_of_smiles_token,
48
+ end_of_smiles_token=end_of_smiles_token,
49
+ padding_token=padding_token,
50
+ unknown_token=unknown_token,
51
+ encoding_type="index",
52
+ filename=filename,
53
+ )
54
+
55
+
56
+ @property
57
+ def chem_token_idxs(self) -> List[int]:
58
+ """Returns the indices of the vocabulary that are chemical tokens"""
59
+ if self._chem_token_idxs is None:
60
+ self._chem_token_idxs = list(range(self._chem_start_idx, len(self.vocabulary)))
61
+ return self._chem_token_idxs
62
+
63
+ @property
64
+ def mask_token_id(self):
65
+ """Get the mask token id"""
66
+ return self.vocabulary[self._mask_token]
67
+
68
+ @property
69
+ def vocab_size(self):
70
+ return len(self.vocabulary)
71
+
72
+ @property
73
+ def special_tokens(self) -> Dict[str, str]:
74
+ """Returns a dictionary of non-character tokens"""
75
+ return {
76
+ "start": self._beginning_of_smiles_token,
77
+ "end": self._end_of_smiles_token,
78
+ "pad": self._padding_token,
79
+ "unknown": self._unknown_token,
80
+ "mask": self._mask_token,
81
+ "sep": self._sep_token,
82
+ }
83
+
84
+ def add_tokens(self, tokens: List[str], regex: bool = False, smiles=None) -> None:
85
+ """Adds tokens to the classes list of tokens.
86
+
87
+ The new tokens are added to the front of the token list and take priority over old tokens. Note that that the
88
+ vocabulary of the tokenizer is not updated after the tokens are added,
89
+ and must be updated by calling `create_vocabulary_from_smiles`.
90
+
91
+ If `regex` is False, the tokens are interpreted as non-chemical tokens, which distinguish
92
+ them for processing by e.g. the masker.
93
+
94
+ :param tokens: List of tokens to be added.
95
+ :param regex: If `True` the input tokens are treated as
96
+ regular expressions and are added to the list of regular expressions
97
+ instead of token list. Defaults to False.
98
+ :param smiles: If a list of smiles is provided, the vocabulary will be created, defaults to None
99
+
100
+ :raises ValueError: If any of the tokens supplied are already in the list
101
+ of tokens.
102
+ """
103
+ super().add_tokens(tokens, regex, smiles)
104
+ if not regex:
105
+ self._chem_start_idx += len(tokens)
106
+ self._chem_token_idxs = None
107
+
108
+ def _reset_vocabulary(self) -> Dict[str, int]:
109
+ """Create a new tokens vocabulary.
110
+
111
+ :return: New tokens vocabulary
112
+ """
113
+ dict_ = {
114
+ self._padding_token: 0,
115
+ self._unknown_token: 1,
116
+ self._beginning_of_smiles_token: 2,
117
+ self._end_of_smiles_token: 3,
118
+ self._mask_token: 4,
119
+ self._sep_token: 5,
120
+ }
121
+ for token in self._tokens:
122
+ dict_.setdefault(token, len(dict_))
123
+ return dict_
124
+
125
+ def _state_properties(self) -> Dict[str, Any]:
126
+ """Return properties to reconstruct the internal state of the tokenizer"""
127
+ dict_ = super()._state_properties()
128
+ dict_["chem_start_idx"] = self._chem_start_idx
129
+ return dict_
130
+
131
+ def _update_state(self, dict_: Dict[str, Any]) -> None:
132
+ """Update the internal state with properties loaded from disc"""
133
+ super()._update_state(dict_)
134
+ self._chem_start_idx = dict_["chem_start_idx"]
135
+ self._chem_token_idxs = None
protobind_diff/model.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple, Optional, Dict
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ import lightning.pytorch as pl
7
+ import logging
8
+ import huggingface_hub
9
+
10
+ from .ligands.rdkit_utils import validate_smile, calc_chem_desc, tanimoto_smiles
11
+ from .ligands.smiles_tokenizer import ChemformerTokenizer
12
+ from .noise_schedule import _sample_t, q_xt, _sample_categorical, LogLinearNoise
13
+ from .decoder_rope import Decoder_RoPE
14
+
15
+ logger = logging.getLogger("lightning")
16
+
17
+
18
+ class ModelGenerator(pl.LightningModule):
19
+ """
20
+ ProtoBind-Diff model with SMILES and ESM-2 protein encodings.
21
+ """
22
+ @staticmethod
23
+ def get_exp_dir(
24
+ exp_dir: str | None,
25
+ output_dir: str,
26
+ exp_dir_prefix: str,
27
+ split: str
28
+ ) -> Path:
29
+ """Determines the experiment directory path."""
30
+ if exp_dir:
31
+ return Path(exp_dir)
32
+ return Path(output_dir) / split / exp_dir_prefix
33
+
34
+ def __init__(self, *args, **kwargs):
35
+ """Initializes the Lightning Module, saves hyperparameters, and configures the model."""
36
+ super().__init__()
37
+
38
+ is_load = kwargs['load']
39
+ if not is_load:
40
+ self.save_hyperparameters()
41
+
42
+ self.data_dir = Path(kwargs["data_dir"])
43
+ exp_dir = kwargs.get('exp_dir', None)
44
+ self.exp_dir = self.get_exp_dir(
45
+ exp_dir=exp_dir,
46
+ output_dir=kwargs["output_dir"],
47
+ exp_dir_prefix=kwargs["exp_dir_prefix"],
48
+ split=kwargs["split"]
49
+ )
50
+
51
+ self.configure_model_params(**kwargs)
52
+
53
+ def configure_model_params(self, **kwargs):
54
+ """Parses keyword arguments to configure the model, tokenizer, and training parameters."""
55
+
56
+ self.learning_rate = kwargs.pop('learning_rate')
57
+ self.weight_decay = float(kwargs.pop('weight_decay'))
58
+
59
+ # Decoder params for masked diffusion
60
+ decoder_params = {
61
+ 'nhead': kwargs['num_heads_decoder'],
62
+ 'n_layers': kwargs['num_decoder_layers'],
63
+ 'hidden_size': kwargs['decoder_hidd_dim'],
64
+ 'expand_feedforward': kwargs['expand_feedforward'],
65
+ 'decoder_name': kwargs['decoder_name'],
66
+ }
67
+ # Tokenizer params
68
+ tokenizer_path = kwargs.get('tokenizer_path')
69
+ if tokenizer_path:
70
+ self.tokenizer = ChemformerTokenizer(filename=tokenizer_path)
71
+ else:
72
+ self.tokenizer = ChemformerTokenizer(filename=self.data_dir / f"{kwargs['tokenizer_json_name']}.json")
73
+
74
+ # Masking params
75
+ self.noise = LogLinearNoise()
76
+ self.mask_index = self.tokenizer.mask_token_id
77
+
78
+ # Sampler params
79
+ self.model_length = 170
80
+ self.noise_removal = True
81
+ self.nucleus_p = 0.9
82
+ self.eta = 0.1
83
+ self.sampling_steps = 100
84
+ self.time_conditioning = False
85
+
86
+ self.return_attention = False
87
+
88
+ self.model = ProtobindMaskedDiffusion(
89
+ embedding_dim=kwargs['seq_embedding_dim'],
90
+ mask_index=self.mask_index,
91
+ vocab_size=self.tokenizer.vocab_size,
92
+ decoder_params=decoder_params,
93
+ dropout=kwargs['dropout'],
94
+ )
95
+ self.optimizer = kwargs.get('optimizer', 'Adam')
96
+
97
+ def generate_mols(self, sequence: Tuple[torch.Tensor, torch.Tensor],
98
+ return_attention=False) -> Tuple[np.array, torch.Tensor,np.array]:
99
+ """Generates and validates SMILES strings for a given protein sequence.
100
+
101
+ This method calls the internal sampler, decodes the generated tokens into
102
+ SMILES strings, and filters out any invalid molecules.
103
+
104
+ Args:
105
+ sequence (Tuple[torch.Tensor, torch.Tensor]): The conditioned protein sequence
106
+ embedding and its length.
107
+ return_attention (bool): Whether to return attention maps from the sampler.
108
+
109
+ Returns:
110
+ Tuple[np.array, torch.Tensor, np.array]: A tuple containing the valid SMILES
111
+ strings, corresponding attention maps, and the mask of valid indices.
112
+ """
113
+ samples, attention = self._sample(sequence, return_attention=return_attention)
114
+ text_samples = self.tokenizer.decode(samples.long())
115
+ text_samples = np.array([validate_smile(smile) for smile in text_samples])
116
+
117
+ mask_invalid = (text_samples != None) & (text_samples != '.') & (text_samples != '')
118
+ text_samples = text_samples[mask_invalid]
119
+ if attention is not None:
120
+ attention = attention[mask_invalid]
121
+
122
+ return text_samples, attention, mask_invalid
123
+
124
+ def predict_step(self, batch, batch_idx):
125
+ sequence, smiles, seq_id, smi_id = batch
126
+ gen_samples, attention, mask_invalid = self.generate_mols(
127
+ sequence, return_attention=self.return_attention)
128
+ seq_id = seq_id[mask_invalid]
129
+ return gen_samples, attention, seq_id
130
+
131
+ def training_step(self, batch, batch_idx):
132
+ return self.common_step(batch, "train", batch_idx)
133
+
134
+ def validation_step(self, batch, batch_idx, dataloader_idx=None):
135
+ # dataloader_idx to predict on several validation sets
136
+ return self.common_step(batch, "val", batch_idx, dataloader_idx)
137
+
138
+ def test_step(self, batch, batch_idx, dataloader_idx=0):
139
+ return self.common_step(batch, "test", batch_idx)
140
+
141
+ def common_step(self, batch, description, batch_idx, dataloader_idx=None):
142
+ """Performs a common training, validation, or test step.
143
+
144
+ This method takes a batch, applies noise according to the diffusion
145
+ timestep, runs the model forward, calculates the loss, and logs metrics.
146
+
147
+ Args:
148
+ batch (Tuple): The input batch from the dataloader.
149
+ description (str): The step description (e.g., 'train', 'val').
150
+ batch_idx (int): The index of the batch.
151
+
152
+ Returns:
153
+ torch.Tensor: The calculated loss for the batch.
154
+ """
155
+ sequence, smiles, seq_id, smi_id = batch
156
+
157
+ # Get data and apply noise
158
+ X, length = smiles
159
+ bs = X.shape[0]
160
+ X = X.squeeze(-1)
161
+ padding_mask = (X != 0).float() # 0 is pad token id
162
+ t = _sample_t(X.shape[0], X.device)
163
+ sigma, dsigma = self.noise(t)
164
+ move_chance = 1 - torch.exp(-sigma[:, None])
165
+ xt = q_xt(X, move_chance, self.mask_index)
166
+ xt = xt.unsqueeze(dim=2)
167
+ smiles_t = (xt, length, None)
168
+
169
+ pred_x, _ = self.model(sequence, smiles_t, sigma, padding_mask)
170
+ total_loss = self.loss_mdlm(X.long(), pred_x, sigma, dsigma, padding_mask=None)
171
+
172
+ if batch_idx % 50 == 0:
173
+ tokens = pred_x.argmax(dim=-1) * padding_mask
174
+ true_smiles = self.tokenizer.decode(X.long())
175
+ pred_smiles = [smile for smile in self.tokenizer.decode(tokens)]
176
+ pred_smiles_valid = [validate_smile(smile) for smile in pred_smiles]
177
+
178
+ try:
179
+ tanimoto = np.asarray([tanimoto_smiles(mol_pred, mol_ref) for mol_pred, mol_ref
180
+ in zip(pred_smiles_valid, true_smiles) if mol_pred is not None])
181
+ tanimoto_mean = np.mean(tanimoto) if len(tanimoto) > 0 else 0
182
+ num_mols_valid = len(tanimoto)
183
+ except:
184
+ num_mols_valid = 0
185
+ tanimoto_mean = 0.0
186
+
187
+ self.log(f"{description}_tanimoto", tanimoto_mean, prog_bar=True,
188
+ on_epoch=True, sync_dist=True)
189
+ self.log(f"{description}_perc_of_valid", num_mols_valid / bs * 100, prog_bar=True,
190
+ on_epoch=True, sync_dist=True)
191
+
192
+ self.log(f"{description}_loss", total_loss, prog_bar=True, on_epoch=True,
193
+ sync_dist=True, batch_size=bs)
194
+ return total_loss
195
+
196
+ def configure_optimizers(self):
197
+ if self.weight_decay > 0.:
198
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
199
+ else:
200
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
201
+ return optimizer
202
+
203
+ def loss_mdlm(self, x_0, model_output, sigma, dsigma, padding_mask=None):
204
+ """Loss for SUBS parameterization, continuous time case"""
205
+ log_p_theta = torch.gather(
206
+ input=model_output,
207
+ dim=-1,
208
+ index=x_0[:, :, None]).squeeze(-1)
209
+
210
+ loss = - log_p_theta * (dsigma / torch.expm1(sigma))[:, None]
211
+
212
+ if padding_mask is not None:
213
+ return (loss * padding_mask).sum() / padding_mask.sum()
214
+ return loss.mean()
215
+
216
+ def _sample_prior(self, *batch_dims):
217
+ return self.mask_index * torch.ones(*batch_dims, dtype=torch.int64)
218
+
219
+ def _ddpm_caching_update(self, sequence, x, t, dt, p_x0=None, conf=None,
220
+ return_attention=False):
221
+ attention = None
222
+ if t.ndim > 1:
223
+ t = t.squeeze(-1)
224
+ sigma_t, _ = self.noise(t)
225
+ assert t.ndim == 1
226
+ move_chance_t = t[:, None, None]
227
+ move_chance_s = (t - dt)[:, None, None]
228
+ assert move_chance_t.ndim == 3, move_chance_t.shape
229
+ padding_mask = (x != 0).float()
230
+
231
+ if p_x0 is None:
232
+ p_x0, attention = self.model(sequence, (x.unsqueeze(dim=2), None, None), sigma_t,
233
+ padding_mask, return_attention=return_attention)
234
+ p_x0 = p_x0.exp()
235
+ if self.nucleus_p < 1:
236
+ sorted_probs, sorted_indices = torch.sort(p_x0, descending=True, dim=-1)
237
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
238
+ top_p_mask = cumulative_probs <= self.nucleus_p
239
+ top_p_mask[..., 0] = True
240
+ nucleus_probs = sorted_probs * top_p_mask
241
+ nucleus_probs /= nucleus_probs.sum(dim=-1, keepdim=True)
242
+ p_x0 = torch.zeros_like(p_x0).scatter_(-1, sorted_indices, nucleus_probs)
243
+
244
+ assert move_chance_t.ndim == p_x0.ndim
245
+
246
+ # Use remdm-cap sampler
247
+ alpha_t = (1 - move_chance_t)[0].item()
248
+ alpha_s = (1 - move_chance_s)[0].item()
249
+ if alpha_t > 0:
250
+ sigma = min(self.eta, (1 - alpha_s) / alpha_t)
251
+ else:
252
+ sigma = self.eta
253
+ q_xs = p_x0 * (1 - sigma)
254
+ q_xs[..., self.mask_index] = sigma
255
+ q_xs_2 = p_x0 * ((alpha_s - (1 - sigma) * alpha_t) / (1 - alpha_t))
256
+ q_xs_2[..., self.mask_index] = (1 - alpha_s - sigma * alpha_t) / (1 - alpha_t)
257
+ copy_flag = (x != self.mask_index).to(torch.bool)
258
+ q_xs = torch.where(copy_flag.unsqueeze(-1), q_xs, q_xs_2)
259
+ xs = _sample_categorical(q_xs)
260
+
261
+ if torch.allclose(xs, x) and not self.time_conditioning:
262
+ p_x0_cache = p_x0
263
+ else:
264
+ p_x0_cache = None
265
+
266
+ return p_x0_cache, xs, conf, attention
267
+
268
+ @torch.no_grad()
269
+ def _sample(self, sequence, eps=1e-3, return_attention=False):
270
+ """Generate samples from the model"""
271
+ num_steps = self.sampling_steps
272
+ bs = sequence[0].shape[0]
273
+ x = self._sample_prior(bs, self.model_length).to(self.device)
274
+
275
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
276
+ dt = (1 - eps) / num_steps
277
+ p_x0_cache = None
278
+
279
+ min_t = timesteps[-1].item()
280
+ confident_score = - torch.ones_like(x, device=self.device) * torch.inf
281
+
282
+ for i in range(num_steps):
283
+ t = timesteps[i] * torch.ones(bs, 1, device=self.device)
284
+ p_x0_cache, x_next, confident_score, attention = self._ddpm_caching_update(
285
+ sequence, x, t, dt, p_x0=p_x0_cache, conf=confident_score,
286
+ return_attention=return_attention)
287
+
288
+ if (not torch.allclose(x_next, x)):
289
+ p_x0_cache = None
290
+ x = x_next
291
+
292
+ if self.noise_removal:
293
+ t = min_t * torch.ones(bs, 1, device=self.device)
294
+ unet_conditioning = self.noise(t)[0]
295
+ padding_mask = (x != 0).float()
296
+ x, attention = self.model(sequence, (x, None, None), unet_conditioning.squeeze(-1),
297
+ padding_mask, return_attention=return_attention)
298
+ x = x.argmax(dim=-1)
299
+ return x, attention
300
+
301
+
302
+ class ProtobindMaskedDiffusion(nn.Module, huggingface_hub.PyTorchModelHubMixin):
303
+ """The core Protobind-Diff model, which uses a Transformer decoder with RoPE.
304
+
305
+ This model is designed for a masked diffusion task and supports conditioning
306
+ on ESM-2 protein embeddings and generating ligands with a ChemformerTokenizer.
307
+ """
308
+
309
+
310
+ def __init__(self,
311
+ embedding_dim: int,
312
+ mask_index: int,
313
+ vocab_size: int,
314
+ decoder_params: Optional[dict] = None,
315
+ dropout: float = 0.2,
316
+ parametrization_strategy: str = 'subs',
317
+ **kwargs) -> None:
318
+ """Initializes the ProtobindMaskedDiffusion model.
319
+
320
+ Args:
321
+ embedding_dim (int): The dimension of the protein sequence embeddings.
322
+ mask_index (int): The token ID for the MASK token.
323
+ vocab_size (int): The size of the ligand's vocabulary.
324
+ decoder_params (Optional[dict]): A dictionary of parameters for the
325
+ internal Transformer decoder (e.g., nhead, n_layers).
326
+ dropout (float): The dropout rate.
327
+ parametrization_strategy (str): The diffusion parameterization to use.
328
+ Currently only 'subs' is supported.
329
+ """
330
+ super().__init__()
331
+
332
+ self.neg_infinity = -1000000.0
333
+ self.parametrization_strategy = parametrization_strategy
334
+ self.decoder_name = decoder_params.pop('decoder_name')
335
+ expand_feedforward = decoder_params.pop('expand_feedforward')
336
+ self.mask_index = mask_index
337
+
338
+ # Decoder options
339
+ if self.decoder_name == 'decoder_re':
340
+ self.decoder = Decoder_RoPE(vocab_size, embedding_dim, expand_feedforward=expand_feedforward,
341
+ dropout=dropout, **decoder_params)
342
+ else:
343
+ raise ValueError(f"Model only supports decoder with rotary embeddings ('decoder_re'), got: {self.decoder_name}")
344
+
345
+ def forward(self,
346
+ sequence: Tuple[torch.Tensor, torch.Tensor],
347
+ ligands: Tuple[torch.Tensor, torch.Tensor],
348
+ sigma: torch.Tensor,
349
+ mask_ligand: torch.Tensor,
350
+ return_attention: bool = False) -> torch.Tensor:
351
+ """Performs the main forward pass of the diffusion model.
352
+
353
+ Args:
354
+ sequence (Tuple[torch.Tensor, torch.Tensor]): A tuple of the conditioning
355
+ protein sequence embeddings and their lengths.
356
+ ligands (Tuple[torch.Tensor, torch.Tensor]): A tuple
357
+ containing the noised ligand `xt`and its length.
358
+ sigma (torch.Tensor): The diffusion timestep (noise level).
359
+ mask_ligand (torch.Tensor): The padding mask for the ligand.
360
+ return_attention (bool): If True, returns attention weights from the decoder.
361
+
362
+ Returns:
363
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing the final predicted logits
364
+ and the attention weights.
365
+ """
366
+
367
+ sequence, sequence_lengths = sequence
368
+ xt, ligand_lengths, _ = ligands
369
+
370
+ # Decode ligand
371
+ ligand_masked = xt.squeeze(-1).long()
372
+ ligand_decoded, attention = self.decoder(ligand_masked,
373
+ sigma,
374
+ sequence,
375
+ sequence_lengths,
376
+ lig_padding_mask=None,
377
+ return_attention=return_attention)
378
+
379
+ # Apply parametrization
380
+ ligand_decoded = self.parametrization(ligand_decoded, xt)
381
+
382
+ return ligand_decoded, attention
383
+
384
+ def parametrization(self, logits, xt):
385
+ """Applies the chosen parameterization to the model's output logits.
386
+
387
+ The 'subs' strategy modifies the logits to represent the probability
388
+ p(x_{t-1}|x_t), enforcing that unmasked tokens remain unchanged.
389
+
390
+ Args:
391
+ logits (torch.Tensor): The raw output logits from the decoder.
392
+ xt (torch.Tensor): The noised input ligand at timestep t.
393
+
394
+ Returns:
395
+ torch.Tensor: The re-parameterized logits.
396
+ """
397
+ if self.parametrization_strategy == 'subs':
398
+ # log prob at the mask index = - infinity
399
+ logits[:, :, self.mask_index] += self.neg_infinity
400
+
401
+ # Normalize the logits
402
+ logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
403
+
404
+ # Apply updates for unmasked tokens
405
+ xt = xt.squeeze(-1)
406
+ unmasked_indices = (xt != self.mask_index)
407
+ logits[unmasked_indices] = self.neg_infinity
408
+ logits[unmasked_indices, xt[unmasked_indices].long()] = 0
409
+ else:
410
+ raise NotImplementedError(f'Parametrization strategy {self.parametrization_strategy} not implemented')
411
+ return logits
protobind_diff/noise_schedule.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ # Flags required to enable jit fusion kernels
7
+ torch._C._jit_set_profiling_mode(False)
8
+ torch._C._jit_set_profiling_executor(False)
9
+ torch._C._jit_override_can_fuse_on_cpu(True)
10
+ torch._C._jit_override_can_fuse_on_gpu(True)
11
+
12
+
13
+
14
+ def _sample_categorical(categorical_probs):
15
+ gumbel_norm = (
16
+ 1e-10
17
+ - (torch.rand_like(categorical_probs) + 1e-10).log())
18
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
19
+
20
+
21
+ def _unsqueeze(x, reference):
22
+ return x.view(
23
+ * x.shape,
24
+ * ((1,) * (len(reference.shape) - len(x.shape))))
25
+
26
+
27
+ def _sample_t(n, device, antithetic_sampling=True, sampling_eps=1e-3):
28
+ _eps_t = torch.rand(n, device=device)
29
+ if antithetic_sampling:
30
+ offset = torch.arange(n, device=device) / n
31
+ _eps_t = (_eps_t / n + offset) % 1
32
+ t = (1 - sampling_eps) * _eps_t + sampling_eps
33
+ return t
34
+
35
+
36
+ def q_xt( x, move_chance, mask_index):
37
+ """Computes the noisy sample xt.
38
+
39
+ Args:
40
+ x: int torch.Tensor with shape (batch_size,
41
+ diffusion_model_input_length), input.
42
+ move_chance: float torch.Tensor with shape (batch_size, 1).
43
+ """
44
+ move_indices = torch.rand(
45
+ * x.shape, device=x.device) < move_chance
46
+ xt = torch.where(move_indices, mask_index, x)
47
+ return xt
48
+
49
+
50
+ def get_noise(config, dtype=torch.float32):
51
+ if config.noise.type == 'geometric':
52
+ return GeometricNoise(config.noise.sigma_min,
53
+ config.noise.sigma_max)
54
+ elif config.noise.type == 'loglinear':
55
+ return LogLinearNoise()
56
+ elif config.noise.type == 'cosine':
57
+ return CosineNoise()
58
+ elif config.noise.type == 'cosinesqr':
59
+ return CosineSqrNoise()
60
+ elif config.noise.type == 'linear':
61
+ return Linear(config.noise.sigma_min,
62
+ config.noise.sigma_max,
63
+ dtype)
64
+ else:
65
+ raise ValueError(f'{config.noise.type} is not a valid noise')
66
+
67
+
68
+ def binary_discretization(z):
69
+ z_hard = torch.sign(z)
70
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
71
+ return z_soft + (z_hard - z_soft).detach()
72
+
73
+
74
+ class Noise(abc.ABC, nn.Module):
75
+ """
76
+ Baseline forward method to get the total + rate of noise at a timestep
77
+ """
78
+ def forward(self, t):
79
+ # Assume time goes from 0 to 1
80
+ return self.total_noise(t), self.rate_noise(t)
81
+
82
+ @abc.abstractmethod
83
+ def rate_noise(self, t):
84
+ """
85
+ Rate of change of noise ie g(t)
86
+ """
87
+ pass
88
+
89
+ @abc.abstractmethod
90
+ def total_noise(self, t):
91
+ """
92
+ Total noise ie \int_0^t g(t) dt + g(0)
93
+ """
94
+ pass
95
+
96
+
97
+ class CosineNoise(Noise):
98
+ def __init__(self, eps=1e-3):
99
+ super().__init__()
100
+ self.eps = eps
101
+
102
+ def rate_noise(self, t):
103
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
104
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
105
+ scale = torch.pi / 2
106
+ return scale * sin / (cos + self.eps)
107
+
108
+ def total_noise(self, t):
109
+ cos = torch.cos(t * torch.pi / 2)
110
+ return - torch.log(self.eps + (1 - self.eps) * cos)
111
+
112
+
113
+ class CosineSqrNoise(Noise):
114
+ def __init__(self, eps=1e-3):
115
+ super().__init__()
116
+ self.eps = eps
117
+
118
+ def rate_noise(self, t):
119
+ cos = (1 - self.eps) * (
120
+ torch.cos(t * torch.pi / 2) ** 2)
121
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
122
+ scale = torch.pi / 2
123
+ return scale * sin / (cos + self.eps)
124
+
125
+ def total_noise(self, t):
126
+ cos = torch.cos(t * torch.pi / 2) ** 2
127
+ return - torch.log(self.eps + (1 - self.eps) * cos)
128
+
129
+
130
+ class Linear(Noise):
131
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
132
+ super().__init__()
133
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
134
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
135
+
136
+ def rate_noise(self, t):
137
+ return self.sigma_max - self.sigma_min
138
+
139
+ def total_noise(self, t):
140
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
141
+
142
+ def importance_sampling_transformation(self, t):
143
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
144
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
145
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
146
+ return (sigma_t - self.sigma_min) / (
147
+ self.sigma_max - self.sigma_min)
148
+
149
+
150
+ class GeometricNoise(Noise):
151
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
152
+ super().__init__()
153
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
154
+
155
+ def rate_noise(self, t):
156
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
157
+ self.sigmas[1].log() - self.sigmas[0].log())
158
+
159
+ def total_noise(self, t):
160
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
161
+
162
+
163
+ class LogLinearNoise(Noise):
164
+ """Log Linear noise schedule.
165
+
166
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and 1.
167
+ """
168
+ def __init__(self, eps=1e-3):
169
+ super().__init__()
170
+ self.eps = eps
171
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
172
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
173
+
174
+ def rate_noise(self, t):
175
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
176
+
177
+ def total_noise(self, t):
178
+ return -torch.log1p(-(1 - self.eps) * t)
179
+
180
+ def importance_sampling_transformation(self, t):
181
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
182
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
183
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
184
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
185
+ return t
pyproject.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ProtoBind-Diff"
3
+ version = "0.1.0"
4
+ description = "ProtoBind-Diff: A Structure-Free Diffusion Language Model for Protein Sequence-Conditioned Ligand Design"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10,<3.13"
7
+ license = "MIT AND (Apache-2.0 OR BSD-2-Clause)"
8
+ authors = [
9
+ { name = "Lukia Mistryukova", email = "lukiia.mistriukova@gero.ai" },
10
+ { name = "Vladimir Manuilov", email = "vladimir.manuylov@gero.ai" },
11
+ { name = "Konstantin Avchaciov", email = "ka@gero.ai" },
12
+ { name = "Peter Fedichev", email = "pf@gero.ai" },
13
+ ]
14
+ dependencies = [
15
+ "torch>=2.2",
16
+ "numpy>=1.26,<2.0",
17
+ "lightning>=2.3.0",
18
+ "rdkit>=2024.3.2",
19
+ "requests==2.32.3",
20
+ "pandas>=2.2.2",
21
+ "PyYAML>=6.0",
22
+ "scipy>=1.13.0",
23
+ "scikit-learn>=1.1.0",
24
+ "fair-esm==2.0.0",
25
+ "biopython>=1.80",
26
+ "pysmilesutils @ git+https://github.com/MolecularAI/pysmilesutils.git",
27
+ "FPSim2==0.5.2",
28
+ "huggingface_hub",
29
+ "einops==0.8.0",
30
+ "easydict>=1.11",
31
+ "tensorboard>=2.14.0",
32
+ "rich>=13.5.0"
33
+ ]
34
+
35
+ [project.scripts]
36
+ protobind-train = "protobind_diff.train:main"
37
+ protobind-infer = "protobind_diff.esm_inference:main"
38
+
39
+ [build-system]
40
+ requires = ["setuptools>=61.0"]
41
+ build-backend = "setuptools.build_meta"
42
+
43
+ [tool.setuptools.packages.find]
44
+ include = ["protobind_diff"]