| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import random |
| import safe as sf |
| import datamol as dm |
| from contextlib import suppress |
| from rdkit import Chem, RDLogger |
| RDLogger.DisableLog('rdApp.*') |
|
|
| |
| |
| def safe_to_smiles(safe_str, fix=True): |
| if fix: |
| safe_str = '.'.join([frag for frag in safe_str.split('.') |
| if sf.decode(frag, ignore_errors=True) is not None]) |
| return sf.decode(safe_str, canonical=True, ignore_errors=True) |
|
|
|
|
| def _safe_to_smiles_worker(args): |
| """Worker function for parallel SAFE to SMILES conversion.""" |
| safe_str, use_bracket_safe, fix = args |
| try: |
| from mol_utils.bracket_safe_converter import bracketsafe2safe |
| if use_bracket_safe: |
| safe_str = bracketsafe2safe(safe_str) |
| return safe_to_smiles(safe_str, fix=fix) |
| except Exception: |
| return None |
|
|
|
|
| def batch_safe_to_smiles(safe_strings, use_bracket_safe=False, fix=True, num_workers=None): |
| """ |
| Convert a batch of SAFE strings to SMILES in parallel using multiprocessing. |
| |
| Args: |
| safe_strings: List of SAFE format strings |
| use_bracket_safe: Whether to convert from bracket SAFE format first |
| fix: Whether to fix invalid fragments |
| num_workers: Number of parallel workers (default: min(cpu_count, len(safe_strings), 8)) |
| |
| Returns: |
| List of SMILES strings (None for invalid molecules) |
| """ |
| from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor |
| import os |
| |
| n = len(safe_strings) |
| if n == 0: |
| return [] |
| |
| |
| if n <= 4: |
| if use_bracket_safe: |
| from mol_utils.bracket_safe_converter import bracketsafe2safe |
| return [safe_to_smiles(bracketsafe2safe(s), fix=fix) for s in safe_strings] |
| else: |
| return [safe_to_smiles(s, fix=fix) for s in safe_strings] |
| |
| |
| |
| if num_workers is None: |
| num_workers = min(os.cpu_count() or 4, n, 8) |
| |
| args_list = [(s, use_bracket_safe, fix) for s in safe_strings] |
| |
| |
| |
| |
| |
| with ThreadPoolExecutor(max_workers=num_workers) as executor: |
| results = list(executor.map(_safe_to_smiles_worker, args_list)) |
| |
| return results |
|
|
|
|
| def batch_validate_and_extract(smiles_list, samples_tensor, log_rnd_tensor): |
| """ |
| Batch validate SMILES and extract valid samples efficiently. |
| |
| Args: |
| smiles_list: List of SMILES strings (may contain None for invalid) |
| samples_tensor: Tensor of token IDs (B, L) |
| log_rnd_tensor: Tensor of log random values (B,) |
| |
| Returns: |
| valid_sequences: List of valid SMILES (largest fragment) |
| valid_indices: List of indices of valid samples |
| """ |
| valid_sequences = [] |
| valid_indices = [] |
| |
| for idx, smiles in enumerate(smiles_list): |
| if smiles: |
| |
| largest_fragment = sorted(smiles.split('.'), key=len)[-1] |
| valid_sequences.append(largest_fragment) |
| valid_indices.append(idx) |
| |
| return valid_sequences, valid_indices |
|
|
|
|
| def filter_by_substructure(sequences, substruct): |
| substruct = sf.utils.standardize_attach(substruct) |
| substruct = Chem.DeleteSubstructs(Chem.MolFromSmarts(substruct), Chem.MolFromSmiles('*')) |
| substruct = Chem.MolFromSmarts(Chem.MolToSmiles(substruct)) |
| return sf.utils.filter_by_substructure_constraints(sequences, substruct) |
|
|
|
|
| def mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, num_samples=1): |
| mol_linker_slicer = sf.utils.MolSlicer(require_ring_system=False) |
|
|
| prefix_linkers = [] |
| suffix_linkers = [] |
| prefix_query = dm.from_smarts(prefix) |
| suffix_query = dm.from_smarts(suffix) |
|
|
| for x in prefix_sequences: |
| with suppress(Exception): |
| x = dm.to_mol(x) |
| out = mol_linker_slicer(x, prefix_query) |
| prefix_linkers.append(out[1]) |
|
|
| for x in suffix_sequences: |
| with suppress(Exception): |
| x = dm.to_mol(x) |
| out = mol_linker_slicer(x, suffix_query) |
| suffix_linkers.append(out[1]) |
|
|
| n_linked = 0 |
| linked = [] |
| linkers = prefix_linkers + suffix_linkers |
| linkers = [x for x in linkers if x is not None] |
| for n_linked, linker in enumerate(linkers): |
| linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix)) |
| if n_linked > num_samples: |
| break |
| linked = [x for x in linked if x] |
| return linked[:num_samples] |
| |
|
|
| def cut(smiles): |
| def cut_nonring(mol): |
| if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')): |
| return None |
|
|
| bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))) |
| bs = [mol.GetBondBetweenAtoms(bis[0], bis[1]).GetIdx()] |
| fragments_mol = Chem.FragmentOnBonds(mol, bs, addDummies=True, dummyLabels=[(1, 1)]) |
|
|
| try: |
| return Chem.GetMolFrags(fragments_mol, asMols=True, sanitizeFrags=True) |
| except ValueError: |
| return None |
| |
| mol = Chem.MolFromSmiles(smiles) |
| frags = set() |
| |
| for _ in range(3): |
| frags_nonring = cut_nonring(mol) |
| if frags_nonring is not None: |
| frags |= set([Chem.MolToSmiles(f) for f in frags_nonring]) |
| return frags |
|
|
|
|
| class Slicer: |
| def __call__(self, mol): |
| if isinstance(mol, str): |
| mol = Chem.MolFromSmiles(mol) |
| |
| |
| bonds = mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]')) |
| for bond in bonds: |
| yield bond |