import selfies as sf from rdkit import Chem import ast import numpy as np # Get molecule old smiles to permuted smiles correspondence for token_regr def __get_correspondence__(mol, epoch): if epoch == 0: new_smiles = Chem.MolToSmiles(mol, canonical=True) else: new_smiles = Chem.MolToRandomSmilesVect(mol, 1, randomSeed=epoch)[0] output_order = mol.GetProp('_smilesAtomOutputOrder') mapping = ast.literal_eval(output_order) return new_smiles, mapping # We already know the [Ring] token connects the token immediately before... def get_ring_masks(mol, map_smiles_to_selfies, tokens): # This is fine, atoms are given indices in the molecule based on the order they appear in the SMILES Chem.FastFindRings(mol) rings = mol.GetRingInfo().AtomRings() ring_masks = [] for i, ring in enumerate(rings): selfies_ring = map_smiles_to_selfies[list(ring)] ring_idx = selfies_ring.max()+1 ring_masks.append((ring_idx, selfies_ring)) assert "Ring" in tokens[ring_idx] return ring_masks # Distances are set to 0 for the tokens in the molecules at the right and at the left of . tokens (except padding tokens) def __get_attribution_mapping__(tokens): special_token_masks = [] map_smiles_to_selfies = [] dots = [] idx = 1 # Start after [CLS] while idx < len(tokens): token = tokens[idx] if token == ".": dots.append(idx) idx += 1 continue branch_idx = token.find("Branch") if branch_idx >= 0: n = int(token[branch_idx + 6]) special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16)) idx += n + 1 continue else: ring_idx = token.find("Ring") if ring_idx >= 0: n = int(token[ring_idx + 4]) special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16)) idx += n + 1 continue # Real (atom) token map_smiles_to_selfies.append(idx) idx += 1 # Existing dot_masks construction (unchanged) dot_masks = [] last_dots = [-1] for dot_idx in dots: if len(last_dots) == 2: val = last_dots.pop(0) dot_masks.append([el for el in range(val + 1, dot_idx, 1)]) last_dots.append(dot_idx) if len(dots) >= 1: dot_masks.append([el for el in range(last_dots.pop(0) + 1, len(tokens), 1)]) return special_token_masks, np.array(map_smiles_to_selfies), list(zip(dots, dot_masks, strict=True)) def __get_positional_encodings__(mol, smiles_to_selfies, context_length, special_token_masks, double_masks, first_padding_token_idx): ats = np.array(smiles_to_selfies, dtype=np.int64) distance = Chem.GetDistanceMatrix(mol) # Distance of encodings is capped at the int16 upper bound minus 1 # (because the int16 upper bound value is reserved for special distances) limit = np.iinfo(np.int16).max distance = np.minimum(distance, limit-1).astype(np.int16) pos_encod = np.full((context_length, context_length), limit, dtype=np.int16) # Set first row and column to 0 only for non-padding tokens (positions in ats) pos_encod[0, :first_padding_token_idx] = 0 pos_encod[:first_padding_token_idx, 0] = 0 for m in special_token_masks: pos_encod[m[:, None], m] = -1 for i, m in double_masks: pos_encod[i, m] = 0 pos_encod[m, i] = 0 np.fill_diagonal(pos_encod, 0) # Use advanced indexing for distance assignment pos_encod[ats[:, None], ats] = distance return pos_encod def get_positional_encodings_and_align(smiles, token_regr, epoch): orig_mol = Chem.MolFromSmiles(smiles, sanitize = False) # Converts SMILES to the final SMILES so that the mapping is already correct for the token-level labels. # Generates a predictable variation of the SMILES. new_smiles, mapping_to_new = __get_correspondence__(orig_mol, epoch) # Convert to SELFIES, simulate tokenization and add [CLS] token at the beginning selfies = sf.encoder(new_smiles) tokens = ["[CLS]"] + list(sf.split_selfies(selfies)) special_token_masks, map_smiles_to_selfies, dot_masks = __get_attribution_mapping__(tokens) # Align token labels to SELFIES tokens if token_regr is not None: # Align token labels to the new SMILES token_regr[:len(mapping_to_new)] = token_regr[mapping_to_new] token_regr_selfies = np.full(len(tokens)-1, np.nan, dtype=token_regr.dtype) valid = map_smiles_to_selfies < len(tokens) token_regr_selfies[map_smiles_to_selfies[valid] - 1] = token_regr[:np.sum(valid)] else: token_regr_selfies = None # Generate molecule from the new SMILES (remove sanitization to preserve the original structure) mol = Chem.MolFromSmiles(new_smiles, sanitize = False) ring_masks = get_ring_masks(mol, map_smiles_to_selfies, tokens) double_masks = ring_masks + dot_masks pos_encod = __get_positional_encodings__(mol, map_smiles_to_selfies, len(tokens), special_token_masks, double_masks, len(tokens)) return selfies, pos_encod, token_regr_selfies