update
Browse files- .gitattributes +1 -0
- README.md +43 -3
- data_preprocessing/data.py +235 -0
- data_preprocessing/data_split.py +101 -0
- mcts.png +3 -0
- mdlm.png +0 -0
- peptune.png +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
mcts.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,43 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
extra_gated_fields:
|
| 3 |
+
Name: text
|
| 4 |
+
Company: text
|
| 5 |
+
Country: country
|
| 6 |
+
Specific date: date_picker
|
| 7 |
+
I want to use this model for:
|
| 8 |
+
type: select
|
| 9 |
+
options:
|
| 10 |
+
- Research
|
| 11 |
+
- Education
|
| 12 |
+
- label: Other
|
| 13 |
+
value: other
|
| 14 |
+
extra_gated_prompt: "PepTune License: https://duke.box.com/s/5ghseh23rpsyou66kg60qr89sxt5twyu"
|
| 15 |
+
extra_gated_heading: Acknowledge license to access the repository
|
| 16 |
+
extra_gated_button_content: Acknowledge license
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
<div align="center">
|
| 20 |
+
<img src="peptune.png" alt="peptune" width="300" height="300">
|
| 21 |
+
</div>
|
| 22 |
+
|
| 23 |
+
# PepTune: *De Novo* Generation of Therapeutic Peptides with Multi-Objective-Guided Discrete Diffusion
|
| 24 |
+
|
| 25 |
+
Peptide therapeutics, a major class of medicines, have achieved remarkable success across diseases like diabetes and cancer, with landmark examples such as GLP-1 receptor agonists revolutionizing the treatment of type-2 diabetes and obesity. Despite their success, designing peptides that satisfy multiple conflicting objectives, such as binding affinity, solubility, and membrane permeability, remains a major challenge. Classical drug development and target structure-based design methods are ineffective for such tasks, as they fail to optimize global functional properties critical for therapeutic efficacy. Existing generative frameworks are largely limited to continuous spaces, unconditioned outputs, or single-objective guidance, making them unsuitable for discrete sequence optimization across multiple properties. To address this, we present **PepTune**, a multi-objective discrete diffusion model for the simultaneous generation and optimization of therapeutic peptide SMILES. Built on the Masked Discrete Language Model (MDLM) framework, PepTune ensures valid peptide structures with state-dependent masking schedules and penalty-based objectives. To guide the diffusion process, we propose a Monte Carlo Tree Search (MCTS)-based strategy that balances exploration and exploitation to iteratively refine Pareto-optimal sequences. MCTS integrates classifier-based rewards with search-tree expansion, overcoming gradient estimation challenges and data sparsity inherent to discrete spaces. Using PepTune, we generate diverse, chemically-modified peptides optimized for multiple therapeutic properties, including target binding affinity, membrane permeability, solubility, hemolysis, and non-fouling characteristics on various disease-relevant targets. In total, our results demonstrate that MCTS-guided discrete diffusion is a powerful and modular approach for multi-objective sequence design in discrete state spaces.
|
| 26 |
+
|
| 27 |
+
## We build our training framework on top of [Masked Discrete Language Model](https://huggingface.co/kuleshov-group/mdlm-owt).
|
| 28 |
+

|
| 29 |
+
|
| 30 |
+
## We optimize desired therapeutic properties of generated sequences based on Monte Carlo Tree Search
|
| 31 |
+

|
| 32 |
+
|
| 33 |
+
## Inference API, datasets, and sequences will be freely accessible to the academic community via a non-commercial license upon publication and provisional patent filing
|
| 34 |
+
|
| 35 |
+
## Interactive Demo
|
| 36 |
+
|
| 37 |
+
You can try out the our peptide visualizer directly in your browser, other property classifiers will be added soon:
|
| 38 |
+
|
| 39 |
+
<https://huggingface.co/spaces/ChatterjeeLab/SMILES2PEPTIDE>
|
| 40 |
+
|
| 41 |
+
## Usage
|
| 42 |
+
|
| 43 |
+
To use this repository, you agree to abide by the [PepTune License](https://duke.box.com/s/5ghseh23rpsyou66kg60qr89sxt5twyu).
|
data_preprocessing/data.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
from datasets import Dataset, DatasetDict, load_from_disk
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import os
|
| 6 |
+
from multiprocessing import Pool
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import lightning.pytorch as pl
|
| 9 |
+
sys.path.append('/home/yz927/projects/peptune/scripts/')
|
| 10 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 11 |
+
global_tokenizer = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def init_pool(tokenizer):
|
| 15 |
+
global global_tokenizer
|
| 16 |
+
global_tokenizer = tokenizer
|
| 17 |
+
|
| 18 |
+
class SequenceDataset:
|
| 19 |
+
def __init__(self, sequences, tokenizer, max_sequence_length, num_cores=8):
|
| 20 |
+
self.sequences = sequences
|
| 21 |
+
self.tokenizer = tokenizer
|
| 22 |
+
self.max_sequence_length = max_sequence_length
|
| 23 |
+
self.num_cores = 8
|
| 24 |
+
self.tokenized_sequences = []
|
| 25 |
+
self.original_sequences = []
|
| 26 |
+
|
| 27 |
+
def tokenize_sequences(self):
|
| 28 |
+
print(f"Starting parallel tokenization using {self.num_cores} cores")
|
| 29 |
+
with Pool(processes=self.num_cores, initializer=init_pool, initargs=(self.tokenizer,)) as pool:
|
| 30 |
+
results = list(tqdm(
|
| 31 |
+
pool.imap(standalone_tokenize_function, self.sequences),
|
| 32 |
+
total=len(self.sequences)
|
| 33 |
+
))
|
| 34 |
+
|
| 35 |
+
for result, seq in zip(results, self.sequences):
|
| 36 |
+
if result is not None and len(result['input_ids'][0]) <= self.max_sequence_length:
|
| 37 |
+
self.tokenized_sequences.append(result)
|
| 38 |
+
self.original_sequences.append(seq)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def process_sequences(self, batch_size):
|
| 42 |
+
self.tokenize_sequences()
|
| 43 |
+
|
| 44 |
+
lengths = [(len(seq['input_ids'][0]), i) for i, seq in enumerate(self.tokenized_sequences)]
|
| 45 |
+
lengths.sort()
|
| 46 |
+
|
| 47 |
+
batches = []
|
| 48 |
+
sequence_batches = []
|
| 49 |
+
current_batch = []
|
| 50 |
+
current_sequence_batch = []
|
| 51 |
+
current_length = 0
|
| 52 |
+
|
| 53 |
+
for length, idx in tqdm(lengths):
|
| 54 |
+
if current_length + length > self.max_sequence_length or len(current_batch) == batch_size:
|
| 55 |
+
if current_batch:
|
| 56 |
+
batches.append([self.tokenized_sequences[i] for i in current_batch])
|
| 57 |
+
sequence_batches.append([self.original_sequences[i] for i in current_batch])
|
| 58 |
+
current_batch = [idx]
|
| 59 |
+
current_sequence_batch = [self.original_sequences[idx]]
|
| 60 |
+
current_length = length
|
| 61 |
+
else:
|
| 62 |
+
current_batch.append(idx)
|
| 63 |
+
current_sequence_batch.append(self.original_sequences[idx])
|
| 64 |
+
current_length += length
|
| 65 |
+
|
| 66 |
+
if current_batch:
|
| 67 |
+
batches.append([self.tokenized_sequences[i] for i in current_batch])
|
| 68 |
+
sequence_batches.append([self.original_sequences[i] for i in current_batch])
|
| 69 |
+
|
| 70 |
+
token_batch_fn = TokenizeBatch(self.tokenizer)
|
| 71 |
+
processed_batches = [token_batch_fn(batch) for batch in tqdm(batches)]
|
| 72 |
+
|
| 73 |
+
dataset = Dataset.from_dict({
|
| 74 |
+
'attention_mask': [batch['attention_mask'] for batch in processed_batches],
|
| 75 |
+
'input_ids': [batch['input_ids'] for batch in processed_batches],
|
| 76 |
+
'labels': sequence_batches
|
| 77 |
+
})
|
| 78 |
+
|
| 79 |
+
return dataset
|
| 80 |
+
|
| 81 |
+
class DynamicBatchingDataset(Dataset):
|
| 82 |
+
"""
|
| 83 |
+
Process dynamically batched datasets of Huggingface Datasets object. Need special handling since in the previous
|
| 84 |
+
steps, each batch (row in the Datasets object) is already processed for per batch loading
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, dataset_dict):
|
| 88 |
+
print('Initializing dataset...')
|
| 89 |
+
self.dataset_dict = {
|
| 90 |
+
'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
|
| 91 |
+
'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
|
| 92 |
+
'labels': dataset_dict['labels'] # Store original sequences as it is
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def __len__(self):
|
| 96 |
+
return len(self.dataset_dict['attention_mask'])
|
| 97 |
+
|
| 98 |
+
def __getitem__(self, idx):
|
| 99 |
+
if isinstance(idx, int):
|
| 100 |
+
return {
|
| 101 |
+
'attention_mask': self.dataset_dict['attention_mask'][idx],
|
| 102 |
+
'input_ids': self.dataset_dict['input_ids'][idx],
|
| 103 |
+
'labels': self.dataset_dict['labels'][idx]
|
| 104 |
+
}
|
| 105 |
+
elif isinstance(idx, list):
|
| 106 |
+
return {
|
| 107 |
+
'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
|
| 108 |
+
'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
|
| 109 |
+
'labels': [self.dataset_dict['labels'][i] for i in idx]
|
| 110 |
+
}
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def collate_fn(batch, verbose=False):
|
| 116 |
+
item = batch[0]
|
| 117 |
+
return {
|
| 118 |
+
'input_ids': item['input_ids'],
|
| 119 |
+
'attention_mask': item['attention_mask'],
|
| 120 |
+
'labels': item['labels']
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
def standalone_tokenize_function(sequence):
|
| 124 |
+
global global_tokenizer
|
| 125 |
+
try:
|
| 126 |
+
tokens = global_tokenizer(sequence)
|
| 127 |
+
# The tokenizer already returns lists of integers, so we just need to wrap them in another list
|
| 128 |
+
# to match the expected format [batch_size, sequence_length]
|
| 129 |
+
return {
|
| 130 |
+
'input_ids': [tokens['input_ids']],
|
| 131 |
+
'attention_mask': [tokens['attention_mask']]
|
| 132 |
+
}
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"Error tokenizing sequence '{sequence}': {e}")
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
class TokenizeBatch:
|
| 138 |
+
def __init__(self, tokenizer):
|
| 139 |
+
self.pad_token_id = tokenizer.pad_token_id
|
| 140 |
+
|
| 141 |
+
def __call__(self, batches):
|
| 142 |
+
data_tokens = [torch.tensor(batch['input_ids'][0]) for batch in batches]
|
| 143 |
+
data_tokens_padded = torch.nn.utils.rnn.pad_sequence(data_tokens, batch_first=True, padding_value=self.pad_token_id)
|
| 144 |
+
attention_masks = (data_tokens_padded != self.pad_token_id).long()
|
| 145 |
+
|
| 146 |
+
return {
|
| 147 |
+
'input_ids': data_tokens_padded,
|
| 148 |
+
'attention_mask': attention_masks,
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
class PretrainSequenceDataModule(pl.LightningDataModule):
|
| 152 |
+
def __init__(self,
|
| 153 |
+
tokenizer,
|
| 154 |
+
input_dataset_path,
|
| 155 |
+
output_dataset_path,
|
| 156 |
+
num_workers,
|
| 157 |
+
batch_size,
|
| 158 |
+
max_sequence_length=512,):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.tokenizer = tokenizer
|
| 161 |
+
self.input_path = input_dataset_path
|
| 162 |
+
self.output_path = output_dataset_path
|
| 163 |
+
self.num_workers = num_workers
|
| 164 |
+
self.batch_size = batch_size
|
| 165 |
+
self.max_sequence_length = max_sequence_length
|
| 166 |
+
|
| 167 |
+
def prepare_data(self):
|
| 168 |
+
if not os.path.exists(self.output_path):
|
| 169 |
+
print("Loading text files")
|
| 170 |
+
with open(f"{self.input_path}/train.txt", 'r') as f:
|
| 171 |
+
train_sequences = [line.strip() for line in f if line.strip()]
|
| 172 |
+
with open(f"{self.input_path}/val.txt", 'r') as f:
|
| 173 |
+
val_sequences = [line.strip() for line in f if line.strip()]
|
| 174 |
+
|
| 175 |
+
print("Processing training data")
|
| 176 |
+
train_dataset = SequenceDataset(train_sequences,
|
| 177 |
+
self.tokenizer,
|
| 178 |
+
self.max_sequence_length)
|
| 179 |
+
print("Processing validation data")
|
| 180 |
+
val_dataset = SequenceDataset(val_sequences,
|
| 181 |
+
self.tokenizer,
|
| 182 |
+
self.max_sequence_length)
|
| 183 |
+
|
| 184 |
+
processed_train = train_dataset.process_sequences(self.batch_size)
|
| 185 |
+
processed_val = val_dataset.process_sequences(self.batch_size)
|
| 186 |
+
|
| 187 |
+
print("Combining datasets")
|
| 188 |
+
combined_dataset = DatasetDict({
|
| 189 |
+
'train': processed_train,
|
| 190 |
+
'val': processed_val,
|
| 191 |
+
})
|
| 192 |
+
|
| 193 |
+
print(f"Saving dataset to {self.output_path}")
|
| 194 |
+
combined_dataset.save_to_disk(self.output_path)
|
| 195 |
+
|
| 196 |
+
def setup(self, stage: str):
|
| 197 |
+
print("Loading processed dataset")
|
| 198 |
+
dataset = load_from_disk(self.output_path)
|
| 199 |
+
self.train_dataset = DynamicBatchingDataset(dataset['train'])
|
| 200 |
+
self.val_dataset = DynamicBatchingDataset(dataset['val'])
|
| 201 |
+
|
| 202 |
+
def train_dataloader(self):
|
| 203 |
+
print("Creating training dataloader")
|
| 204 |
+
return DataLoader(self.train_dataset,
|
| 205 |
+
batch_size=1,
|
| 206 |
+
shuffle=False,
|
| 207 |
+
num_workers=self.num_workers,
|
| 208 |
+
collate_fn=DynamicBatchingDataset.collate_fn,
|
| 209 |
+
pin_memory=True)
|
| 210 |
+
|
| 211 |
+
def val_dataloader(self):
|
| 212 |
+
print("Creating validation dataloader")
|
| 213 |
+
return DataLoader(self.val_dataset,
|
| 214 |
+
batch_size=1,
|
| 215 |
+
shuffle=False,
|
| 216 |
+
num_workers=self.num_workers,
|
| 217 |
+
collate_fn=DynamicBatchingDataset.collate_fn,
|
| 218 |
+
pin_memory=True)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == '__main__':
|
| 222 |
+
tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
|
| 223 |
+
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 224 |
+
dm = PretrainSequenceDataModule(
|
| 225 |
+
tokenizer=tokenizer,
|
| 226 |
+
input_dataset_path='/home/yz927/projects/peptune/tokens/11M_smiles',
|
| 227 |
+
output_dataset_path='/home/yz927/projects/peptune/tokenized/11M_smiles_old_tokenizer_no_limit',
|
| 228 |
+
num_workers=8,
|
| 229 |
+
batch_size=2000,
|
| 230 |
+
max_sequence_length=16*1000,
|
| 231 |
+
)
|
| 232 |
+
dm.prepare_data()
|
| 233 |
+
dm.setup('fit')
|
| 234 |
+
dm.train_dataloader()
|
| 235 |
+
dm.val_dataloader()
|
data_preprocessing/data_split.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rdkit import Chem
|
| 2 |
+
from rdkit.Chem import AllChem
|
| 3 |
+
from rdkit import DataStructs
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.cluster import MiniBatchKMeans
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import selfies as sf
|
| 9 |
+
from multiprocessing import Pool, cpu_count
|
| 10 |
+
from functools import partial
|
| 11 |
+
def generate_fingerprint_batch_selfies(selfies_batch):
|
| 12 |
+
fps = []
|
| 13 |
+
valid_selfies = []
|
| 14 |
+
|
| 15 |
+
for selfies in tqdm(selfies_batch, desc="Generating fingerprints", leave=False):
|
| 16 |
+
try:
|
| 17 |
+
# Convert SELFIES to SMILES then to molecule
|
| 18 |
+
smiles = sf.decoder(selfies)
|
| 19 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 20 |
+
if mol is not None:
|
| 21 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048)
|
| 22 |
+
arr = np.zeros((1,))
|
| 23 |
+
DataStructs.ConvertToNumpyArray(fp, arr)
|
| 24 |
+
fps.append(arr)
|
| 25 |
+
valid_selfies.append(selfies)
|
| 26 |
+
except:
|
| 27 |
+
continue
|
| 28 |
+
|
| 29 |
+
return np.array(fps), valid_selfies
|
| 30 |
+
|
| 31 |
+
def process_batch(batch, n_clusters, seed):
|
| 32 |
+
fps, valid_selfies = generate_fingerprint_batch_selfies(batch)
|
| 33 |
+
if len(fps) > 0:
|
| 34 |
+
clusterer = MiniBatchKMeans(n_clusters=n_clusters, random_state=seed)
|
| 35 |
+
clusterer.fit(fps)
|
| 36 |
+
labels = clusterer.predict(fps)
|
| 37 |
+
return list(zip(labels, valid_selfies))
|
| 38 |
+
return []
|
| 39 |
+
|
| 40 |
+
def parallel_clustering_split_selfies(selfies_list, batch_size=10000, n_clusters=1000, train_ratio=0.9, seed=42):
|
| 41 |
+
np.random.seed(seed)
|
| 42 |
+
|
| 43 |
+
# Create batches
|
| 44 |
+
batches = [selfies_list[i:i + batch_size]
|
| 45 |
+
for i in range(0, len(selfies_list), batch_size)]
|
| 46 |
+
|
| 47 |
+
# Initialize parallel processing
|
| 48 |
+
n_cores = 12
|
| 49 |
+
process_batch_partial = partial(process_batch, n_clusters=n_clusters, seed=seed)
|
| 50 |
+
|
| 51 |
+
cluster_assignments = defaultdict(list)
|
| 52 |
+
with Pool(n_cores) as pool:
|
| 53 |
+
results = list(tqdm(
|
| 54 |
+
pool.imap(process_batch_partial, batches),
|
| 55 |
+
total=len(batches),
|
| 56 |
+
desc="Processing batches"
|
| 57 |
+
))
|
| 58 |
+
|
| 59 |
+
# Combine results
|
| 60 |
+
for batch_results in results:
|
| 61 |
+
for label, selfies in batch_results:
|
| 62 |
+
cluster_assignments[label].append(selfies)
|
| 63 |
+
|
| 64 |
+
# Split into train/val
|
| 65 |
+
clusters = list(cluster_assignments.values())
|
| 66 |
+
np.random.shuffle(clusters)
|
| 67 |
+
|
| 68 |
+
train_selfies = []
|
| 69 |
+
val_selfies = []
|
| 70 |
+
total_mols = sum(len(cluster) for cluster in clusters)
|
| 71 |
+
|
| 72 |
+
for cluster in tqdm(clusters, desc="Splitting clusters"):
|
| 73 |
+
if len(train_selfies) / total_mols < train_ratio:
|
| 74 |
+
train_selfies.extend(cluster)
|
| 75 |
+
else:
|
| 76 |
+
val_selfies.extend(cluster)
|
| 77 |
+
|
| 78 |
+
print(f"Final splits: Train={len(train_selfies)}, Validation={len(val_selfies)}")
|
| 79 |
+
return train_selfies, val_selfies
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
with open('/home/yz927/projects/peptune/tokens/filtered_peptides_selfies.txt', 'r') as f:
|
| 83 |
+
selfies_list = [line.strip() for line in f if line.strip()]
|
| 84 |
+
print(f"Loaded {len(selfies_list)} selfies sequences from file")
|
| 85 |
+
except FileNotFoundError:
|
| 86 |
+
raise FileNotFoundError(f"Could not find the file at file")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
raise Exception(f"Error reading file: {str(e)}")
|
| 89 |
+
|
| 90 |
+
train_selfies, val_selfies = parallel_clustering_split_selfies(
|
| 91 |
+
selfies_list,
|
| 92 |
+
batch_size=10000,
|
| 93 |
+
n_clusters=1000,
|
| 94 |
+
train_ratio=0.8
|
| 95 |
+
)
|
| 96 |
+
with open('/home/yz927/projects/peptune/tokens/11M_selfies/train_selfies.txt', 'w') as f:
|
| 97 |
+
for line in train_selfies:
|
| 98 |
+
f.write(f"{line}\n")
|
| 99 |
+
with open('/home/yz927/projects/peptune/tokens/11M_selfies/val_selfies.txt', 'w') as f:
|
| 100 |
+
for line in val_selfies:
|
| 101 |
+
f.write(f"{line}\n")
|
mcts.png
ADDED
|
Git LFS Details
|
mdlm.png
ADDED
|
peptune.png
ADDED
|