| import selfies as sf |
| from rdkit import Chem |
| import ast |
| import numpy as np |
|
|
|
|
| |
| def __get_correspondence__(mol, epoch): |
| if epoch == 0: |
| new_smiles = Chem.MolToSmiles(mol, canonical=True) |
| else: |
| new_smiles = Chem.MolToRandomSmilesVect(mol, 1, randomSeed=epoch)[0] |
|
|
| output_order = mol.GetProp('_smilesAtomOutputOrder') |
| mapping = ast.literal_eval(output_order) |
|
|
| return new_smiles, mapping |
|
|
| |
|
|
| def get_ring_masks(mol, map_smiles_to_selfies, tokens): |
| |
|
|
| Chem.FastFindRings(mol) |
| |
| rings = mol.GetRingInfo().AtomRings() |
| ring_masks = [] |
| for i, ring in enumerate(rings): |
| selfies_ring = map_smiles_to_selfies[list(ring)] |
| ring_idx = selfies_ring.max()+1 |
| ring_masks.append((ring_idx, selfies_ring)) |
| assert "Ring" in tokens[ring_idx] |
|
|
| return ring_masks |
| |
|
|
| |
| def __get_attribution_mapping__(tokens): |
| special_token_masks = [] |
| map_smiles_to_selfies = [] |
| dots = [] |
|
|
| idx = 1 |
|
|
| while idx < len(tokens): |
| token = tokens[idx] |
|
|
| if token == ".": |
| dots.append(idx) |
| idx += 1 |
| continue |
|
|
| branch_idx = token.find("Branch") |
| if branch_idx >= 0: |
| n = int(token[branch_idx + 6]) |
| special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16)) |
| idx += n + 1 |
| continue |
| else: |
| ring_idx = token.find("Ring") |
| if ring_idx >= 0: |
| n = int(token[ring_idx + 4]) |
| special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16)) |
| idx += n + 1 |
| continue |
|
|
| |
| map_smiles_to_selfies.append(idx) |
| idx += 1 |
|
|
| |
| dot_masks = [] |
| last_dots = [-1] |
| for dot_idx in dots: |
| if len(last_dots) == 2: |
| val = last_dots.pop(0) |
| dot_masks.append([el for el in range(val + 1, dot_idx, 1)]) |
| last_dots.append(dot_idx) |
|
|
| if len(dots) >= 1: |
| dot_masks.append([el for el in range(last_dots.pop(0) + 1, len(tokens), 1)]) |
|
|
| return special_token_masks, np.array(map_smiles_to_selfies), list(zip(dots, dot_masks, strict=True)) |
| |
| def __get_positional_encodings__(mol, smiles_to_selfies, context_length, special_token_masks, double_masks, first_padding_token_idx): |
| ats = np.array(smiles_to_selfies, dtype=np.int64) |
| distance = Chem.GetDistanceMatrix(mol) |
|
|
| |
| |
| limit = np.iinfo(np.int16).max |
| distance = np.minimum(distance, limit-1).astype(np.int16) |
|
|
| pos_encod = np.full((context_length, context_length), limit, dtype=np.int16) |
|
|
| |
| pos_encod[0, :first_padding_token_idx] = 0 |
| pos_encod[:first_padding_token_idx, 0] = 0 |
|
|
| for m in special_token_masks: |
| pos_encod[m[:, None], m] = -1 |
|
|
| for i, m in double_masks: |
| pos_encod[i, m] = 0 |
| pos_encod[m, i] = 0 |
|
|
| np.fill_diagonal(pos_encod, 0) |
| |
| |
| pos_encod[ats[:, None], ats] = distance |
|
|
| return pos_encod |
| |
| def get_positional_encodings_and_align(smiles, token_regr, epoch): |
| orig_mol = Chem.MolFromSmiles(smiles, sanitize = False) |
| |
| |
| |
| new_smiles, mapping_to_new = __get_correspondence__(orig_mol, epoch) |
|
|
| |
| selfies = sf.encoder(new_smiles) |
| tokens = ["[CLS]"] + list(sf.split_selfies(selfies)) |
|
|
| special_token_masks, map_smiles_to_selfies, dot_masks = __get_attribution_mapping__(tokens) |
|
|
| |
| if token_regr is not None: |
| |
| token_regr[:len(mapping_to_new)] = token_regr[mapping_to_new] |
|
|
| token_regr_selfies = np.full(len(tokens)-1, np.nan, dtype=token_regr.dtype) |
|
|
| valid = map_smiles_to_selfies < len(tokens) |
| token_regr_selfies[map_smiles_to_selfies[valid] - 1] = token_regr[:np.sum(valid)] |
| else: |
| token_regr_selfies = None |
|
|
| |
| mol = Chem.MolFromSmiles(new_smiles, sanitize = False) |
|
|
| ring_masks = get_ring_masks(mol, map_smiles_to_selfies, tokens) |
| double_masks = ring_masks + dot_masks |
| pos_encod = __get_positional_encodings__(mol, map_smiles_to_selfies, len(tokens), special_token_masks, double_masks, len(tokens)) |
|
|
| return selfies, pos_encod, token_regr_selfies |
|
|
|
|