m5-encoder / prepare_data.py
IlPakoZ's picture
Initial upload
12cd9ef
import selfies as sf
from rdkit import Chem
import ast
import numpy as np
# Get molecule old smiles to permuted smiles correspondence for token_regr
def __get_correspondence__(mol, epoch):
if epoch == 0:
new_smiles = Chem.MolToSmiles(mol, canonical=True)
else:
new_smiles = Chem.MolToRandomSmilesVect(mol, 1, randomSeed=epoch)[0]
output_order = mol.GetProp('_smilesAtomOutputOrder')
mapping = ast.literal_eval(output_order)
return new_smiles, mapping
# We already know the [Ring] token connects the token immediately before...
def get_ring_masks(mol, map_smiles_to_selfies, tokens):
# This is fine, atoms are given indices in the molecule based on the order they appear in the SMILES
Chem.FastFindRings(mol)
rings = mol.GetRingInfo().AtomRings()
ring_masks = []
for i, ring in enumerate(rings):
selfies_ring = map_smiles_to_selfies[list(ring)]
ring_idx = selfies_ring.max()+1
ring_masks.append((ring_idx, selfies_ring))
assert "Ring" in tokens[ring_idx]
return ring_masks
# Distances are set to 0 for the tokens in the molecules at the right and at the left of . tokens (except padding tokens)
def __get_attribution_mapping__(tokens):
special_token_masks = []
map_smiles_to_selfies = []
dots = []
idx = 1 # Start after [CLS]
while idx < len(tokens):
token = tokens[idx]
if token == ".":
dots.append(idx)
idx += 1
continue
branch_idx = token.find("Branch")
if branch_idx >= 0:
n = int(token[branch_idx + 6])
special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16))
idx += n + 1
continue
else:
ring_idx = token.find("Ring")
if ring_idx >= 0:
n = int(token[ring_idx + 4])
special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16))
idx += n + 1
continue
# Real (atom) token
map_smiles_to_selfies.append(idx)
idx += 1
# Existing dot_masks construction (unchanged)
dot_masks = []
last_dots = [-1]
for dot_idx in dots:
if len(last_dots) == 2:
val = last_dots.pop(0)
dot_masks.append([el for el in range(val + 1, dot_idx, 1)])
last_dots.append(dot_idx)
if len(dots) >= 1:
dot_masks.append([el for el in range(last_dots.pop(0) + 1, len(tokens), 1)])
return special_token_masks, np.array(map_smiles_to_selfies), list(zip(dots, dot_masks, strict=True))
def __get_positional_encodings__(mol, smiles_to_selfies, context_length, special_token_masks, double_masks, first_padding_token_idx):
ats = np.array(smiles_to_selfies, dtype=np.int64)
distance = Chem.GetDistanceMatrix(mol)
# Distance of encodings is capped at the int16 upper bound minus 1
# (because the int16 upper bound value is reserved for special distances)
limit = np.iinfo(np.int16).max
distance = np.minimum(distance, limit-1).astype(np.int16)
pos_encod = np.full((context_length, context_length), limit, dtype=np.int16)
# Set first row and column to 0 only for non-padding tokens (positions in ats)
pos_encod[0, :first_padding_token_idx] = 0
pos_encod[:first_padding_token_idx, 0] = 0
for m in special_token_masks:
pos_encod[m[:, None], m] = -1
for i, m in double_masks:
pos_encod[i, m] = 0
pos_encod[m, i] = 0
np.fill_diagonal(pos_encod, 0)
# Use advanced indexing for distance assignment
pos_encod[ats[:, None], ats] = distance
return pos_encod
def get_positional_encodings_and_align(smiles, token_regr, epoch):
orig_mol = Chem.MolFromSmiles(smiles, sanitize = False)
# Converts SMILES to the final SMILES so that the mapping is already correct for the token-level labels.
# Generates a predictable variation of the SMILES.
new_smiles, mapping_to_new = __get_correspondence__(orig_mol, epoch)
# Convert to SELFIES, simulate tokenization and add [CLS] token at the beginning
selfies = sf.encoder(new_smiles)
tokens = ["[CLS]"] + list(sf.split_selfies(selfies))
special_token_masks, map_smiles_to_selfies, dot_masks = __get_attribution_mapping__(tokens)
# Align token labels to SELFIES tokens
if token_regr is not None:
# Align token labels to the new SMILES
token_regr[:len(mapping_to_new)] = token_regr[mapping_to_new]
token_regr_selfies = np.full(len(tokens)-1, np.nan, dtype=token_regr.dtype)
valid = map_smiles_to_selfies < len(tokens)
token_regr_selfies[map_smiles_to_selfies[valid] - 1] = token_regr[:np.sum(valid)]
else:
token_regr_selfies = None
# Generate molecule from the new SMILES (remove sanitization to preserve the original structure)
mol = Chem.MolFromSmiles(new_smiles, sanitize = False)
ring_masks = get_ring_masks(mol, map_smiles_to_selfies, tokens)
double_masks = ring_masks + dot_masks
pos_encod = __get_positional_encodings__(mol, map_smiles_to_selfies, len(tokens), special_token_masks, double_masks, len(tokens))
return selfies, pos_encod, token_regr_selfies