| | from rdkit import Chem |
| | from rdkit.Chem import AllChem |
| | from rdkit import DataStructs |
| | import numpy as np |
| | from sklearn.cluster import MiniBatchKMeans |
| | from collections import defaultdict |
| | from tqdm import tqdm |
| | import selfies as sf |
| | from multiprocessing import Pool, cpu_count |
| | from functools import partial |
| | def generate_fingerprint_batch_selfies(selfies_batch): |
| | fps = [] |
| | valid_selfies = [] |
| | |
| | for selfies in tqdm(selfies_batch, desc="Generating fingerprints", leave=False): |
| | try: |
| | |
| | smiles = sf.decoder(selfies) |
| | mol = Chem.MolFromSmiles(smiles) |
| | if mol is not None: |
| | fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048) |
| | arr = np.zeros((1,)) |
| | DataStructs.ConvertToNumpyArray(fp, arr) |
| | fps.append(arr) |
| | valid_selfies.append(selfies) |
| | except: |
| | continue |
| | |
| | return np.array(fps), valid_selfies |
| |
|
| | def process_batch(batch, n_clusters, seed): |
| | fps, valid_selfies = generate_fingerprint_batch_selfies(batch) |
| | if len(fps) > 0: |
| | clusterer = MiniBatchKMeans(n_clusters=n_clusters, random_state=seed) |
| | clusterer.fit(fps) |
| | labels = clusterer.predict(fps) |
| | return list(zip(labels, valid_selfies)) |
| | return [] |
| |
|
| | def parallel_clustering_split_selfies(selfies_list, batch_size=10000, n_clusters=1000, train_ratio=0.9, seed=42): |
| | np.random.seed(seed) |
| | |
| | |
| | batches = [selfies_list[i:i + batch_size] |
| | for i in range(0, len(selfies_list), batch_size)] |
| | |
| | |
| | n_cores = 12 |
| | process_batch_partial = partial(process_batch, n_clusters=n_clusters, seed=seed) |
| | |
| | cluster_assignments = defaultdict(list) |
| | with Pool(n_cores) as pool: |
| | results = list(tqdm( |
| | pool.imap(process_batch_partial, batches), |
| | total=len(batches), |
| | desc="Processing batches" |
| | )) |
| | |
| | |
| | for batch_results in results: |
| | for label, selfies in batch_results: |
| | cluster_assignments[label].append(selfies) |
| | |
| | |
| | clusters = list(cluster_assignments.values()) |
| | np.random.shuffle(clusters) |
| | |
| | train_selfies = [] |
| | val_selfies = [] |
| | total_mols = sum(len(cluster) for cluster in clusters) |
| | |
| | for cluster in tqdm(clusters, desc="Splitting clusters"): |
| | if len(train_selfies) / total_mols < train_ratio: |
| | train_selfies.extend(cluster) |
| | else: |
| | val_selfies.extend(cluster) |
| | |
| | print(f"Final splits: Train={len(train_selfies)}, Validation={len(val_selfies)}") |
| | return train_selfies, val_selfies |
| |
|
| | try: |
| | with open('/home/yz927/projects/peptune/tokens/filtered_peptides_selfies.txt', 'r') as f: |
| | selfies_list = [line.strip() for line in f if line.strip()] |
| | print(f"Loaded {len(selfies_list)} selfies sequences from file") |
| | except FileNotFoundError: |
| | raise FileNotFoundError(f"Could not find the file at file") |
| | except Exception as e: |
| | raise Exception(f"Error reading file: {str(e)}") |
| |
|
| | train_selfies, val_selfies = parallel_clustering_split_selfies( |
| | selfies_list, |
| | batch_size=10000, |
| | n_clusters=1000, |
| | train_ratio=0.8 |
| | ) |
| | with open('/home/yz927/projects/peptune/tokens/11M_selfies/train_selfies.txt', 'w') as f: |
| | for line in train_selfies: |
| | f.write(f"{line}\n") |
| | with open('/home/yz927/projects/peptune/tokens/11M_selfies/val_selfies.txt', 'w') as f: |
| | for line in val_selfies: |
| | f.write(f"{line}\n") |