Sophia Tang commited on
Commit
67bf0b2
·
2 Parent(s): 92f7053 08e4d7a
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ mcts.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,43 @@
1
- ---
2
- license: cc-by-nc-nd-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ extra_gated_fields:
3
+ Name: text
4
+ Company: text
5
+ Country: country
6
+ Specific date: date_picker
7
+ I want to use this model for:
8
+ type: select
9
+ options:
10
+ - Research
11
+ - Education
12
+ - label: Other
13
+ value: other
14
+ extra_gated_prompt: "PepTune License: https://duke.box.com/s/5ghseh23rpsyou66kg60qr89sxt5twyu"
15
+ extra_gated_heading: Acknowledge license to access the repository
16
+ extra_gated_button_content: Acknowledge license
17
+ ---
18
+
19
+ <div align="center">
20
+ <img src="peptune.png" alt="peptune" width="300" height="300">
21
+ </div>
22
+
23
+ # PepTune: *De Novo* Generation of Therapeutic Peptides with Multi-Objective-Guided Discrete Diffusion
24
+
25
+ Peptide therapeutics, a major class of medicines, have achieved remarkable success across diseases like diabetes and cancer, with landmark examples such as GLP-1 receptor agonists revolutionizing the treatment of type-2 diabetes and obesity. Despite their success, designing peptides that satisfy multiple conflicting objectives, such as binding affinity, solubility, and membrane permeability, remains a major challenge. Classical drug development and target structure-based design methods are ineffective for such tasks, as they fail to optimize global functional properties critical for therapeutic efficacy. Existing generative frameworks are largely limited to continuous spaces, unconditioned outputs, or single-objective guidance, making them unsuitable for discrete sequence optimization across multiple properties. To address this, we present **PepTune**, a multi-objective discrete diffusion model for the simultaneous generation and optimization of therapeutic peptide SMILES. Built on the Masked Discrete Language Model (MDLM) framework, PepTune ensures valid peptide structures with state-dependent masking schedules and penalty-based objectives. To guide the diffusion process, we propose a Monte Carlo Tree Search (MCTS)-based strategy that balances exploration and exploitation to iteratively refine Pareto-optimal sequences. MCTS integrates classifier-based rewards with search-tree expansion, overcoming gradient estimation challenges and data sparsity inherent to discrete spaces. Using PepTune, we generate diverse, chemically-modified peptides optimized for multiple therapeutic properties, including target binding affinity, membrane permeability, solubility, hemolysis, and non-fouling characteristics on various disease-relevant targets. In total, our results demonstrate that MCTS-guided discrete diffusion is a powerful and modular approach for multi-objective sequence design in discrete state spaces.
26
+
27
+ ## We build our training framework on top of [Masked Discrete Language Model](https://huggingface.co/kuleshov-group/mdlm-owt).
28
+ ![Masked Discrete Language Model Framework](mdlm.png)
29
+
30
+ ## We optimize desired therapeutic properties of generated sequences based on Monte Carlo Tree Search
31
+ ![Monte Carlo Tree Search Schemetic View](mcts.png)
32
+
33
+ ## Inference API, datasets, and sequences will be freely accessible to the academic community via a non-commercial license upon publication and provisional patent filing
34
+
35
+ ## Interactive Demo
36
+
37
+ You can try out the our peptide visualizer directly in your browser, other property classifiers will be added soon:
38
+
39
+ <https://huggingface.co/spaces/ChatterjeeLab/SMILES2PEPTIDE>
40
+
41
+ ## Usage
42
+
43
+ To use this repository, you agree to abide by the [PepTune License](https://duke.box.com/s/5ghseh23rpsyou66kg60qr89sxt5twyu).
data_preprocessing/data.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from datasets import Dataset, DatasetDict, load_from_disk
4
+ from torch.utils.data import DataLoader
5
+ import os
6
+ from multiprocessing import Pool
7
+ from tqdm import tqdm
8
+ import lightning.pytorch as pl
9
+ sys.path.append('/home/yz927/projects/peptune/scripts/')
10
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
11
+ global_tokenizer = None
12
+
13
+
14
+ def init_pool(tokenizer):
15
+ global global_tokenizer
16
+ global_tokenizer = tokenizer
17
+
18
+ class SequenceDataset:
19
+ def __init__(self, sequences, tokenizer, max_sequence_length, num_cores=8):
20
+ self.sequences = sequences
21
+ self.tokenizer = tokenizer
22
+ self.max_sequence_length = max_sequence_length
23
+ self.num_cores = 8
24
+ self.tokenized_sequences = []
25
+ self.original_sequences = []
26
+
27
+ def tokenize_sequences(self):
28
+ print(f"Starting parallel tokenization using {self.num_cores} cores")
29
+ with Pool(processes=self.num_cores, initializer=init_pool, initargs=(self.tokenizer,)) as pool:
30
+ results = list(tqdm(
31
+ pool.imap(standalone_tokenize_function, self.sequences),
32
+ total=len(self.sequences)
33
+ ))
34
+
35
+ for result, seq in zip(results, self.sequences):
36
+ if result is not None and len(result['input_ids'][0]) <= self.max_sequence_length:
37
+ self.tokenized_sequences.append(result)
38
+ self.original_sequences.append(seq)
39
+
40
+
41
+ def process_sequences(self, batch_size):
42
+ self.tokenize_sequences()
43
+
44
+ lengths = [(len(seq['input_ids'][0]), i) for i, seq in enumerate(self.tokenized_sequences)]
45
+ lengths.sort()
46
+
47
+ batches = []
48
+ sequence_batches = []
49
+ current_batch = []
50
+ current_sequence_batch = []
51
+ current_length = 0
52
+
53
+ for length, idx in tqdm(lengths):
54
+ if current_length + length > self.max_sequence_length or len(current_batch) == batch_size:
55
+ if current_batch:
56
+ batches.append([self.tokenized_sequences[i] for i in current_batch])
57
+ sequence_batches.append([self.original_sequences[i] for i in current_batch])
58
+ current_batch = [idx]
59
+ current_sequence_batch = [self.original_sequences[idx]]
60
+ current_length = length
61
+ else:
62
+ current_batch.append(idx)
63
+ current_sequence_batch.append(self.original_sequences[idx])
64
+ current_length += length
65
+
66
+ if current_batch:
67
+ batches.append([self.tokenized_sequences[i] for i in current_batch])
68
+ sequence_batches.append([self.original_sequences[i] for i in current_batch])
69
+
70
+ token_batch_fn = TokenizeBatch(self.tokenizer)
71
+ processed_batches = [token_batch_fn(batch) for batch in tqdm(batches)]
72
+
73
+ dataset = Dataset.from_dict({
74
+ 'attention_mask': [batch['attention_mask'] for batch in processed_batches],
75
+ 'input_ids': [batch['input_ids'] for batch in processed_batches],
76
+ 'labels': sequence_batches
77
+ })
78
+
79
+ return dataset
80
+
81
+ class DynamicBatchingDataset(Dataset):
82
+ """
83
+ Process dynamically batched datasets of Huggingface Datasets object. Need special handling since in the previous
84
+ steps, each batch (row in the Datasets object) is already processed for per batch loading
85
+ """
86
+
87
+ def __init__(self, dataset_dict):
88
+ print('Initializing dataset...')
89
+ self.dataset_dict = {
90
+ 'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
91
+ 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
92
+ 'labels': dataset_dict['labels'] # Store original sequences as it is
93
+ }
94
+
95
+ def __len__(self):
96
+ return len(self.dataset_dict['attention_mask'])
97
+
98
+ def __getitem__(self, idx):
99
+ if isinstance(idx, int):
100
+ return {
101
+ 'attention_mask': self.dataset_dict['attention_mask'][idx],
102
+ 'input_ids': self.dataset_dict['input_ids'][idx],
103
+ 'labels': self.dataset_dict['labels'][idx]
104
+ }
105
+ elif isinstance(idx, list):
106
+ return {
107
+ 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
108
+ 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
109
+ 'labels': [self.dataset_dict['labels'][i] for i in idx]
110
+ }
111
+ else:
112
+ raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
113
+
114
+ @staticmethod
115
+ def collate_fn(batch, verbose=False):
116
+ item = batch[0]
117
+ return {
118
+ 'input_ids': item['input_ids'],
119
+ 'attention_mask': item['attention_mask'],
120
+ 'labels': item['labels']
121
+ }
122
+
123
+ def standalone_tokenize_function(sequence):
124
+ global global_tokenizer
125
+ try:
126
+ tokens = global_tokenizer(sequence)
127
+ # The tokenizer already returns lists of integers, so we just need to wrap them in another list
128
+ # to match the expected format [batch_size, sequence_length]
129
+ return {
130
+ 'input_ids': [tokens['input_ids']],
131
+ 'attention_mask': [tokens['attention_mask']]
132
+ }
133
+ except Exception as e:
134
+ print(f"Error tokenizing sequence '{sequence}': {e}")
135
+ return None
136
+
137
+ class TokenizeBatch:
138
+ def __init__(self, tokenizer):
139
+ self.pad_token_id = tokenizer.pad_token_id
140
+
141
+ def __call__(self, batches):
142
+ data_tokens = [torch.tensor(batch['input_ids'][0]) for batch in batches]
143
+ data_tokens_padded = torch.nn.utils.rnn.pad_sequence(data_tokens, batch_first=True, padding_value=self.pad_token_id)
144
+ attention_masks = (data_tokens_padded != self.pad_token_id).long()
145
+
146
+ return {
147
+ 'input_ids': data_tokens_padded,
148
+ 'attention_mask': attention_masks,
149
+ }
150
+
151
+ class PretrainSequenceDataModule(pl.LightningDataModule):
152
+ def __init__(self,
153
+ tokenizer,
154
+ input_dataset_path,
155
+ output_dataset_path,
156
+ num_workers,
157
+ batch_size,
158
+ max_sequence_length=512,):
159
+ super().__init__()
160
+ self.tokenizer = tokenizer
161
+ self.input_path = input_dataset_path
162
+ self.output_path = output_dataset_path
163
+ self.num_workers = num_workers
164
+ self.batch_size = batch_size
165
+ self.max_sequence_length = max_sequence_length
166
+
167
+ def prepare_data(self):
168
+ if not os.path.exists(self.output_path):
169
+ print("Loading text files")
170
+ with open(f"{self.input_path}/train.txt", 'r') as f:
171
+ train_sequences = [line.strip() for line in f if line.strip()]
172
+ with open(f"{self.input_path}/val.txt", 'r') as f:
173
+ val_sequences = [line.strip() for line in f if line.strip()]
174
+
175
+ print("Processing training data")
176
+ train_dataset = SequenceDataset(train_sequences,
177
+ self.tokenizer,
178
+ self.max_sequence_length)
179
+ print("Processing validation data")
180
+ val_dataset = SequenceDataset(val_sequences,
181
+ self.tokenizer,
182
+ self.max_sequence_length)
183
+
184
+ processed_train = train_dataset.process_sequences(self.batch_size)
185
+ processed_val = val_dataset.process_sequences(self.batch_size)
186
+
187
+ print("Combining datasets")
188
+ combined_dataset = DatasetDict({
189
+ 'train': processed_train,
190
+ 'val': processed_val,
191
+ })
192
+
193
+ print(f"Saving dataset to {self.output_path}")
194
+ combined_dataset.save_to_disk(self.output_path)
195
+
196
+ def setup(self, stage: str):
197
+ print("Loading processed dataset")
198
+ dataset = load_from_disk(self.output_path)
199
+ self.train_dataset = DynamicBatchingDataset(dataset['train'])
200
+ self.val_dataset = DynamicBatchingDataset(dataset['val'])
201
+
202
+ def train_dataloader(self):
203
+ print("Creating training dataloader")
204
+ return DataLoader(self.train_dataset,
205
+ batch_size=1,
206
+ shuffle=False,
207
+ num_workers=self.num_workers,
208
+ collate_fn=DynamicBatchingDataset.collate_fn,
209
+ pin_memory=True)
210
+
211
+ def val_dataloader(self):
212
+ print("Creating validation dataloader")
213
+ return DataLoader(self.val_dataset,
214
+ batch_size=1,
215
+ shuffle=False,
216
+ num_workers=self.num_workers,
217
+ collate_fn=DynamicBatchingDataset.collate_fn,
218
+ pin_memory=True)
219
+
220
+
221
+ if __name__ == '__main__':
222
+ tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
223
+ '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
224
+ dm = PretrainSequenceDataModule(
225
+ tokenizer=tokenizer,
226
+ input_dataset_path='/home/yz927/projects/peptune/tokens/11M_smiles',
227
+ output_dataset_path='/home/yz927/projects/peptune/tokenized/11M_smiles_old_tokenizer_no_limit',
228
+ num_workers=8,
229
+ batch_size=2000,
230
+ max_sequence_length=16*1000,
231
+ )
232
+ dm.prepare_data()
233
+ dm.setup('fit')
234
+ dm.train_dataloader()
235
+ dm.val_dataloader()
data_preprocessing/data_split.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem import AllChem
3
+ from rdkit import DataStructs
4
+ import numpy as np
5
+ from sklearn.cluster import MiniBatchKMeans
6
+ from collections import defaultdict
7
+ from tqdm import tqdm
8
+ import selfies as sf
9
+ from multiprocessing import Pool, cpu_count
10
+ from functools import partial
11
+ def generate_fingerprint_batch_selfies(selfies_batch):
12
+ fps = []
13
+ valid_selfies = []
14
+
15
+ for selfies in tqdm(selfies_batch, desc="Generating fingerprints", leave=False):
16
+ try:
17
+ # Convert SELFIES to SMILES then to molecule
18
+ smiles = sf.decoder(selfies)
19
+ mol = Chem.MolFromSmiles(smiles)
20
+ if mol is not None:
21
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048)
22
+ arr = np.zeros((1,))
23
+ DataStructs.ConvertToNumpyArray(fp, arr)
24
+ fps.append(arr)
25
+ valid_selfies.append(selfies)
26
+ except:
27
+ continue
28
+
29
+ return np.array(fps), valid_selfies
30
+
31
+ def process_batch(batch, n_clusters, seed):
32
+ fps, valid_selfies = generate_fingerprint_batch_selfies(batch)
33
+ if len(fps) > 0:
34
+ clusterer = MiniBatchKMeans(n_clusters=n_clusters, random_state=seed)
35
+ clusterer.fit(fps)
36
+ labels = clusterer.predict(fps)
37
+ return list(zip(labels, valid_selfies))
38
+ return []
39
+
40
+ def parallel_clustering_split_selfies(selfies_list, batch_size=10000, n_clusters=1000, train_ratio=0.9, seed=42):
41
+ np.random.seed(seed)
42
+
43
+ # Create batches
44
+ batches = [selfies_list[i:i + batch_size]
45
+ for i in range(0, len(selfies_list), batch_size)]
46
+
47
+ # Initialize parallel processing
48
+ n_cores = 12
49
+ process_batch_partial = partial(process_batch, n_clusters=n_clusters, seed=seed)
50
+
51
+ cluster_assignments = defaultdict(list)
52
+ with Pool(n_cores) as pool:
53
+ results = list(tqdm(
54
+ pool.imap(process_batch_partial, batches),
55
+ total=len(batches),
56
+ desc="Processing batches"
57
+ ))
58
+
59
+ # Combine results
60
+ for batch_results in results:
61
+ for label, selfies in batch_results:
62
+ cluster_assignments[label].append(selfies)
63
+
64
+ # Split into train/val
65
+ clusters = list(cluster_assignments.values())
66
+ np.random.shuffle(clusters)
67
+
68
+ train_selfies = []
69
+ val_selfies = []
70
+ total_mols = sum(len(cluster) for cluster in clusters)
71
+
72
+ for cluster in tqdm(clusters, desc="Splitting clusters"):
73
+ if len(train_selfies) / total_mols < train_ratio:
74
+ train_selfies.extend(cluster)
75
+ else:
76
+ val_selfies.extend(cluster)
77
+
78
+ print(f"Final splits: Train={len(train_selfies)}, Validation={len(val_selfies)}")
79
+ return train_selfies, val_selfies
80
+
81
+ try:
82
+ with open('/home/yz927/projects/peptune/tokens/filtered_peptides_selfies.txt', 'r') as f:
83
+ selfies_list = [line.strip() for line in f if line.strip()]
84
+ print(f"Loaded {len(selfies_list)} selfies sequences from file")
85
+ except FileNotFoundError:
86
+ raise FileNotFoundError(f"Could not find the file at file")
87
+ except Exception as e:
88
+ raise Exception(f"Error reading file: {str(e)}")
89
+
90
+ train_selfies, val_selfies = parallel_clustering_split_selfies(
91
+ selfies_list,
92
+ batch_size=10000,
93
+ n_clusters=1000,
94
+ train_ratio=0.8
95
+ )
96
+ with open('/home/yz927/projects/peptune/tokens/11M_selfies/train_selfies.txt', 'w') as f:
97
+ for line in train_selfies:
98
+ f.write(f"{line}\n")
99
+ with open('/home/yz927/projects/peptune/tokens/11M_selfies/val_selfies.txt', 'w') as f:
100
+ for line in val_selfies:
101
+ f.write(f"{line}\n")
mcts.png ADDED

Git LFS Details

  • SHA256: e63bdc835269660e4b7bda69973bd60611b61045f25c5c07a9baa277e31d2acd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
mdlm.png ADDED
peptune.png ADDED