| import os |
| import cv2 |
| import time |
| import random |
| import re |
| import string |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
|
|
| from .indigo import Indigo |
| from .indigo.renderer import IndigoRenderer |
|
|
| from .augment import SafeRotate, CropWhite, PadWhite, SaltAndPepperNoise |
| from .utils import FORMAT_INFO |
| from .tokenizer import PAD_ID |
| from .chemistry import get_num_atoms, normalize_nodes |
| from .constants import RGROUP_SYMBOLS, SUBSTITUTIONS, ELEMENTS, COLORS |
|
|
| cv2.setNumThreads(1) |
|
|
| INDIGO_HYGROGEN_PROB = 0.2 |
| INDIGO_FUNCTIONAL_GROUP_PROB = 0.8 |
| INDIGO_CONDENSED_PROB = 0.5 |
| INDIGO_RGROUP_PROB = 0.5 |
| INDIGO_COMMENT_PROB = 0.3 |
| INDIGO_DEARMOTIZE_PROB = 0.8 |
| INDIGO_COLOR_PROB = 0.2 |
|
|
|
|
| def get_transforms(input_size, augment=True, rotate=True, debug=False): |
| trans_list = [] |
| if augment and rotate: |
| trans_list.append(SafeRotate(limit=90, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255))) |
| trans_list.append(CropWhite(pad=5)) |
| if augment: |
| trans_list += [ |
| |
| A.CropAndPad(percent=[-0.01, 0.00], keep_size=False, p=0.5), |
| PadWhite(pad_ratio=0.4, p=0.2), |
| A.Downscale(scale_min=0.2, scale_max=0.5, interpolation=3), |
| A.Blur(), |
| A.GaussNoise(), |
| SaltAndPepperNoise(num_dots=20, p=0.5) |
| ] |
| trans_list.append(A.Resize(input_size, input_size)) |
| if not debug: |
| mean = [0.485, 0.456, 0.406] |
| std = [0.229, 0.224, 0.225] |
| trans_list += [ |
| A.ToGray(p=1), |
| A.Normalize(mean=mean, std=std), |
| ToTensorV2(), |
| ] |
| return A.Compose(trans_list, keypoint_params=A.KeypointParams(format='xy', remove_invisible=False)) |
|
|
|
|
| def add_functional_group(indigo, mol, debug=False): |
| if random.random() > INDIGO_FUNCTIONAL_GROUP_PROB: |
| return mol |
| |
| substitutions = [sub for sub in SUBSTITUTIONS] |
| random.shuffle(substitutions) |
| for sub in substitutions: |
| query = indigo.loadSmarts(sub.smarts) |
| matcher = indigo.substructureMatcher(mol) |
| matched_atoms_ids = set() |
| for match in matcher.iterateMatches(query): |
| if random.random() < sub.probability or debug: |
| atoms = [] |
| atoms_ids = set() |
| for item in query.iterateAtoms(): |
| atom = match.mapAtom(item) |
| atoms.append(atom) |
| atoms_ids.add(atom.index()) |
| if len(matched_atoms_ids.intersection(atoms_ids)) > 0: |
| continue |
| abbrv = random.choice(sub.abbrvs) |
| superatom = mol.addAtom(abbrv) |
| for atom in atoms: |
| for nei in atom.iterateNeighbors(): |
| if nei.index() not in atoms_ids: |
| if nei.symbol() == 'H': |
| |
| atoms_ids.add(nei.index()) |
| else: |
| superatom.addBond(nei, nei.bond().bondOrder()) |
| for id in atoms_ids: |
| mol.getAtom(id).remove() |
| matched_atoms_ids = matched_atoms_ids.union(atoms_ids) |
| return mol |
|
|
|
|
| def add_explicit_hydrogen(indigo, mol): |
| atoms = [] |
| for atom in mol.iterateAtoms(): |
| try: |
| hs = atom.countImplicitHydrogens() |
| if hs > 0: |
| atoms.append((atom, hs)) |
| except: |
| continue |
| if len(atoms) > 0 and random.random() < INDIGO_HYGROGEN_PROB: |
| atom, hs = random.choice(atoms) |
| for i in range(hs): |
| h = mol.addAtom('H') |
| h.addBond(atom, 1) |
| return mol |
|
|
|
|
| def add_rgroup(indigo, mol, smiles): |
| atoms = [] |
| for atom in mol.iterateAtoms(): |
| try: |
| hs = atom.countImplicitHydrogens() |
| if hs > 0: |
| atoms.append(atom) |
| except: |
| continue |
| if len(atoms) > 0 and '*' not in smiles: |
| if random.random() < INDIGO_RGROUP_PROB: |
| atom_idx = random.choice(range(len(atoms))) |
| atom = atoms[atom_idx] |
| atoms.pop(atom_idx) |
| symbol = random.choice(RGROUP_SYMBOLS) |
| r = mol.addAtom(symbol) |
| r.addBond(atom, 1) |
| return mol |
|
|
|
|
| def get_rand_symb(): |
| symb = random.choice(ELEMENTS) |
| if random.random() < 0.1: |
| symb += random.choice(string.ascii_lowercase) |
| if random.random() < 0.1: |
| symb += random.choice(string.ascii_uppercase) |
| if random.random() < 0.1: |
| symb = f'({gen_rand_condensed()})' |
| return symb |
|
|
|
|
| def get_rand_num(): |
| if random.random() < 0.9: |
| if random.random() < 0.8: |
| return '' |
| else: |
| return str(random.randint(2, 9)) |
| else: |
| return '1' + str(random.randint(2, 9)) |
|
|
|
|
| def gen_rand_condensed(): |
| tokens = [] |
| for i in range(5): |
| if i >= 1 and random.random() < 0.8: |
| break |
| tokens.append(get_rand_symb()) |
| tokens.append(get_rand_num()) |
| return ''.join(tokens) |
|
|
|
|
| def add_rand_condensed(indigo, mol): |
| atoms = [] |
| for atom in mol.iterateAtoms(): |
| try: |
| hs = atom.countImplicitHydrogens() |
| if hs > 0: |
| atoms.append(atom) |
| except: |
| continue |
| if len(atoms) > 0 and random.random() < INDIGO_CONDENSED_PROB: |
| atom = random.choice(atoms) |
| symbol = gen_rand_condensed() |
| r = mol.addAtom(symbol) |
| r.addBond(atom, 1) |
| return mol |
|
|
|
|
| def generate_output_smiles(indigo, mol): |
| |
| smiles = mol.smiles() |
| mol = indigo.loadMolecule(smiles) |
| if '*' in smiles: |
| part_a, part_b = smiles.split(' ', maxsplit=1) |
| part_b = re.search(r'\$.*\$', part_b).group(0)[1:-1] |
| symbols = [t for t in part_b.split(';') if len(t) > 0] |
| output = '' |
| cnt = 0 |
| for i, c in enumerate(part_a): |
| if c != '*': |
| output += c |
| else: |
| output += f'[{symbols[cnt]}]' |
| cnt += 1 |
| return mol, output |
| else: |
| if ' ' in smiles: |
| |
| smiles = smiles.split(' ')[0] |
| return mol, smiles |
|
|
|
|
| def add_comment(indigo): |
| if random.random() < INDIGO_COMMENT_PROB: |
| indigo.setOption('render-comment', str(random.randint(1, 20)) + random.choice(string.ascii_letters)) |
| indigo.setOption('render-comment-font-size', random.randint(40, 60)) |
| indigo.setOption('render-comment-alignment', random.choice([0, 0.5, 1])) |
| indigo.setOption('render-comment-position', random.choice(['top', 'bottom'])) |
| indigo.setOption('render-comment-offset', random.randint(2, 30)) |
|
|
|
|
| def add_color(indigo, mol): |
| if random.random() < INDIGO_COLOR_PROB: |
| indigo.setOption('render-coloring', True) |
| if random.random() < INDIGO_COLOR_PROB: |
| indigo.setOption('render-base-color', random.choice(list(COLORS.values()))) |
| if random.random() < INDIGO_COLOR_PROB: |
| if random.random() < 0.5: |
| indigo.setOption('render-highlight-color-enabled', True) |
| indigo.setOption('render-highlight-color', random.choice(list(COLORS.values()))) |
| if random.random() < 0.5: |
| indigo.setOption('render-highlight-thickness-enabled', True) |
| for atom in mol.iterateAtoms(): |
| if random.random() < 0.1: |
| atom.highlight() |
| return mol |
|
|
|
|
| def get_graph(mol, image, shuffle_nodes=False, pseudo_coords=False): |
| mol.layout() |
| coords, symbols = [], [] |
| index_map = {} |
| atoms = [atom for atom in mol.iterateAtoms()] |
| if shuffle_nodes: |
| random.shuffle(atoms) |
| for i, atom in enumerate(atoms): |
| if pseudo_coords: |
| x, y, z = atom.xyz() |
| else: |
| x, y = atom.coords() |
| coords.append([x, y]) |
| symbols.append(atom.symbol()) |
| index_map[atom.index()] = i |
| if pseudo_coords: |
| coords = normalize_nodes(np.array(coords)) |
| h, w, _ = image.shape |
| coords[:, 0] = coords[:, 0] * w |
| coords[:, 1] = coords[:, 1] * h |
| n = len(symbols) |
| edges = np.zeros((n, n), dtype=int) |
| for bond in mol.iterateBonds(): |
| s = index_map[bond.source().index()] |
| t = index_map[bond.destination().index()] |
| |
| edges[s, t] = bond.bondOrder() |
| edges[t, s] = bond.bondOrder() |
| if bond.bondStereo() in [5, 6]: |
| edges[s, t] = bond.bondStereo() |
| edges[t, s] = 11 - bond.bondStereo() |
| graph = { |
| 'coords': coords, |
| 'symbols': symbols, |
| 'edges': edges, |
| 'num_atoms': len(symbols) |
| } |
| return graph |
|
|
|
|
| def generate_indigo_image(smiles, mol_augment=True, default_option=False, shuffle_nodes=False, pseudo_coords=False, |
| include_condensed=True, debug=False): |
| indigo = Indigo() |
| renderer = IndigoRenderer(indigo) |
| indigo.setOption('render-output-format', 'png') |
| indigo.setOption('render-background-color', '1,1,1') |
| indigo.setOption('render-stereo-style', 'none') |
| indigo.setOption('render-label-mode', 'hetero') |
| indigo.setOption('render-font-family', 'Arial') |
| if not default_option: |
| thickness = random.uniform(0.5, 2) |
| indigo.setOption('render-relative-thickness', thickness) |
| indigo.setOption('render-bond-line-width', random.uniform(1, 4 - thickness)) |
| if random.random() < 0.5: |
| indigo.setOption('render-font-family', random.choice(['Arial', 'Times', 'Courier', 'Helvetica'])) |
| indigo.setOption('render-label-mode', random.choice(['hetero', 'terminal-hetero'])) |
| indigo.setOption('render-implicit-hydrogens-visible', random.choice([True, False])) |
| if random.random() < 0.1: |
| indigo.setOption('render-stereo-style', 'old') |
| if random.random() < 0.2: |
| indigo.setOption('render-atom-ids-visible', True) |
|
|
| try: |
| mol = indigo.loadMolecule(smiles) |
| if mol_augment: |
| if random.random() < INDIGO_DEARMOTIZE_PROB: |
| mol.dearomatize() |
| else: |
| mol.aromatize() |
| smiles = mol.canonicalSmiles() |
| add_comment(indigo) |
| mol = add_explicit_hydrogen(indigo, mol) |
| mol = add_rgroup(indigo, mol, smiles) |
| if include_condensed: |
| mol = add_rand_condensed(indigo, mol) |
| mol = add_functional_group(indigo, mol, debug) |
| mol = add_color(indigo, mol) |
| mol, smiles = generate_output_smiles(indigo, mol) |
|
|
| buf = renderer.renderToBuffer(mol) |
| img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1) |
| |
| graph = get_graph(mol, img, shuffle_nodes, pseudo_coords) |
| success = True |
| except Exception: |
| if debug: |
| raise Exception |
| img = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32) |
| graph = {} |
| success = False |
| return img, smiles, graph, success |
|
|
|
|
| class TrainDataset(Dataset): |
| def __init__(self, args, df, tokenizer, split='train', dynamic_indigo=False): |
| super().__init__() |
| self.df = df |
| self.args = args |
| self.tokenizer = tokenizer |
| if 'file_path' in df.columns: |
| self.file_paths = df['file_path'].values |
| if not self.file_paths[0].startswith(args.data_path): |
| self.file_paths = [os.path.join(args.data_path, path) for path in df['file_path']] |
| self.smiles = df['SMILES'].values if 'SMILES' in df.columns else None |
| self.formats = args.formats |
| self.labelled = (split == 'train') |
| if self.labelled: |
| self.labels = {} |
| for format_ in self.formats: |
| if format_ in ['atomtok', 'inchi']: |
| field = FORMAT_INFO[format_]['name'] |
| if field in df.columns: |
| self.labels[format_] = df[field].values |
| self.transform = get_transforms(args.input_size, |
| augment=(self.labelled and args.augment)) |
| |
| self.dynamic_indigo = (dynamic_indigo and split == 'train') |
| if self.labelled and not dynamic_indigo and args.coords_file is not None: |
| if args.coords_file == 'aux_file': |
| self.coords_df = df |
| self.pseudo_coords = True |
| else: |
| self.coords_df = pd.read_csv(args.coords_file) |
| self.pseudo_coords = False |
| else: |
| self.coords_df = None |
| self.pseudo_coords = args.pseudo_coords |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def image_transform(self, image, coords=[], renormalize=False): |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| augmented = self.transform(image=image, keypoints=coords) |
| image = augmented['image'] |
| if len(coords) > 0: |
| coords = np.array(augmented['keypoints']) |
| if renormalize: |
| coords = normalize_nodes(coords, flip_y=False) |
| else: |
| _, height, width = image.shape |
| coords[:, 0] = coords[:, 0] / width |
| coords[:, 1] = coords[:, 1] / height |
| coords = np.array(coords).clip(0, 1) |
| return image, coords |
| return image |
|
|
| def __getitem__(self, idx): |
| try: |
| return self.getitem(idx) |
| except Exception as e: |
| with open(os.path.join(self.args.save_path, f'error_dataset_{int(time.time())}.log'), 'w') as f: |
| f.write(str(e)) |
| raise e |
|
|
| def getitem(self, idx): |
| ref = {} |
| if self.dynamic_indigo: |
| begin = time.time() |
| image, smiles, graph, success = generate_indigo_image( |
| self.smiles[idx], mol_augment=self.args.mol_augment, default_option=self.args.default_option, |
| shuffle_nodes=self.args.shuffle_nodes, pseudo_coords=self.pseudo_coords, |
| include_condensed=self.args.include_condensed) |
| |
| end = time.time() |
| if idx < 30 and self.args.save_image: |
| path = os.path.join(self.args.save_path, 'images') |
| os.makedirs(path, exist_ok=True) |
| cv2.imwrite(os.path.join(path, f'{idx}.png'), image) |
| if not success: |
| return idx, None, {} |
| image, coords = self.image_transform(image, graph['coords'], renormalize=self.pseudo_coords) |
| graph['coords'] = coords |
| ref['time'] = end - begin |
| if 'atomtok' in self.formats: |
| max_len = FORMAT_INFO['atomtok']['max_len'] |
| label = self.tokenizer['atomtok'].text_to_sequence(smiles, tokenized=False) |
| ref['atomtok'] = torch.LongTensor(label[:max_len]) |
| if 'edges' in self.formats and 'atomtok_coords' not in self.formats and 'chartok_coords' not in self.formats: |
| ref['edges'] = torch.tensor(graph['edges']) |
| if 'atomtok_coords' in self.formats: |
| self._process_atomtok_coords(idx, ref, smiles, graph['coords'], graph['edges'], |
| mask_ratio=self.args.mask_ratio) |
| if 'chartok_coords' in self.formats: |
| self._process_chartok_coords(idx, ref, smiles, graph['coords'], graph['edges'], |
| mask_ratio=self.args.mask_ratio) |
| return idx, image, ref |
| else: |
| file_path = self.file_paths[idx] |
| image = cv2.imread(file_path) |
| if image is None: |
| image = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32) |
| print(file_path, 'not found!') |
| if self.coords_df is not None: |
| h, w, _ = image.shape |
| coords = np.array(eval(self.coords_df.loc[idx, 'node_coords'])) |
| if self.pseudo_coords: |
| coords = normalize_nodes(coords) |
| coords[:, 0] = coords[:, 0] * w |
| coords[:, 1] = coords[:, 1] * h |
| image, coords = self.image_transform(image, coords, renormalize=self.pseudo_coords) |
| else: |
| image = self.image_transform(image) |
| coords = None |
| if self.labelled: |
| smiles = self.smiles[idx] |
| if 'atomtok' in self.formats: |
| max_len = FORMAT_INFO['atomtok']['max_len'] |
| label = self.tokenizer['atomtok'].text_to_sequence(smiles, False) |
| ref['atomtok'] = torch.LongTensor(label[:max_len]) |
| if 'atomtok_coords' in self.formats: |
| if coords is not None: |
| self._process_atomtok_coords(idx, ref, smiles, coords, mask_ratio=0) |
| else: |
| self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1) |
| if 'chartok_coords' in self.formats: |
| if coords is not None: |
| self._process_chartok_coords(idx, ref, smiles, coords, mask_ratio=0) |
| else: |
| self._process_chartok_coords(idx, ref, smiles, mask_ratio=1) |
| if self.args.predict_coords and ('atomtok_coords' in self.formats or 'chartok_coords' in self.formats): |
| smiles = self.smiles[idx] |
| if 'atomtok_coords' in self.formats: |
| self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1) |
| if 'chartok_coords' in self.formats: |
| self._process_chartok_coords(idx, ref, smiles, mask_ratio=1) |
| return idx, image, ref |
|
|
| def _process_atomtok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0): |
| max_len = FORMAT_INFO['atomtok_coords']['max_len'] |
| tokenizer = self.tokenizer['atomtok_coords'] |
| if smiles is None or type(smiles) is not str: |
| smiles = "" |
| label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio) |
| ref['atomtok_coords'] = torch.LongTensor(label[:max_len]) |
| indices = [i for i in indices if i < max_len] |
| ref['atom_indices'] = torch.LongTensor(indices) |
| if tokenizer.continuous_coords: |
| if coords is not None: |
| ref['coords'] = torch.tensor(coords) |
| else: |
| ref['coords'] = torch.ones(len(indices), 2) * -1. |
| if edges is not None: |
| ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)] |
| else: |
| if 'edges' in self.df.columns: |
| edge_list = eval(self.df.loc[idx, 'edges']) |
| n = len(indices) |
| edges = torch.zeros((n, n), dtype=torch.long) |
| for u, v, t in edge_list: |
| if u < n and v < n: |
| if t <= 4: |
| edges[u, v] = t |
| edges[v, u] = t |
| else: |
| edges[u, v] = t |
| edges[v, u] = 11 - t |
| ref['edges'] = edges |
| else: |
| ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100) |
|
|
| def _process_chartok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0): |
| max_len = FORMAT_INFO['chartok_coords']['max_len'] |
| tokenizer = self.tokenizer['chartok_coords'] |
| if smiles is None or type(smiles) is not str: |
| smiles = "" |
| label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio) |
| ref['chartok_coords'] = torch.LongTensor(label[:max_len]) |
| indices = [i for i in indices if i < max_len] |
| ref['atom_indices'] = torch.LongTensor(indices) |
| if tokenizer.continuous_coords: |
| if coords is not None: |
| ref['coords'] = torch.tensor(coords) |
| else: |
| ref['coords'] = torch.ones(len(indices), 2) * -1. |
| if edges is not None: |
| ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)] |
| else: |
| if 'edges' in self.df.columns: |
| edge_list = eval(self.df.loc[idx, 'edges']) |
| n = len(indices) |
| edges = torch.zeros((n, n), dtype=torch.long) |
| for u, v, t in edge_list: |
| if u < n and v < n: |
| if t <= 4: |
| edges[u, v] = t |
| edges[v, u] = t |
| else: |
| edges[u, v] = t |
| edges[v, u] = 11 - t |
| ref['edges'] = edges |
| else: |
| ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100) |
|
|
|
|
| class AuxTrainDataset(Dataset): |
|
|
| def __init__(self, args, train_df, aux_df, tokenizer): |
| super().__init__() |
| self.train_dataset = TrainDataset(args, train_df, tokenizer, dynamic_indigo=args.dynamic_indigo) |
| self.aux_dataset = TrainDataset(args, aux_df, tokenizer, dynamic_indigo=False) |
|
|
| def __len__(self): |
| return len(self.train_dataset) + len(self.aux_dataset) |
|
|
| def __getitem__(self, idx): |
| if idx < len(self.train_dataset): |
| return self.train_dataset[idx] |
| else: |
| return self.aux_dataset[idx - len(self.train_dataset)] |
|
|
|
|
| def pad_images(imgs): |
| |
| max_shape = [0, 0] |
| for img in imgs: |
| for i in range(len(max_shape)): |
| max_shape[i] = max(max_shape[i], img.shape[-1 - i]) |
| stack = [] |
| for img in imgs: |
| pad = [] |
| for i in range(len(max_shape)): |
| pad = pad + [0, max_shape[i] - img.shape[-1 - i]] |
| stack.append(F.pad(img, pad, value=0)) |
| return torch.stack(stack) |
|
|
|
|
| def bms_collate(batch): |
| ids = [] |
| imgs = [] |
| batch = [ex for ex in batch if ex[1] is not None] |
| formats = list(batch[0][2].keys()) |
| seq_formats = [k for k in formats if |
| k in ['atomtok', 'inchi', 'nodes', 'atomtok_coords', 'chartok_coords', 'atom_indices']] |
| refs = {key: [[], []] for key in seq_formats} |
| for ex in batch: |
| ids.append(ex[0]) |
| imgs.append(ex[1]) |
| ref = ex[2] |
| for key in seq_formats: |
| refs[key][0].append(ref[key]) |
| refs[key][1].append(torch.LongTensor([len(ref[key])])) |
| |
| for key in seq_formats: |
| |
| refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=PAD_ID) |
| refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1) |
| |
| |
| |
| |
| if 'coords' in formats: |
| refs['coords'] = pad_sequence([ex[2]['coords'] for ex in batch], batch_first=True, padding_value=-1.) |
| |
| if 'edges' in formats: |
| edges_list = [ex[2]['edges'] for ex in batch] |
| max_len = max([len(edges) for edges in edges_list]) |
| refs['edges'] = torch.stack( |
| [F.pad(edges, (0, max_len - len(edges), 0, max_len - len(edges)), value=-100) for edges in edges_list], |
| dim=0) |
| return ids, pad_images(imgs), refs |
|
|