A2D2 / a2d2_mol /mol_utils /utils_chem.py
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
6.74 kB
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import safe as sf
import datamol as dm
from contextlib import suppress
from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')
# https://github.com/datamol-io/safe/blob/main/safe/sample.py
# https://github.com/jensengroup/GB_GA/blob/master/crossover.py
def safe_to_smiles(safe_str, fix=True):
if fix:
safe_str = '.'.join([frag for frag in safe_str.split('.')
if sf.decode(frag, ignore_errors=True) is not None])
return sf.decode(safe_str, canonical=True, ignore_errors=True)
def _safe_to_smiles_worker(args):
"""Worker function for parallel SAFE to SMILES conversion."""
safe_str, use_bracket_safe, fix = args
try:
from mol_utils.bracket_safe_converter import bracketsafe2safe
if use_bracket_safe:
safe_str = bracketsafe2safe(safe_str)
return safe_to_smiles(safe_str, fix=fix)
except Exception:
return None
def batch_safe_to_smiles(safe_strings, use_bracket_safe=False, fix=True, num_workers=None):
"""
Convert a batch of SAFE strings to SMILES in parallel using multiprocessing.
Args:
safe_strings: List of SAFE format strings
use_bracket_safe: Whether to convert from bracket SAFE format first
fix: Whether to fix invalid fragments
num_workers: Number of parallel workers (default: min(cpu_count, len(safe_strings), 8))
Returns:
List of SMILES strings (None for invalid molecules)
"""
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import os
n = len(safe_strings)
if n == 0:
return []
# For small batches, use sequential processing (overhead not worth it)
if n <= 4:
if use_bracket_safe:
from mol_utils.bracket_safe_converter import bracketsafe2safe
return [safe_to_smiles(bracketsafe2safe(s), fix=fix) for s in safe_strings]
else:
return [safe_to_smiles(s, fix=fix) for s in safe_strings]
# Use ThreadPoolExecutor for I/O bound tasks (RDKit releases GIL)
# ProcessPoolExecutor has too much overhead for this use case
if num_workers is None:
num_workers = min(os.cpu_count() or 4, n, 8)
args_list = [(s, use_bracket_safe, fix) for s in safe_strings]
# ThreadPoolExecutor is faster here because:
# 1. No pickle serialization overhead
# 2. RDKit releases the GIL during computation
# 3. Lower startup cost
with ThreadPoolExecutor(max_workers=num_workers) as executor:
results = list(executor.map(_safe_to_smiles_worker, args_list))
return results
def batch_validate_and_extract(smiles_list, samples_tensor, log_rnd_tensor):
"""
Batch validate SMILES and extract valid samples efficiently.
Args:
smiles_list: List of SMILES strings (may contain None for invalid)
samples_tensor: Tensor of token IDs (B, L)
log_rnd_tensor: Tensor of log random values (B,)
Returns:
valid_sequences: List of valid SMILES (largest fragment)
valid_indices: List of indices of valid samples
"""
valid_sequences = []
valid_indices = []
for idx, smiles in enumerate(smiles_list):
if smiles: # Valid SMILES
# Take largest fragment if multiple
largest_fragment = sorted(smiles.split('.'), key=len)[-1]
valid_sequences.append(largest_fragment)
valid_indices.append(idx)
return valid_sequences, valid_indices
def filter_by_substructure(sequences, substruct):
substruct = sf.utils.standardize_attach(substruct)
substruct = Chem.DeleteSubstructs(Chem.MolFromSmarts(substruct), Chem.MolFromSmiles('*'))
substruct = Chem.MolFromSmarts(Chem.MolToSmiles(substruct))
return sf.utils.filter_by_substructure_constraints(sequences, substruct)
def mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, num_samples=1):
mol_linker_slicer = sf.utils.MolSlicer(require_ring_system=False)
prefix_linkers = []
suffix_linkers = []
prefix_query = dm.from_smarts(prefix)
suffix_query = dm.from_smarts(suffix)
for x in prefix_sequences:
with suppress(Exception):
x = dm.to_mol(x)
out = mol_linker_slicer(x, prefix_query)
prefix_linkers.append(out[1])
for x in suffix_sequences:
with suppress(Exception):
x = dm.to_mol(x)
out = mol_linker_slicer(x, suffix_query)
suffix_linkers.append(out[1])
n_linked = 0
linked = []
linkers = prefix_linkers + suffix_linkers
linkers = [x for x in linkers if x is not None]
for n_linked, linker in enumerate(linkers):
linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix))
if n_linked > num_samples:
break
linked = [x for x in linked if x]
return linked[:num_samples]
def cut(smiles):
def cut_nonring(mol):
if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')):
return None
bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))) # single bond not in ring
bs = [mol.GetBondBetweenAtoms(bis[0], bis[1]).GetIdx()]
fragments_mol = Chem.FragmentOnBonds(mol, bs, addDummies=True, dummyLabels=[(1, 1)])
try:
return Chem.GetMolFrags(fragments_mol, asMols=True, sanitizeFrags=True)
except ValueError:
return None
mol = Chem.MolFromSmiles(smiles)
frags = set()
# non-ring cut
for _ in range(3):
frags_nonring = cut_nonring(mol)
if frags_nonring is not None:
frags |= set([Chem.MolToSmiles(f) for f in frags_nonring])
return frags
class Slicer:
def __call__(self, mol):
if isinstance(mol, str):
mol = Chem.MolFromSmiles(mol)
# non-ring single bonds
bonds = mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))
for bond in bonds:
yield bond