File size: 5,262 Bytes
12cd9ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import selfies as sf
from rdkit import Chem
import ast
import numpy as np


# Get molecule old smiles to permuted smiles correspondence for token_regr
def __get_correspondence__(mol, epoch):
    if epoch == 0:
        new_smiles = Chem.MolToSmiles(mol, canonical=True)
    else:
        new_smiles = Chem.MolToRandomSmilesVect(mol, 1, randomSeed=epoch)[0]

    output_order = mol.GetProp('_smilesAtomOutputOrder')
    mapping = ast.literal_eval(output_order)

    return new_smiles, mapping

# We already know the [Ring] token connects the token immediately before...

def get_ring_masks(mol, map_smiles_to_selfies, tokens):
    # This is fine, atoms are given indices in the molecule based on the order they appear in the SMILES

    Chem.FastFindRings(mol)
    
    rings = mol.GetRingInfo().AtomRings()
    ring_masks = []
    for i, ring in enumerate(rings):
        selfies_ring = map_smiles_to_selfies[list(ring)]
        ring_idx = selfies_ring.max()+1
        ring_masks.append((ring_idx, selfies_ring))
        assert "Ring" in tokens[ring_idx]

    return ring_masks
    

# Distances are set to 0 for the tokens in the molecules at the right and at the left of . tokens (except padding tokens)
def __get_attribution_mapping__(tokens):
    special_token_masks = []
    map_smiles_to_selfies = []
    dots = []

    idx = 1  # Start after [CLS]

    while idx < len(tokens):
        token = tokens[idx]

        if token == ".":
            dots.append(idx)
            idx += 1
            continue

        branch_idx = token.find("Branch")
        if branch_idx >= 0:
            n = int(token[branch_idx + 6])
            special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16))
            idx += n + 1
            continue
        else:
            ring_idx = token.find("Ring")
            if ring_idx >= 0:
                n = int(token[ring_idx + 4])
                special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16))
                idx += n + 1
                continue

        # Real (atom) token
        map_smiles_to_selfies.append(idx)
        idx += 1

    # Existing dot_masks construction (unchanged)
    dot_masks = []
    last_dots = [-1]
    for dot_idx in dots:
        if len(last_dots) == 2:
            val = last_dots.pop(0)
            dot_masks.append([el for el in range(val + 1, dot_idx, 1)])
        last_dots.append(dot_idx)

    if len(dots) >= 1:
        dot_masks.append([el for el in range(last_dots.pop(0) + 1, len(tokens), 1)])

    return special_token_masks, np.array(map_smiles_to_selfies), list(zip(dots, dot_masks, strict=True))
    
def __get_positional_encodings__(mol, smiles_to_selfies, context_length, special_token_masks, double_masks, first_padding_token_idx):
    ats = np.array(smiles_to_selfies, dtype=np.int64)
    distance = Chem.GetDistanceMatrix(mol)

    # Distance of encodings is capped at the int16 upper bound minus 1
    # (because the int16 upper bound value is reserved for special distances)    
    limit = np.iinfo(np.int16).max
    distance = np.minimum(distance, limit-1).astype(np.int16)

    pos_encod = np.full((context_length, context_length), limit, dtype=np.int16)

    # Set first row and column to 0 only for non-padding tokens (positions in ats)
    pos_encod[0, :first_padding_token_idx] = 0
    pos_encod[:first_padding_token_idx, 0] = 0

    for m in special_token_masks:
        pos_encod[m[:, None], m] = -1

    for i, m in double_masks:
        pos_encod[i, m] = 0
        pos_encod[m, i] = 0

    np.fill_diagonal(pos_encod, 0)    
    
    # Use advanced indexing for distance assignment
    pos_encod[ats[:, None], ats] = distance

    return pos_encod
    
def get_positional_encodings_and_align(smiles, token_regr, epoch):
    orig_mol = Chem.MolFromSmiles(smiles, sanitize = False)
    
    # Converts SMILES to the final SMILES so that the mapping is already correct for the token-level labels.
    # Generates a predictable variation of the SMILES.
    new_smiles, mapping_to_new = __get_correspondence__(orig_mol, epoch)

    # Convert to SELFIES, simulate tokenization and add [CLS] token at the beginning
    selfies = sf.encoder(new_smiles)
    tokens = ["[CLS]"] + list(sf.split_selfies(selfies))

    special_token_masks, map_smiles_to_selfies, dot_masks = __get_attribution_mapping__(tokens)

    # Align token labels to SELFIES tokens 
    if token_regr is not None:
        # Align token labels to the new SMILES
        token_regr[:len(mapping_to_new)] = token_regr[mapping_to_new]

        token_regr_selfies = np.full(len(tokens)-1, np.nan, dtype=token_regr.dtype)

        valid = map_smiles_to_selfies < len(tokens)
        token_regr_selfies[map_smiles_to_selfies[valid] - 1] = token_regr[:np.sum(valid)]
    else:
        token_regr_selfies = None

    # Generate molecule from the new SMILES (remove sanitization to preserve the original structure)
    mol = Chem.MolFromSmiles(new_smiles, sanitize = False)

    ring_masks = get_ring_masks(mol, map_smiles_to_selfies, tokens)
    double_masks = ring_masks + dot_masks
    pos_encod = __get_positional_encodings__(mol, map_smiles_to_selfies, len(tokens), special_token_masks, double_masks, len(tokens))

    return selfies, pos_encod, token_regr_selfies