File size: 6,735 Bytes
8019be0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import random
import safe as sf
import datamol as dm
from contextlib import suppress
from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')

# https://github.com/datamol-io/safe/blob/main/safe/sample.py
# https://github.com/jensengroup/GB_GA/blob/master/crossover.py
def safe_to_smiles(safe_str, fix=True):
    if fix:
        safe_str = '.'.join([frag for frag in safe_str.split('.')
                             if sf.decode(frag, ignore_errors=True) is not None])
    return sf.decode(safe_str, canonical=True, ignore_errors=True)


def _safe_to_smiles_worker(args):
    """Worker function for parallel SAFE to SMILES conversion."""
    safe_str, use_bracket_safe, fix = args
    try:
        from mol_utils.bracket_safe_converter import bracketsafe2safe
        if use_bracket_safe:
            safe_str = bracketsafe2safe(safe_str)
        return safe_to_smiles(safe_str, fix=fix)
    except Exception:
        return None


def batch_safe_to_smiles(safe_strings, use_bracket_safe=False, fix=True, num_workers=None):
    """
    Convert a batch of SAFE strings to SMILES in parallel using multiprocessing.
    
    Args:
        safe_strings: List of SAFE format strings
        use_bracket_safe: Whether to convert from bracket SAFE format first
        fix: Whether to fix invalid fragments
        num_workers: Number of parallel workers (default: min(cpu_count, len(safe_strings), 8))
    
    Returns:
        List of SMILES strings (None for invalid molecules)
    """
    from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
    import os
    
    n = len(safe_strings)
    if n == 0:
        return []
    
    # For small batches, use sequential processing (overhead not worth it)
    if n <= 4:
        if use_bracket_safe:
            from mol_utils.bracket_safe_converter import bracketsafe2safe
            return [safe_to_smiles(bracketsafe2safe(s), fix=fix) for s in safe_strings]
        else:
            return [safe_to_smiles(s, fix=fix) for s in safe_strings]
    
    # Use ThreadPoolExecutor for I/O bound tasks (RDKit releases GIL)
    # ProcessPoolExecutor has too much overhead for this use case
    if num_workers is None:
        num_workers = min(os.cpu_count() or 4, n, 8)
    
    args_list = [(s, use_bracket_safe, fix) for s in safe_strings]
    
    # ThreadPoolExecutor is faster here because:
    # 1. No pickle serialization overhead
    # 2. RDKit releases the GIL during computation
    # 3. Lower startup cost
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(_safe_to_smiles_worker, args_list))
    
    return results


def batch_validate_and_extract(smiles_list, samples_tensor, log_rnd_tensor):
    """
    Batch validate SMILES and extract valid samples efficiently.
    
    Args:
        smiles_list: List of SMILES strings (may contain None for invalid)
        samples_tensor: Tensor of token IDs (B, L)
        log_rnd_tensor: Tensor of log random values (B,)
    
    Returns:
        valid_sequences: List of valid SMILES (largest fragment)
        valid_indices: List of indices of valid samples
    """
    valid_sequences = []
    valid_indices = []
    
    for idx, smiles in enumerate(smiles_list):
        if smiles:  # Valid SMILES
            # Take largest fragment if multiple
            largest_fragment = sorted(smiles.split('.'), key=len)[-1]
            valid_sequences.append(largest_fragment)
            valid_indices.append(idx)
    
    return valid_sequences, valid_indices


def filter_by_substructure(sequences, substruct):
    substruct = sf.utils.standardize_attach(substruct)
    substruct = Chem.DeleteSubstructs(Chem.MolFromSmarts(substruct), Chem.MolFromSmiles('*'))
    substruct = Chem.MolFromSmarts(Chem.MolToSmiles(substruct))
    return sf.utils.filter_by_substructure_constraints(sequences, substruct)


def mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, num_samples=1):
    mol_linker_slicer = sf.utils.MolSlicer(require_ring_system=False)

    prefix_linkers = []
    suffix_linkers = []
    prefix_query = dm.from_smarts(prefix)
    suffix_query = dm.from_smarts(suffix)

    for x in prefix_sequences:
        with suppress(Exception):
            x = dm.to_mol(x)
            out = mol_linker_slicer(x, prefix_query)
            prefix_linkers.append(out[1])

    for x in suffix_sequences:
        with suppress(Exception):
            x = dm.to_mol(x)
            out = mol_linker_slicer(x, suffix_query)
            suffix_linkers.append(out[1])

    n_linked = 0
    linked = []
    linkers = prefix_linkers + suffix_linkers
    linkers = [x for x in linkers if x is not None]
    for n_linked, linker in enumerate(linkers):
        linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix))
        if n_linked > num_samples:
            break
        linked = [x for x in linked if x]
    return linked[:num_samples]
    

def cut(smiles):
    def cut_nonring(mol):
        if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')):
            return None

        bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]')))  # single bond not in ring
        bs = [mol.GetBondBetweenAtoms(bis[0], bis[1]).GetIdx()]
        fragments_mol = Chem.FragmentOnBonds(mol, bs, addDummies=True, dummyLabels=[(1, 1)])

        try:
            return Chem.GetMolFrags(fragments_mol, asMols=True, sanitizeFrags=True)
        except ValueError:
            return None
        
    mol = Chem.MolFromSmiles(smiles)
    frags = set()
    # non-ring cut
    for _ in range(3):
        frags_nonring = cut_nonring(mol)
        if frags_nonring is not None:
            frags |= set([Chem.MolToSmiles(f) for f in frags_nonring])
    return frags


class Slicer:
    def __call__(self, mol):
        if isinstance(mol, str):
            mol = Chem.MolFromSmiles(mol)
        
        # non-ring single bonds
        bonds = mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))
        for bond in bonds:
            yield bond