{ "cells": [ { "cell_type": "markdown", "id": "362977fe", "metadata": {}, "source": [ "# MoleGen - a generative Machine Learning model for chemical molecules\n", "\n", "In this post, I will be replicating the model and results from the following 2019 paper: [A Two-Step Graph Convolutional Decoder for Molecule Generation](https://arxiv.org/pdf/1906.03412)." ] }, { "cell_type": "markdown", "id": "6aeae7e6", "metadata": {}, "source": [ "## Prereqs" ] }, { "cell_type": "code", "execution_count": null, "id": "66ada5ec", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install torch_geometric torch pandas matplotlib rdkit scikit-learn" ] }, { "cell_type": "markdown", "id": "e2b6a5bd", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "ce431cb6", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from rdkit import Chem\n", "from rdkit.Chem import Draw\n", "from rdkit.Chem import rdmolops\n", "from tqdm.auto import tqdm\n", "from torch_geometric.data import Data\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from torch_geometric.utils import dense_to_sparse\n", "from rdkit.Chem.rdchem import BondType\n", "from torch_geometric.loader import DataLoader\n", "from torch.utils.data import Dataset\n", "from collections import defaultdict\n", "from sklearn.manifold import TSNE\n", "\n", "tqdm.pandas() # enable progress bars in pandas" ] }, { "cell_type": "markdown", "id": "c454d1ef", "metadata": {}, "source": [ "## Preparation - paper read\n", "\n", "We start by reading the paper. As a first pass, we can make the following comments:\n", "\n", "- The authors used the ZINC dataset\n", "- The very first input to the model is a canonical SMILES representation -> this needs to be converted into a graph ([see page 13/26 of this paper](https://arxiv.org/pdf/1610.02415) for more info on the encoding process). We will also need to extract the position information\n", "- We use a graph convolutional network (GCN) encoder to aggregate information from neighbors for each node and edge\n", "- We then have a simple MLP to convert the output to a matrix and create an output such that we have a one-hot vector essentially of how many atoms of each atom type, and we choose the index that has the maximum\n", "- Then we need to decode this and so we start with a fully connected graph (I'm assuming the edge types are randomized).\n", "\n", "\n", "As a first step, we want to look at our data and build the encoder." ] }, { "cell_type": "markdown", "id": "45ba5638", "metadata": {}, "source": [ "## Dataset exploration\n", "\n", "Next, we will load the ZINC dataset and explore it. We will only load a subset first.\n", "\n", "Note that we can't use the builtin ZINC dataset from PyTorch because we lose the SMILES representation and we don't know the mapping from integer encoding to the atom and edge type. I think this is [somewhere explained here though](https://pubs.acs.org/doi/full/10.1021/acs.jcim.5b00559) or [here](https://zinc15.docking.org/catalogs/home/). However, we don't need the entire ZINC. Only a small subset.\n", "\n", "We will use the [ZINC-250k dataset](https://www.kaggle.com/datasets/basu369victor/zinc250k) which is available on Kaggle. Note for debugging we will only load 1000 molecules.\n", "\n", "We will use [RDKit](https://www.rdkit.org/docs/GettingStartedInPython.html) to convert these SMILES strings into canonical strings and to extract the connectivity for PyTorch geometric. But first, we have to load in our data. We will use Pandas dataframes for this. Note that most of these libraries are already pre-installed on Google colab" ] }, { "cell_type": "code", "execution_count": 13, "id": "59d708d1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1000, 4)\n" ] }, { "data": { "text/html": [ "
| \n", " | smiles | \n", "logP | \n", "qed | \n", "SAS | \n", "
|---|---|---|---|---|
| 0 | \n", "CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1\\n | \n", "5.05060 | \n", "0.702012 | \n", "2.084095 | \n", "
| 1 | \n", "C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1\\n | \n", "3.11370 | \n", "0.928975 | \n", "3.432004 | \n", "
| 2 | \n", "N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)... | \n", "4.96778 | \n", "0.599682 | \n", "2.470633 | \n", "
| 3 | \n", "CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c... | \n", "4.00022 | \n", "0.690944 | \n", "2.822753 | \n", "
| 4 | \n", "N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#... | \n", "3.60956 | \n", "0.789027 | \n", "4.035182 | \n", "