File size: 5,262 Bytes
12cd9ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import selfies as sf
from rdkit import Chem
import ast
import numpy as np
# Get molecule old smiles to permuted smiles correspondence for token_regr
def __get_correspondence__(mol, epoch):
if epoch == 0:
new_smiles = Chem.MolToSmiles(mol, canonical=True)
else:
new_smiles = Chem.MolToRandomSmilesVect(mol, 1, randomSeed=epoch)[0]
output_order = mol.GetProp('_smilesAtomOutputOrder')
mapping = ast.literal_eval(output_order)
return new_smiles, mapping
# We already know the [Ring] token connects the token immediately before...
def get_ring_masks(mol, map_smiles_to_selfies, tokens):
# This is fine, atoms are given indices in the molecule based on the order they appear in the SMILES
Chem.FastFindRings(mol)
rings = mol.GetRingInfo().AtomRings()
ring_masks = []
for i, ring in enumerate(rings):
selfies_ring = map_smiles_to_selfies[list(ring)]
ring_idx = selfies_ring.max()+1
ring_masks.append((ring_idx, selfies_ring))
assert "Ring" in tokens[ring_idx]
return ring_masks
# Distances are set to 0 for the tokens in the molecules at the right and at the left of . tokens (except padding tokens)
def __get_attribution_mapping__(tokens):
special_token_masks = []
map_smiles_to_selfies = []
dots = []
idx = 1 # Start after [CLS]
while idx < len(tokens):
token = tokens[idx]
if token == ".":
dots.append(idx)
idx += 1
continue
branch_idx = token.find("Branch")
if branch_idx >= 0:
n = int(token[branch_idx + 6])
special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16))
idx += n + 1
continue
else:
ring_idx = token.find("Ring")
if ring_idx >= 0:
n = int(token[ring_idx + 4])
special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16))
idx += n + 1
continue
# Real (atom) token
map_smiles_to_selfies.append(idx)
idx += 1
# Existing dot_masks construction (unchanged)
dot_masks = []
last_dots = [-1]
for dot_idx in dots:
if len(last_dots) == 2:
val = last_dots.pop(0)
dot_masks.append([el for el in range(val + 1, dot_idx, 1)])
last_dots.append(dot_idx)
if len(dots) >= 1:
dot_masks.append([el for el in range(last_dots.pop(0) + 1, len(tokens), 1)])
return special_token_masks, np.array(map_smiles_to_selfies), list(zip(dots, dot_masks, strict=True))
def __get_positional_encodings__(mol, smiles_to_selfies, context_length, special_token_masks, double_masks, first_padding_token_idx):
ats = np.array(smiles_to_selfies, dtype=np.int64)
distance = Chem.GetDistanceMatrix(mol)
# Distance of encodings is capped at the int16 upper bound minus 1
# (because the int16 upper bound value is reserved for special distances)
limit = np.iinfo(np.int16).max
distance = np.minimum(distance, limit-1).astype(np.int16)
pos_encod = np.full((context_length, context_length), limit, dtype=np.int16)
# Set first row and column to 0 only for non-padding tokens (positions in ats)
pos_encod[0, :first_padding_token_idx] = 0
pos_encod[:first_padding_token_idx, 0] = 0
for m in special_token_masks:
pos_encod[m[:, None], m] = -1
for i, m in double_masks:
pos_encod[i, m] = 0
pos_encod[m, i] = 0
np.fill_diagonal(pos_encod, 0)
# Use advanced indexing for distance assignment
pos_encod[ats[:, None], ats] = distance
return pos_encod
def get_positional_encodings_and_align(smiles, token_regr, epoch):
orig_mol = Chem.MolFromSmiles(smiles, sanitize = False)
# Converts SMILES to the final SMILES so that the mapping is already correct for the token-level labels.
# Generates a predictable variation of the SMILES.
new_smiles, mapping_to_new = __get_correspondence__(orig_mol, epoch)
# Convert to SELFIES, simulate tokenization and add [CLS] token at the beginning
selfies = sf.encoder(new_smiles)
tokens = ["[CLS]"] + list(sf.split_selfies(selfies))
special_token_masks, map_smiles_to_selfies, dot_masks = __get_attribution_mapping__(tokens)
# Align token labels to SELFIES tokens
if token_regr is not None:
# Align token labels to the new SMILES
token_regr[:len(mapping_to_new)] = token_regr[mapping_to_new]
token_regr_selfies = np.full(len(tokens)-1, np.nan, dtype=token_regr.dtype)
valid = map_smiles_to_selfies < len(tokens)
token_regr_selfies[map_smiles_to_selfies[valid] - 1] = token_regr[:np.sum(valid)]
else:
token_regr_selfies = None
# Generate molecule from the new SMILES (remove sanitization to preserve the original structure)
mol = Chem.MolFromSmiles(new_smiles, sanitize = False)
ring_masks = get_ring_masks(mol, map_smiles_to_selfies, tokens)
double_masks = ring_masks + dot_masks
pos_encod = __get_positional_encodings__(mol, map_smiles_to_selfies, len(tokens), special_token_masks, double_masks, len(tokens))
return selfies, pos_encod, token_regr_selfies
|