File size: 6,735 Bytes
8019be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import safe as sf
import datamol as dm
from contextlib import suppress
from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')
# https://github.com/datamol-io/safe/blob/main/safe/sample.py
# https://github.com/jensengroup/GB_GA/blob/master/crossover.py
def safe_to_smiles(safe_str, fix=True):
if fix:
safe_str = '.'.join([frag for frag in safe_str.split('.')
if sf.decode(frag, ignore_errors=True) is not None])
return sf.decode(safe_str, canonical=True, ignore_errors=True)
def _safe_to_smiles_worker(args):
"""Worker function for parallel SAFE to SMILES conversion."""
safe_str, use_bracket_safe, fix = args
try:
from mol_utils.bracket_safe_converter import bracketsafe2safe
if use_bracket_safe:
safe_str = bracketsafe2safe(safe_str)
return safe_to_smiles(safe_str, fix=fix)
except Exception:
return None
def batch_safe_to_smiles(safe_strings, use_bracket_safe=False, fix=True, num_workers=None):
"""
Convert a batch of SAFE strings to SMILES in parallel using multiprocessing.
Args:
safe_strings: List of SAFE format strings
use_bracket_safe: Whether to convert from bracket SAFE format first
fix: Whether to fix invalid fragments
num_workers: Number of parallel workers (default: min(cpu_count, len(safe_strings), 8))
Returns:
List of SMILES strings (None for invalid molecules)
"""
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import os
n = len(safe_strings)
if n == 0:
return []
# For small batches, use sequential processing (overhead not worth it)
if n <= 4:
if use_bracket_safe:
from mol_utils.bracket_safe_converter import bracketsafe2safe
return [safe_to_smiles(bracketsafe2safe(s), fix=fix) for s in safe_strings]
else:
return [safe_to_smiles(s, fix=fix) for s in safe_strings]
# Use ThreadPoolExecutor for I/O bound tasks (RDKit releases GIL)
# ProcessPoolExecutor has too much overhead for this use case
if num_workers is None:
num_workers = min(os.cpu_count() or 4, n, 8)
args_list = [(s, use_bracket_safe, fix) for s in safe_strings]
# ThreadPoolExecutor is faster here because:
# 1. No pickle serialization overhead
# 2. RDKit releases the GIL during computation
# 3. Lower startup cost
with ThreadPoolExecutor(max_workers=num_workers) as executor:
results = list(executor.map(_safe_to_smiles_worker, args_list))
return results
def batch_validate_and_extract(smiles_list, samples_tensor, log_rnd_tensor):
"""
Batch validate SMILES and extract valid samples efficiently.
Args:
smiles_list: List of SMILES strings (may contain None for invalid)
samples_tensor: Tensor of token IDs (B, L)
log_rnd_tensor: Tensor of log random values (B,)
Returns:
valid_sequences: List of valid SMILES (largest fragment)
valid_indices: List of indices of valid samples
"""
valid_sequences = []
valid_indices = []
for idx, smiles in enumerate(smiles_list):
if smiles: # Valid SMILES
# Take largest fragment if multiple
largest_fragment = sorted(smiles.split('.'), key=len)[-1]
valid_sequences.append(largest_fragment)
valid_indices.append(idx)
return valid_sequences, valid_indices
def filter_by_substructure(sequences, substruct):
substruct = sf.utils.standardize_attach(substruct)
substruct = Chem.DeleteSubstructs(Chem.MolFromSmarts(substruct), Chem.MolFromSmiles('*'))
substruct = Chem.MolFromSmarts(Chem.MolToSmiles(substruct))
return sf.utils.filter_by_substructure_constraints(sequences, substruct)
def mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, num_samples=1):
mol_linker_slicer = sf.utils.MolSlicer(require_ring_system=False)
prefix_linkers = []
suffix_linkers = []
prefix_query = dm.from_smarts(prefix)
suffix_query = dm.from_smarts(suffix)
for x in prefix_sequences:
with suppress(Exception):
x = dm.to_mol(x)
out = mol_linker_slicer(x, prefix_query)
prefix_linkers.append(out[1])
for x in suffix_sequences:
with suppress(Exception):
x = dm.to_mol(x)
out = mol_linker_slicer(x, suffix_query)
suffix_linkers.append(out[1])
n_linked = 0
linked = []
linkers = prefix_linkers + suffix_linkers
linkers = [x for x in linkers if x is not None]
for n_linked, linker in enumerate(linkers):
linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix))
if n_linked > num_samples:
break
linked = [x for x in linked if x]
return linked[:num_samples]
def cut(smiles):
def cut_nonring(mol):
if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')):
return None
bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))) # single bond not in ring
bs = [mol.GetBondBetweenAtoms(bis[0], bis[1]).GetIdx()]
fragments_mol = Chem.FragmentOnBonds(mol, bs, addDummies=True, dummyLabels=[(1, 1)])
try:
return Chem.GetMolFrags(fragments_mol, asMols=True, sanitizeFrags=True)
except ValueError:
return None
mol = Chem.MolFromSmiles(smiles)
frags = set()
# non-ring cut
for _ in range(3):
frags_nonring = cut_nonring(mol)
if frags_nonring is not None:
frags |= set([Chem.MolToSmiles(f) for f in frags_nonring])
return frags
class Slicer:
def __call__(self, mol):
if isinstance(mol, str):
mol = Chem.MolFromSmiles(mol)
# non-ring single bonds
bonds = mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))
for bond in bonds:
yield bond |