Spaces:
Running
Running
| """ | |
| Title: Message-passing neural network (MPNN) for molecular property prediction | |
| Author: [akensert](http://github.com/akensert) | |
| Date created: 2021/08/16 | |
| Last modified: 2021/12/27 | |
| Description: Implementation of an MPNN to predict blood-brain barrier permeability. | |
| Accelerator: GPU | |
| """ | |
| """ | |
| ## Introduction | |
| In this tutorial, we will implement a type of graph neural network (GNN) known as | |
| _ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will | |
| implement an MPNN to predict a molecular property known as | |
| _blood-brain barrier permeability_ (BBBP). | |
| Motivation: as molecules are naturally represented as an undirected graph `G = (V, E)`, | |
| where `V` is a set or vertices (nodes; atoms) and `E` a set of edges (bonds), GNNs (such | |
| as MPNN) are proving to be a useful method for predicting molecular properties. | |
| Until now, more traditional methods, such as random forests, support vector machines, etc., | |
| have been commonly used to predict molecular properties. In contrast to GNNs, these | |
| traditional approaches often operate on precomputed molecular features such as | |
| molecular weight, polarity, charge, number of carbon atoms, etc. Although these | |
| molecular features prove to be good predictors for various molecular properties, it is | |
| hypothesized that operating on these more "raw", "low-level", features could prove even | |
| better. | |
| ### References | |
| In recent years, a lot of effort has been put into developing neural networks for | |
| graph data, including molecular graphs. For a summary of graph neural networks, see e.g., | |
| [A Comprehensive Survey on Graph Neural Networks](https://arxiv.org/abs/1901.00596) and | |
| [Graph Neural Networks: A Review of Methods and Applications](https://arxiv.org/abs/1812.08434); | |
| and for further reading on the specific | |
| graph neural network implemented in this tutorial see | |
| [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and | |
| [DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel). | |
| """ | |
| """ | |
| ## Setup | |
| ### Install RDKit and other dependencies | |
| (Text below taken from | |
| [this tutorial](https://keras.io/examples/generative/wgan-graphs/)). | |
| [RDKit](https://www.rdkit.org/) is a collection of cheminformatics and machine-learning | |
| software written in C++ and Python. In this tutorial, RDKit is used to conveniently and | |
| efficiently transform | |
| [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) to | |
| molecule objects, and then from those obtain sets of atoms and bonds. | |
| SMILES expresses the structure of a given molecule in the form of an ASCII string. | |
| The SMILES string is a compact encoding which, for smaller molecules, is relatively | |
| human-readable. Encoding molecules as a string both alleviates and facilitates database | |
| and/or web searching of a given molecule. RDKit uses algorithms to | |
| accurately transform a given SMILES to a molecule object, which can then | |
| be used to compute a great number of molecular properties/features. | |
| Notice, RDKit is commonly installed via [Conda](https://www.rdkit.org/docs/Install.html). | |
| However, thanks to | |
| [rdkit_platform_wheels](https://github.com/kuelumbus/rdkit_platform_wheels), rdkit | |
| can now (for the sake of this tutorial) be installed easily via pip, as follows: | |
| ``` | |
| pip -q install rdkit-pypi | |
| ``` | |
| And for easy and efficient reading of csv files and visualization, the below needs to be | |
| installed: | |
| ``` | |
| pip -q install pandas | |
| pip -q install Pillow | |
| pip -q install matplotlib | |
| pip -q install pydot | |
| sudo apt-get -qq install graphviz | |
| ``` | |
| """ | |
| """ | |
| ### Import packages | |
| """ | |
| import os | |
| # Temporary suppress tf logs | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import warnings | |
| from rdkit import Chem | |
| from rdkit import RDLogger | |
| from rdkit.Chem.Draw import IPythonConsole | |
| from rdkit.Chem.Draw import MolsToGridImage | |
| # Temporary suppress warnings and RDKit logs | |
| warnings.filterwarnings("ignore") | |
| RDLogger.DisableLog("rdApp.*") | |
| np.random.seed(42) | |
| tf.random.set_seed(42) | |
| """ | |
| ## Dataset | |
| Information about the dataset can be found in | |
| [A Bayesian Approach to in Silico Blood-Brain Barrier Penetration Modeling](https://pubs.acs.org/doi/10.1021/ci300124c) | |
| and [MoleculeNet: A Benchmark for Molecular Machine Learning](https://arxiv.org/abs/1703.00564). | |
| The dataset will be downloaded from [MoleculeNet.org](https://moleculenet.org/datasets-1). | |
| ### About | |
| The dataset contains **2,050** molecules. Each molecule come with a **name**, **label** | |
| and **SMILES** string. | |
| The blood-brain barrier (BBB) is a membrane separating the blood from the brain | |
| extracellular fluid, hence blocking out most drugs (molecules) from reaching | |
| the brain. Because of this, the BBBP has been important to study for the development of | |
| new drugs that aim to target the central nervous system. The labels for this | |
| data set are binary (1 or 0) and indicate the permeability of the molecules. | |
| """ | |
| csv_path = keras.utils.get_file( | |
| "BBBP.csv", "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv" | |
| ) | |
| df = pd.read_csv(csv_path, usecols=[1, 2, 3]) | |
| df.iloc[96:104] | |
| """ | |
| ### Define features | |
| To encode features for atoms and bonds (which we will need later), | |
| we'll define two classes: `AtomFeaturizer` and `BondFeaturizer` respectively. | |
| To reduce the lines of code, i.e., to keep this tutorial short and concise, | |
| only about a handful of (atom and bond) features will be considered: \[atom features\] | |
| [symbol (element)](https://en.wikipedia.org/wiki/Chemical_element), | |
| [number of valence electrons](https://en.wikipedia.org/wiki/Valence_electron), | |
| [number of hydrogen bonds](https://en.wikipedia.org/wiki/Hydrogen), | |
| [orbital hybridization](https://en.wikipedia.org/wiki/Orbital_hybridisation), | |
| \[bond features\] | |
| [(covalent) bond type](https://en.wikipedia.org/wiki/Covalent_bond), and | |
| [conjugation](https://en.wikipedia.org/wiki/Conjugated_system). | |
| """ | |
| class Featurizer: | |
| def __init__(self, allowable_sets): | |
| self.dim = 0 | |
| self.features_mapping = {} | |
| for k, s in allowable_sets.items(): | |
| s = sorted(list(s)) | |
| self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim))) | |
| self.dim += len(s) | |
| def encode(self, inputs): | |
| output = np.zeros((self.dim,)) | |
| for name_feature, feature_mapping in self.features_mapping.items(): | |
| feature = getattr(self, name_feature)(inputs) | |
| if feature not in feature_mapping: | |
| continue | |
| output[feature_mapping[feature]] = 1.0 | |
| return output | |
| class AtomFeaturizer(Featurizer): | |
| def __init__(self, allowable_sets): | |
| super().__init__(allowable_sets) | |
| def symbol(self, atom): | |
| return atom.GetSymbol() | |
| def n_valence(self, atom): | |
| return atom.GetTotalValence() | |
| def n_hydrogens(self, atom): | |
| return atom.GetTotalNumHs() | |
| def hybridization(self, atom): | |
| return atom.GetHybridization().name.lower() | |
| class BondFeaturizer(Featurizer): | |
| def __init__(self, allowable_sets): | |
| super().__init__(allowable_sets) | |
| self.dim += 1 | |
| def encode(self, bond): | |
| output = np.zeros((self.dim,)) | |
| if bond is None: | |
| output[-1] = 1.0 | |
| return output | |
| output = super().encode(bond) | |
| return output | |
| def bond_type(self, bond): | |
| return bond.GetBondType().name.lower() | |
| def conjugated(self, bond): | |
| return bond.GetIsConjugated() | |
| atom_featurizer = AtomFeaturizer( | |
| allowable_sets={ | |
| "symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"}, | |
| "n_valence": {0, 1, 2, 3, 4, 5, 6}, | |
| "n_hydrogens": {0, 1, 2, 3, 4}, | |
| "hybridization": {"s", "sp", "sp2", "sp3"}, | |
| } | |
| ) | |
| bond_featurizer = BondFeaturizer( | |
| allowable_sets={ | |
| "bond_type": {"single", "double", "triple", "aromatic"}, | |
| "conjugated": {True, False}, | |
| } | |
| ) | |
| """ | |
| ### Generate graphs | |
| Before we can generate complete graphs from SMILES, we need to implement the following functions: | |
| 1. `molecule_from_smiles`, which takes as input a SMILES and returns a molecule object. | |
| This is all handled by RDKit. | |
| 2. `graph_from_molecule`, which takes as input a molecule object and returns a graph, | |
| represented as a three-tuple (atom_features, bond_features, pair_indices). For this we | |
| will make use of the classes defined previously. | |
| Finally, we can now implement the function `graphs_from_smiles`, which applies function (1) | |
| and subsequently (2) on all SMILES of the training, validation and test datasets. | |
| Notice: although scaffold splitting is recommended for this data set (see | |
| [here](https://arxiv.org/abs/1703.00564)), for simplicity, simple random splittings were | |
| performed. | |
| """ | |
| def molecule_from_smiles(smiles): | |
| # MolFromSmiles(m, sanitize=True) should be equivalent to | |
| # MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...) | |
| molecule = Chem.MolFromSmiles(smiles, sanitize=False) | |
| # If sanitization is unsuccessful, catch the error, and try again without | |
| # the sanitization step that caused the error | |
| flag = Chem.SanitizeMol(molecule, catchErrors=True) | |
| if flag != Chem.SanitizeFlags.SANITIZE_NONE: | |
| Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag) | |
| Chem.AssignStereochemistry(molecule, cleanIt=True, force=True) | |
| return molecule | |
| def graph_from_molecule(molecule): | |
| # Initialize graph | |
| atom_features = [] | |
| bond_features = [] | |
| pair_indices = [] | |
| for atom in molecule.GetAtoms(): | |
| atom_features.append(atom_featurizer.encode(atom)) | |
| # Add self-loops | |
| pair_indices.append([atom.GetIdx(), atom.GetIdx()]) | |
| bond_features.append(bond_featurizer.encode(None)) | |
| for neighbor in atom.GetNeighbors(): | |
| bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()) | |
| pair_indices.append([atom.GetIdx(), neighbor.GetIdx()]) | |
| bond_features.append(bond_featurizer.encode(bond)) | |
| return np.array(atom_features), np.array(bond_features), np.array(pair_indices) | |
| def graphs_from_smiles(smiles_list): | |
| # Initialize graphs | |
| atom_features_list = [] | |
| bond_features_list = [] | |
| pair_indices_list = [] | |
| for smiles in smiles_list: | |
| molecule = molecule_from_smiles(smiles) | |
| atom_features, bond_features, pair_indices = graph_from_molecule(molecule) | |
| atom_features_list.append(atom_features) | |
| bond_features_list.append(bond_features) | |
| pair_indices_list.append(pair_indices) | |
| # Convert lists to ragged tensors for tf.data.Dataset later on | |
| return ( | |
| tf.ragged.constant(atom_features_list, dtype=tf.float32), | |
| tf.ragged.constant(bond_features_list, dtype=tf.float32), | |
| tf.ragged.constant(pair_indices_list, dtype=tf.int64), | |
| ) | |
| # Shuffle array of indices ranging from 0 to 2049 | |
| permuted_indices = np.random.permutation(np.arange(df.shape[0])) | |
| # Train set: 80 % of data | |
| train_index = permuted_indices[: int(df.shape[0] * 0.8)] | |
| x_train = graphs_from_smiles(df.iloc[train_index].smiles) | |
| y_train = df.iloc[train_index].p_np | |
| # Valid set: 19 % of data | |
| valid_index = permuted_indices[int(df.shape[0] * 0.8) : int(df.shape[0] * 0.99)] | |
| x_valid = graphs_from_smiles(df.iloc[valid_index].smiles) | |
| y_valid = df.iloc[valid_index].p_np | |
| # Test set: 1 % of data | |
| test_index = permuted_indices[int(df.shape[0] * 0.99) :] | |
| x_test = graphs_from_smiles(df.iloc[test_index].smiles) | |
| y_test = df.iloc[test_index].p_np | |
| """ | |
| ### Test the functions | |
| """ | |
| print(f"Name:\t{df.name[100]}\nSMILES:\t{df.smiles[100]}\nBBBP:\t{df.p_np[100]}") | |
| molecule = molecule_from_smiles(df.iloc[100].smiles) | |
| print("Molecule:") | |
| molecule | |
| """ | |
| """ | |
| graph = graph_from_molecule(molecule) | |
| print("Graph (including self-loops):") | |
| print("\tatom features\t", graph[0].shape) | |
| print("\tbond features\t", graph[1].shape) | |
| print("\tpair indices\t", graph[2].shape) | |
| """ | |
| ### Create a `tf.data.Dataset` | |
| In this tutorial, the MPNN implementation will take as input (per iteration) a single graph. | |
| Therefore, given a batch of (sub)graphs (molecules), we need to merge them into a | |
| single graph (we'll refer to this graph as *global graph*). | |
| This global graph is a disconnected graph where each subgraph is | |
| completely separated from the other subgraphs. | |
| """ | |
| def prepare_batch(x_batch, y_batch): | |
| """Merges (sub)graphs of batch into a single global (disconnected) graph""" | |
| atom_features, bond_features, pair_indices = x_batch | |
| # Obtain number of atoms and bonds for each graph (molecule) | |
| num_atoms = atom_features.row_lengths() | |
| num_bonds = bond_features.row_lengths() | |
| # Obtain partition indices (molecule_indicator), which will be used to | |
| # gather (sub)graphs from global graph in model later on | |
| molecule_indices = tf.range(len(num_atoms)) | |
| molecule_indicator = tf.repeat(molecule_indices, num_atoms) | |
| # Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to | |
| # 'pair_indices' (and merging ragged tensors) actualizes the global graph | |
| gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:]) | |
| increment = tf.cumsum(num_atoms[:-1]) | |
| increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)]) | |
| pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
| pair_indices = pair_indices + increment[:, tf.newaxis] | |
| atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
| bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
| return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch | |
| def MPNNDataset(X, y, batch_size=32, shuffle=False): | |
| dataset = tf.data.Dataset.from_tensor_slices((X, (y))) | |
| if shuffle: | |
| dataset = dataset.shuffle(1024) | |
| return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1) | |
| """ | |
| ## Model | |
| The MPNN model can take on various shapes and forms. In this tutorial, we will implement an | |
| MPNN based on the original paper | |
| [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and | |
| [DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel). | |
| The MPNN of this tutorial consists of three stages: message passing, readout and | |
| classification. | |
| ### Message passing | |
| The message passing step itself consists of two parts: | |
| 1. The *edge network*, which passes messages from 1-hop neighbors `w_{i}` of `v` | |
| to `v`, based on the edge features between them (`e_{vw_{i}}`), | |
| resulting in an updated node (state) `v'`. `w_{i}` denotes the `i:th` neighbor of | |
| `v`. | |
| 2. The *gated recurrent unit* (GRU), which takes as input the most recent node state | |
| and updates it based on previous node states. In | |
| other words, the most recent node state serves as the input to the GRU, while the previous | |
| node states are incorporated within the memory state of the GRU. This allows information | |
| to travel from one node state (e.g., `v`) to another (e.g., `v''`). | |
| Importantly, step (1) and (2) are repeated for `k steps`, and where at each step `1...k`, | |
| the radius (or number of hops) of aggregated information from `v` increases by 1. | |
| """ | |
| class EdgeNetwork(layers.Layer): | |
| def build(self, input_shape): | |
| self.atom_dim = input_shape[0][-1] | |
| self.bond_dim = input_shape[1][-1] | |
| self.kernel = self.add_weight( | |
| shape=(self.bond_dim, self.atom_dim * self.atom_dim), | |
| initializer="glorot_uniform", | |
| name="kernel", | |
| ) | |
| self.bias = self.add_weight( | |
| shape=(self.atom_dim * self.atom_dim), | |
| initializer="zeros", | |
| name="bias", | |
| ) | |
| self.built = True | |
| def call(self, inputs): | |
| atom_features, bond_features, pair_indices = inputs | |
| # Apply linear transformation to bond features | |
| bond_features = tf.matmul(bond_features, self.kernel) + self.bias | |
| # Reshape for neighborhood aggregation later | |
| bond_features = tf.reshape(bond_features, (-1, self.atom_dim, self.atom_dim)) | |
| # Obtain atom features of neighbors | |
| atom_features_neighbors = tf.gather(atom_features, pair_indices[:, 1]) | |
| atom_features_neighbors = tf.expand_dims(atom_features_neighbors, axis=-1) | |
| # Apply neighborhood aggregation | |
| transformed_features = tf.matmul(bond_features, atom_features_neighbors) | |
| transformed_features = tf.squeeze(transformed_features, axis=-1) | |
| aggregated_features = tf.math.unsorted_segment_sum( | |
| transformed_features, | |
| pair_indices[:, 0], | |
| num_segments=tf.shape(atom_features)[0], | |
| ) | |
| return aggregated_features | |
| class MessagePassing(layers.Layer): | |
| def __init__(self, units, steps=4, **kwargs): | |
| super().__init__(**kwargs) | |
| self.units = units | |
| self.steps = steps | |
| def build(self, input_shape): | |
| self.atom_dim = input_shape[0][-1] | |
| self.message_step = EdgeNetwork() | |
| self.pad_length = max(0, self.units - self.atom_dim) | |
| self.update_step = layers.GRUCell(self.atom_dim + self.pad_length) | |
| self.built = True | |
| def call(self, inputs): | |
| atom_features, bond_features, pair_indices = inputs | |
| # Pad atom features if number of desired units exceeds atom_features dim. | |
| # Alternatively, a dense layer could be used here. | |
| atom_features_updated = tf.pad(atom_features, [(0, 0), (0, self.pad_length)]) | |
| # Perform a number of steps of message passing | |
| for i in range(self.steps): | |
| # Aggregate information from neighbors | |
| atom_features_aggregated = self.message_step( | |
| [atom_features_updated, bond_features, pair_indices] | |
| ) | |
| # Update node state via a step of GRU | |
| atom_features_updated, _ = self.update_step( | |
| atom_features_aggregated, atom_features_updated | |
| ) | |
| return atom_features_updated | |
| """ | |
| ### Readout | |
| When the message passing procedure ends, the k-step-aggregated node states are to be partitioned | |
| into subgraphs (corresponding to each molecule in the batch) and subsequently | |
| reduced to graph-level embeddings. In the | |
| [original paper](https://arxiv.org/abs/1704.01212), a | |
| [set-to-set layer](https://arxiv.org/abs/1511.06391) was used for this purpose. | |
| In this tutorial however, a transformer encoder + average pooling will be used. Specifically: | |
| * the k-step-aggregated node states will be partitioned into the subgraphs | |
| (corresponding to each molecule in the batch); | |
| * each subgraph will then be padded to match the subgraph with the greatest number of nodes, followed | |
| by a `tf.stack(...)`; | |
| * the (stacked padded) tensor, encoding subgraphs (each subgraph containing a set of node states), are | |
| masked to make sure the paddings don't interfere with training; | |
| * finally, the tensor is passed to the transformer followed by average pooling. | |
| """ | |
| class PartitionPadding(layers.Layer): | |
| def __init__(self, batch_size, **kwargs): | |
| super().__init__(**kwargs) | |
| self.batch_size = batch_size | |
| def call(self, inputs): | |
| atom_features, molecule_indicator = inputs | |
| # Obtain subgraphs | |
| atom_features_partitioned = tf.dynamic_partition( | |
| atom_features, molecule_indicator, self.batch_size | |
| ) | |
| # Pad and stack subgraphs | |
| num_atoms = [tf.shape(f)[0] for f in atom_features_partitioned] | |
| max_num_atoms = tf.reduce_max(num_atoms) | |
| atom_features_stacked = tf.stack( | |
| [ | |
| tf.pad(f, [(0, max_num_atoms - n), (0, 0)]) | |
| for f, n in zip(atom_features_partitioned, num_atoms) | |
| ], | |
| axis=0, | |
| ) | |
| # Remove empty subgraphs (usually for last batch in dataset) | |
| gather_indices = tf.where(tf.reduce_sum(atom_features_stacked, (1, 2)) != 0) | |
| gather_indices = tf.squeeze(gather_indices, axis=-1) | |
| return tf.gather(atom_features_stacked, gather_indices, axis=0) | |
| class TransformerEncoderReadout(layers.Layer): | |
| def __init__( | |
| self, num_heads=8, embed_dim=64, dense_dim=512, batch_size=32, **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.partition_padding = PartitionPadding(batch_size) | |
| self.attention = layers.MultiHeadAttention(num_heads, embed_dim) | |
| self.dense_proj = keras.Sequential( | |
| [ | |
| layers.Dense(dense_dim, activation="relu"), | |
| layers.Dense(embed_dim), | |
| ] | |
| ) | |
| self.layernorm_1 = layers.LayerNormalization() | |
| self.layernorm_2 = layers.LayerNormalization() | |
| self.average_pooling = layers.GlobalAveragePooling1D() | |
| def call(self, inputs): | |
| x = self.partition_padding(inputs) | |
| padding_mask = tf.reduce_any(tf.not_equal(x, 0.0), axis=-1) | |
| padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] | |
| attention_output = self.attention(x, x, attention_mask=padding_mask) | |
| proj_input = self.layernorm_1(x + attention_output) | |
| proj_output = self.layernorm_2(proj_input + self.dense_proj(proj_input)) | |
| return self.average_pooling(proj_output) | |
| """ | |
| ### Message Passing Neural Network (MPNN) | |
| It is now time to complete the MPNN model. In addition to the message passing | |
| and readout, a two-layer classification network will be implemented to make | |
| predictions of BBBP. | |
| """ | |
| def MPNNModel( | |
| atom_dim, | |
| bond_dim, | |
| batch_size=32, | |
| message_units=64, | |
| message_steps=4, | |
| num_attention_heads=8, | |
| dense_units=512, | |
| ): | |
| atom_features = layers.Input((atom_dim), dtype="float32", name="atom_features") | |
| bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features") | |
| pair_indices = layers.Input((2), dtype="int32", name="pair_indices") | |
| molecule_indicator = layers.Input((), dtype="int32", name="molecule_indicator") | |
| x = MessagePassing(message_units, message_steps)( | |
| [atom_features, bond_features, pair_indices] | |
| ) | |
| x = TransformerEncoderReadout( | |
| num_attention_heads, message_units, dense_units, batch_size | |
| )([x, molecule_indicator]) | |
| x = layers.Dense(dense_units, activation="relu")(x) | |
| x = layers.Dense(1, activation="sigmoid")(x) | |
| model = keras.Model( | |
| inputs=[atom_features, bond_features, pair_indices, molecule_indicator], | |
| outputs=[x], | |
| ) | |
| return model | |
| mpnn = MPNNModel( | |
| atom_dim=x_train[0][0][0].shape[0], | |
| bond_dim=x_train[1][0][0].shape[0], | |
| ) | |
| mpnn.compile( | |
| loss=keras.losses.BinaryCrossentropy(), | |
| optimizer=keras.optimizers.Adam(learning_rate=5e-4), | |
| metrics=[keras.metrics.AUC(name="AUC")], | |
| ) | |
| keras.utils.plot_model(mpnn, show_dtype=True, show_shapes=True) | |
| """ | |
| ### Training | |
| """ | |
| train_dataset = MPNNDataset(x_train, y_train) | |
| valid_dataset = MPNNDataset(x_valid, y_valid) | |
| test_dataset = MPNNDataset(x_test, y_test) | |
| history = mpnn.fit( | |
| train_dataset, | |
| validation_data=valid_dataset, | |
| epochs=40, | |
| verbose=2, | |
| class_weight={0: 2.0, 1: 0.5}, | |
| ) | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(history.history["AUC"], label="train AUC") | |
| plt.plot(history.history["val_AUC"], label="valid AUC") | |
| plt.xlabel("Epochs", fontsize=16) | |
| plt.ylabel("AUC", fontsize=16) | |
| plt.legend(fontsize=16) | |
| """ | |
| ### Predicting | |
| """ | |
| molecules = [molecule_from_smiles(df.smiles.values[index]) for index in test_index] | |
| y_true = [df.p_np.values[index] for index in test_index] | |
| y_pred = tf.squeeze(mpnn.predict(test_dataset), axis=1) | |
| legends = [f"y_true/y_pred = {y_true[i]}/{y_pred[i]:.2f}" for i in range(len(y_true))] | |
| MolsToGridImage(molecules, molsPerRow=4, legends=legends) | |
| """ | |
| ## Conclusions | |
| In this tutorial, we demonstrated a message passing neural network (MPNN) to | |
| predict blood-brain barrier permeability (BBBP) for a number of different molecules. We | |
| first had to construct graphs from SMILES, then build a Keras model that could | |
| operate on these graphs, and finally train the model to make the predictions. | |
| Example available on HuggingFace | |
| | Trained Model | Demo | | |
| | :--: | :--: | | |
| | [](https://huggingface.co/keras-io/MPNN-for-molecular-property-prediction) | [](https://huggingface.co/spaces/keras-io/molecular-property-prediction) | | |
| """ | |