| | import argparse |
| | from typing import List |
| |
|
| | import cv2 |
| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from matplotlib.backends.backend_agg import FigureCanvasAgg |
| |
|
| | from .dataset import get_transforms |
| | from .model import Encoder, Decoder |
| | from .chemistry import convert_graph_to_smiles |
| | from .tokenizer import get_tokenizer |
| |
|
| |
|
| | BOND_TYPES = ["", "single", "double", "triple", "aromatic", "solid wedge", "dashed wedge"] |
| |
|
| |
|
| | def safe_load(module, module_states): |
| | def remove_prefix(state_dict): |
| | return {k.replace('module.', ''): v for k, v in state_dict.items()} |
| | missing_keys, unexpected_keys = module.load_state_dict(remove_prefix(module_states), strict=False) |
| | return |
| |
|
| |
|
| | class MolScribe: |
| |
|
| | def __init__(self, model_path, device=None): |
| | """ |
| | MolScribe Interface |
| | :param model_path: path of the model checkpoint. |
| | :param device: torch device, defaults to be CPU. |
| | """ |
| | model_states = torch.load(model_path, map_location=torch.device('cpu')) |
| | args = self._get_args(model_states['args']) |
| | if device is None: |
| | device = torch.device('cpu') |
| | self.device = device |
| | self.tokenizer = get_tokenizer(args) |
| | self.encoder, self.decoder = self._get_model(args, self.tokenizer, self.device, model_states) |
| | self.transform = get_transforms(args.input_size, augment=False) |
| |
|
| | def _get_args(self, args_states=None): |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument('--encoder', type=str, default='swin_base') |
| | parser.add_argument('--decoder', type=str, default='transformer') |
| | parser.add_argument('--trunc_encoder', action='store_true') |
| | parser.add_argument('--no_pretrained', action='store_true') |
| | parser.add_argument('--use_checkpoint', action='store_true', default=True) |
| | parser.add_argument('--dropout', type=float, default=0.5) |
| | parser.add_argument('--embed_dim', type=int, default=256) |
| | parser.add_argument('--enc_pos_emb', action='store_true') |
| | group = parser.add_argument_group("transformer_options") |
| | group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6) |
| | group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256) |
| | group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8) |
| | group.add_argument("--dec_num_queries", type=int, default=128) |
| | group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1) |
| | group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1) |
| | group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0) |
| | parser.add_argument('--continuous_coords', action='store_true') |
| | parser.add_argument('--compute_confidence', action='store_true') |
| | |
| | parser.add_argument('--input_size', type=int, default=384) |
| | parser.add_argument('--vocab_file', type=str, default=None) |
| | parser.add_argument('--coord_bins', type=int, default=64) |
| | parser.add_argument('--sep_xy', action='store_true', default=True) |
| |
|
| | args = parser.parse_args([]) |
| | if args_states: |
| | for key, value in args_states.items(): |
| | args.__dict__[key] = value |
| | return args |
| |
|
| | def _get_model(self, args, tokenizer, device, states): |
| | encoder = Encoder(args, pretrained=False) |
| | args.encoder_dim = encoder.n_features |
| | decoder = Decoder(args, tokenizer) |
| |
|
| | safe_load(encoder, states['encoder']) |
| | safe_load(decoder, states['decoder']) |
| | |
| |
|
| | encoder.to(device) |
| | decoder.to(device) |
| | encoder.eval() |
| | decoder.eval() |
| | return encoder, decoder |
| |
|
| | def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=16): |
| | device = self.device |
| | predictions = [] |
| | self.decoder.compute_confidence = return_confidence |
| |
|
| | for idx in range(0, len(input_images), batch_size): |
| | batch_images = input_images[idx:idx+batch_size] |
| | images = [self.transform(image=image, keypoints=[])['image'] for image in batch_images] |
| | images = torch.stack(images, dim=0).to(device) |
| | with torch.no_grad(): |
| | features, hiddens = self.encoder(images) |
| | batch_predictions = self.decoder.decode(features, hiddens) |
| | predictions += batch_predictions |
| |
|
| | return self.convert_graph_to_output(predictions, input_images, return_confidence, return_atoms_bonds) |
| |
|
| |
|
| | def convert_graph_to_output(self, predictions, input_images, return_confidence=True, return_atoms_bonds=True): |
| | node_coords = [pred['chartok_coords']['coords'] for pred in predictions] |
| | node_symbols = [pred['chartok_coords']['symbols'] for pred in predictions] |
| | edges = [pred['edges'] for pred in predictions] |
| | |
| | smiles_list, molblock_list, r_success = convert_graph_to_smiles( |
| | node_coords, node_symbols, edges, images=input_images) |
| |
|
| | outputs = [] |
| | for smiles, molblock, pred in zip(smiles_list, molblock_list, predictions): |
| | pred_dict = {"smiles": smiles, "molfile": molblock, "oringinal_coords": pred['chartok_coords']['coords'], "original_symbols": pred['chartok_coords']['symbols'], "orignal_edges": pred['edges']} |
| | if return_confidence: |
| | pred_dict["confidence"] = pred["overall_score"] |
| | if return_atoms_bonds: |
| | coords = pred['chartok_coords']['coords'] |
| | symbols = pred['chartok_coords']['symbols'] |
| | |
| | |
| | |
| | atom_list = [] |
| | for i, (symbol, coord) in enumerate(zip(symbols, coords)): |
| | atom_dict = {"atom_symbol": symbol, "x": round(coord[0],3), "y": round(coord[1],3)} |
| | if return_confidence: |
| | atom_dict["confidence"] = pred['chartok_coords']['atom_scores'][i] |
| | atom_list.append(atom_dict) |
| | pred_dict["atoms"] = atom_list |
| | |
| | bond_list = [] |
| | num_atoms = len(symbols) |
| | for i in range(num_atoms-1): |
| | for j in range(i+1, num_atoms): |
| | bond_type_int = pred['edges'][i][j] |
| | if bond_type_int != 0: |
| | bond_type_str = BOND_TYPES[bond_type_int] |
| | bond_dict = {"bond_type": bond_type_str, "endpoint_atoms": (i, j)} |
| | if return_confidence: |
| | bond_dict["confidence"] = pred["edge_scores"][i][j] |
| | bond_list.append(bond_dict) |
| | pred_dict["bonds"] = bond_list |
| | outputs.append(pred_dict) |
| | return outputs |
| |
|
| | def predict_image(self, image, return_atoms_bonds=False, return_confidence=False): |
| | return self.predict_images([ |
| | image], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0] |
| |
|
| | def predict_image_files(self, image_files: List, return_atoms_bonds=False, return_confidence=False): |
| | input_images = [] |
| | for path in image_files: |
| | image = cv2.imread(path) |
| | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| | input_images.append(image) |
| | return self.predict_images( |
| | input_images, return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence) |
| |
|
| | def predict_image_file(self, image_file: str, return_atoms_bonds=False, return_confidence=False): |
| | return self.predict_image_files( |
| | [image_file], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0] |
| |
|
| | def draw_prediction(self, prediction, image, notebook=False): |
| | if "atoms" not in prediction or "bonds" not in prediction: |
| | raise ValueError("atoms and bonds information are not provided.") |
| | h, w, _ = image.shape |
| | h, w = np.array([h, w]) * 400 / max(h, w) |
| | image = cv2.resize(image, (int(w), int(h))) |
| | fig, ax = plt.subplots(1, 1) |
| | ax.axis('off') |
| | ax.set_xlim(-0.05 * w, w * 1.05) |
| | ax.set_ylim(1.05 * h, -0.05 * h) |
| | plt.imshow(image, alpha=0.) |
| | x = [a['x'] * w for a in prediction['atoms']] |
| | y = [a['y'] * h for a in prediction['atoms']] |
| | markersize = min(w, h) / 3 |
| | plt.scatter(x, y, marker='o', s=markersize, color='lightskyblue', zorder=10) |
| | for i, atom in enumerate(prediction['atoms']): |
| | symbol = atom['atom_symbol'].lstrip('[').rstrip(']') |
| | plt.annotate(symbol, xy=(x[i], y[i]), ha='center', va='center', color='black', zorder=100) |
| | for bond in prediction['bonds']: |
| | u, v = bond['endpoint_atoms'] |
| | x1, y1, x2, y2 = x[u], y[u], x[v], y[v] |
| | bond_type = bond['bond_type'] |
| | if bond_type == 'single': |
| | color = 'tab:green' |
| | ax.plot([x1, x2], [y1, y2], color, linewidth=4) |
| | elif bond_type == 'aromatic': |
| | color = 'tab:purple' |
| | ax.plot([x1, x2], [y1, y2], color, linewidth=4) |
| | elif bond_type == 'double': |
| | color = 'tab:green' |
| | ax.plot([x1, x2], [y1, y2], color=color, linewidth=7) |
| | ax.plot([x1, x2], [y1, y2], color='w', linewidth=1.5, zorder=2.1) |
| | elif bond_type == 'triple': |
| | color = 'tab:green' |
| | x1s, x2s = 0.8 * x1 + 0.2 * x2, 0.2 * x1 + 0.8 * x2 |
| | y1s, y2s = 0.8 * y1 + 0.2 * y2, 0.2 * y1 + 0.8 * y2 |
| | ax.plot([x1s, x2s], [y1s, y2s], color=color, linewidth=9) |
| | ax.plot([x1, x2], [y1, y2], color='w', linewidth=5, zorder=2.05) |
| | ax.plot([x1, x2], [y1, y2], color=color, linewidth=2, zorder=2.1) |
| | else: |
| | length = 10 |
| | width = 10 |
| | color = 'tab:green' |
| | if bond_type == 'solid wedge': |
| | ax.annotate('', xy=(x1, y1), xytext=(x2, y2), |
| | arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2) |
| | else: |
| | ax.annotate('', xy=(x2, y2), xytext=(x1, y1), |
| | arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2) |
| | fig.tight_layout() |
| | if not notebook: |
| | canvas = FigureCanvasAgg(fig) |
| | canvas.draw() |
| | buf = canvas.buffer_rgba() |
| | result_image = np.asarray(buf) |
| | plt.close(fig) |
| | return result_image |
| |
|