import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Union, Optional, Callable, Dict from deepchem.models.losses import Loss, L2Loss from deepchem.models.torch_models.torch_model import TorchModel from deepchem.models.optimizers import Optimizer, LearningRateSchedule from openpom.layers.pom_ffn import CustomPositionwiseFeedForward from openpom.utils.loss import CustomMultiLabelLoss from openpom.utils.optimizer import get_optimizer try: import dgl from dgl import DGLGraph from dgl.nn.pytorch import Set2Set from openpom.layers.pom_mpnn_gnn import CustomMPNNGNN except (ImportError, ModuleNotFoundError): raise ImportError('This module requires dgl and dgllife') class MPNNPOM(nn.Module): """ MPNN model computes a principal odor map using multilabel-classification based on the pre-print: "A Principal Odor Map Unifies DiverseTasks in Human Olfactory Perception" [1] This model proceeds as follows: * Combine latest node representations and edge features in updating node representations, which involves multiple rounds of message passing. * For each graph, compute its representation by radius 0 combination to fold atom and bond embeddings together, followed by 'set2set' or 'global_sum_pooling' readout. * Perform the final prediction using a feed-forward layer. References ---------- .. [1] Brian K. Lee, Emily J. Mayhew, Benjamin Sanchez-Lengeling, Jennifer N. Wei, Wesley W. Qian, Kelsie Little, Matthew Andres, Britney B. Nguyen, Theresa Moloy, Jane K. Parker, Richard C. Gerkin, Joel D. Mainland, Alexander B. Wiltschko `A Principal Odor Map Unifies Diverse Tasks in Human Olfactory Perception preprint `_. .. [2] Benjamin Sanchez-Lengeling, Jennifer N. Wei, Brian K. Lee, Richard C. Gerkin, Alán Aspuru-Guzik, Alexander B. Wiltschko `Machine Learning for Scent: Learning Generalizable Perceptual Representations of Small Molecules `_. .. [3] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, George E. Dahl. "Neural Message Passing for Quantum Chemistry." ICML 2017. Notes ----- This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci (https://github.com/awslabs/dgl-lifesci) to be installed. """ def __init__(self, n_tasks: int, node_out_feats: int = 64, edge_hidden_feats: int = 128, edge_out_feats: int = 64, num_step_message_passing: int = 3, mpnn_residual: bool = True, message_aggregator_type: str = 'sum', mode: str = 'classification', number_atom_features: int = 134, number_bond_features: int = 6, n_classes: int = 1, nfeat_name: str = 'x', efeat_name: str = 'edge_attr', readout_type: str = 'set2set', num_step_set2set: int = 6, num_layer_set2set: int = 3, ffn_hidden_list: List = [300], ffn_embeddings: int = 256, ffn_activation: str = 'relu', ffn_dropout_p: float = 0.0, ffn_dropout_at_input_no_act: bool = True): """ Parameters ---------- n_tasks: int Number of tasks. node_out_feats: int The length of the final node representation vectors before readout. Default to 64. edge_hidden_feats: int The length of the hidden edge representation vectors for mpnn edge network. Default to 128. edge_out_feats: int The length of the final edge representation vectors before readout. Default to 64. num_step_message_passing: int The number of rounds of message passing. Default to 3. mpnn_residual: bool If true, adds residual layer to mpnn layer. Default to True. message_aggregator_type: str MPNN message aggregator type, 'sum', 'mean' or 'max'. Default to 'sum'. mode: str The model type, 'classification' or 'regression'. Default to 'classification'. number_atom_features: int The length of the initial atom feature vectors. Default to 134. number_bond_features: int The length of the initial bond feature vectors. Default to 6. n_classes: int The number of classes to predict per task (only used when ``mode`` is 'classification'). Default to 1. nfeat_name: str For an input graph ``g``, the model assumes that it stores node features in ``g.ndata[nfeat_name]`` and will retrieve input node features from that. Default to 'x'. efeat_name: str For an input graph ``g``, the model assumes that it stores edge features in ``g.edata[efeat_name]`` and will retrieve input edge features from that. Default to 'edge_attr'. readout_type: str The Readout type, 'set2set' or 'global_sum_pooling'. Default to 'set2set'. num_step_set2set: int Number of steps in set2set readout. Used if, readout_type == 'set2set'. Default to 6. num_layer_set2set: int Number of layers in set2set readout. Used if, readout_type == 'set2set'. Default to 3. ffn_hidden_list: List List of sizes of hidden layer in the feed-forward network layer. Default to [300]. ffn_embeddings: int Size of penultimate layer in the feed-forward network layer. This determines the Principal Odor Map dimension. Default to 256. ffn_activation: str Activation function to be used in feed-forward network layer. Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU, 'prelu' for PReLU, 'tanh' for TanH, 'selu' for SELU, and 'elu' for ELU. ffn_dropout_p: float Dropout probability for the feed-forward network layer. Default to 0.0 ffn_dropout_at_input_no_act: bool If true, dropout is applied on the input tensor. For single layer, it is not passed to an activation function. """ if mode not in ['classification', 'regression']: raise ValueError( "mode must be either 'classification' or 'regression'") super(MPNNPOM, self).__init__() self.n_tasks: int = n_tasks self.mode: str = mode self.n_classes: int = n_classes self.nfeat_name: str = nfeat_name self.efeat_name: str = efeat_name self.readout_type: str = readout_type self.ffn_embeddings: int = ffn_embeddings self.ffn_activation: str = ffn_activation self.ffn_dropout_p: float = ffn_dropout_p if mode == 'classification': self.ffn_output: int = n_tasks * n_classes else: self.ffn_output = n_tasks self.mpnn: nn.Module = CustomMPNNGNN( node_in_feats=number_atom_features, node_out_feats=node_out_feats, edge_in_feats=number_bond_features, edge_hidden_feats=edge_hidden_feats, num_step_message_passing=num_step_message_passing, residual=mpnn_residual, message_aggregator_type=message_aggregator_type) self.project_edge_feats: nn.Module = nn.Sequential( nn.Linear(number_bond_features, edge_out_feats), nn.ReLU()) if self.readout_type == 'set2set': self.readout_set2set: nn.Module = Set2Set( input_dim=node_out_feats + edge_out_feats, n_iters=num_step_set2set, n_layers=num_layer_set2set) ffn_input: int = 2 * (node_out_feats + edge_out_feats) elif self.readout_type == 'global_sum_pooling': ffn_input = node_out_feats + edge_out_feats else: raise Exception("readout_type invalid") if ffn_embeddings is not None: d_hidden_list: List = ffn_hidden_list + [ffn_embeddings] self.ffn: nn.Module = CustomPositionwiseFeedForward( d_input=ffn_input, d_hidden_list=d_hidden_list, d_output=self.ffn_output, activation=ffn_activation, dropout_p=ffn_dropout_p, dropout_at_input_no_act=ffn_dropout_at_input_no_act) def _readout(self, g: DGLGraph, node_encodings: torch.Tensor, edge_feats: torch.Tensor) -> torch.Tensor: """ Method to execute the readout phase. (compute molecules encodings from atom hidden states) Readout phase consists of radius 0 combination to fold atom and bond embeddings together, followed by: - a reduce-sum across atoms if `self.readout_type == 'global_sum_pooling'` - set2set pooling if `self.readout_type == 'set2set'` Parameters ---------- g: DGLGraph A DGLGraph for a batch of graphs. It stores the node features in ``dgl_graph.ndata[self.nfeat_name]`` and edge features in ``dgl_graph.edata[self.efeat_name]``. node_encodings: torch.Tensor Tensor containing node hidden states. edge_feats: torch.Tensor Tensor containing edge features. Returns ------- batch_mol_hidden_states: torch.Tensor Tensor containing batchwise molecule encodings. """ g.ndata['node_emb'] = node_encodings g.edata['edge_emb'] = self.project_edge_feats(edge_feats) def message_func(edges) -> Dict: """ The message function to generate messages along the edges for DGLGraph.send_and_recv() """ src_msg: torch.Tensor = torch.cat( (edges.src['node_emb'], edges.data['edge_emb']), dim=1) return {'src_msg': src_msg} def reduce_func(nodes) -> Dict: """ The reduce function to aggregate the messages for DGLGraph.send_and_recv() """ src_msg_sum: torch.Tensor = torch.sum(nodes.mailbox['src_msg'], dim=1) return {'src_msg_sum': src_msg_sum} # radius 0 combination to fold atom and bond embeddings together g.send_and_recv(g.edges(), message_func=message_func, reduce_func=reduce_func) if self.readout_type == 'set2set': batch_mol_hidden_states: torch.Tensor = self.readout_set2set( g, g.ndata['src_msg_sum']) elif self.readout_type == 'global_sum_pooling': batch_mol_hidden_states = dgl.sum_nodes(g, 'src_msg_sum') # batch_size x (node_out_feats + edge_out_feats) return batch_mol_hidden_states def forward( self, g: DGLGraph ) -> Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: """ Foward pass for MPNNPOM class. It also returns embeddings for POM. Parameters ---------- g: DGLGraph A DGLGraph for a batch of graphs. It stores the node features in ``dgl_graph.ndata[self.nfeat_name]`` and edge features in ``dgl_graph.edata[self.efeat_name]``. Returns ------- Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] The model output. * When self.mode = 'regression', its shape will be ``(dgl_graph.batch_size, self.n_tasks)``. * When self.mode = 'classification', the output consists of probabilities for classes. Its shape will be ``(dgl_graph.batch_size, self.n_tasks, self.n_classes)`` if self.n_tasks > 1; its shape will be ``(dgl_graph.batch_size, self.n_classes)`` if self.n_tasks is 1. """ node_feats: torch.Tensor = g.ndata[self.nfeat_name] edge_feats: torch.Tensor = g.edata[self.efeat_name] node_encodings: torch.Tensor = self.mpnn(g, node_feats, edge_feats) molecular_encodings: torch.Tensor = self._readout( g, node_encodings, edge_feats) if self.readout_type == 'global_sum_pooling': molecular_encodings = F.softmax(molecular_encodings, dim=1) embeddings: torch.Tensor out: torch.Tensor embeddings, out = self.ffn(molecular_encodings) if self.mode == 'classification': if self.n_tasks == 1: logits: torch.Tensor = out.view(-1, self.n_classes) else: logits = out.view(-1, self.n_tasks, self.n_classes) proba: torch.Tensor = F.sigmoid( logits) # (batch, n_tasks, classes) if self.n_classes == 1: proba = proba.squeeze(-1) # (batch, n_tasks) return proba, logits, embeddings else: return out class MPNNPOMModel(TorchModel): """ MPNNPOMModel for obtaining a principal odor map using multilabel-classification based on the pre-print: "A Principal Odor Map Unifies DiverseTasks in Human Olfactory Perception" [1] * Combine latest node representations and edge features in updating node representations, which involves multiple rounds of message passing. * For each graph, compute its representation by radius 0 combination to fold atom and bond embeddings together, followed by 'set2set' or 'global_sum_pooling' readout. * Perform the final prediction using a feed-forward layer. References ---------- .. [1] Brian K. Lee, Emily J. Mayhew, Benjamin Sanchez-Lengeling, Jennifer N. Wei, Wesley W. Qian, Kelsie Little, Matthew Andres, Britney B. Nguyen, Theresa Moloy, Jane K. Parker, Richard C. Gerkin, Joel D. Mainland, Alexander B. Wiltschko `A Principal Odor Map Unifies Diverse Tasks in Human Olfactory Perception preprint `_. .. [2] Benjamin Sanchez-Lengeling, Jennifer N. Wei, Brian K. Lee, Richard C. Gerkin, Alán Aspuru-Guzik, Alexander B. Wiltschko `Machine Learning for Scent: Learning Generalizable Perceptual Representations of Small Molecules `_. .. [3] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, George E. Dahl. "Neural Message Passing for Quantum Chemistry." ICML 2017. Notes ----- This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci (https://github.com/awslabs/dgl-lifesci) to be installed. The featurizer used with MPNNPOMModel must produce a Deepchem GraphData object which should have both 'edge' and 'node' features. """ def __init__(self, n_tasks: int, class_imbalance_ratio: Optional[List] = None, loss_aggr_type: str = 'sum', learning_rate: Union[float, LearningRateSchedule] = 0.001, batch_size: int = 100, node_out_feats: int = 64, edge_hidden_feats: int = 128, edge_out_feats: int = 64, num_step_message_passing: int = 3, mpnn_residual: bool = True, message_aggregator_type: str = 'sum', mode: str = 'regression', number_atom_features: int = 134, number_bond_features: int = 6, n_classes: int = 1, readout_type: str = 'set2set', num_step_set2set: int = 6, num_layer_set2set: int = 3, ffn_hidden_list: List = [300], ffn_embeddings: int = 256, ffn_activation: str = 'relu', ffn_dropout_p: float = 0.0, ffn_dropout_at_input_no_act: bool = True, weight_decay: float = 1e-5, self_loop: bool = False, optimizer_name: str = 'adam', device_name: Optional[str] = None, **kwargs): """ Parameters ---------- n_tasks: int Number of tasks. class_imbalance_ratio: Optional[List] List of imbalance ratios per task. loss_aggr_type: str loss aggregation type; 'sum' or 'mean'. Default to 'sum'. Only applies to CustomMultiLabelLoss for classification learning_rate: Union[float, LearningRateSchedule] Learning rate value or scheduler object. Default to 0.001. batch_size: int Batch size for training. Default to 100. node_out_feats: int The length of the final node representation vectors before readout. Default to 64. edge_hidden_feats: int The length of the hidden edge representation vectors for mpnn edge network. Default to 128. edge_out_feats: int The length of the final edge representation vectors before readout. Default to 64. num_step_message_passing: int The number of rounds of message passing. Default to 3. mpnn_residual: bool If true, adds residual layer to mpnn layer. Default to True. message_aggregator_type: str MPNN message aggregator type, 'sum', 'mean' or 'max'. Default to 'sum'. mode: str The model type, 'classification' or 'regression'. Default to 'classification'. number_atom_features: int The length of the initial atom feature vectors. Default to 134. number_bond_features: int The length of the initial bond feature vectors. Default to 6. n_classes: int The number of classes to predict per task (only used when ``mode`` is 'classification'). Default to 1. readout_type: str The Readout type, 'set2set' or 'global_sum_pooling'. Default to 'set2set'. num_step_set2set: int Number of steps in set2set readout. Used if, readout_type == 'set2set'. Default to 6. num_layer_set2set: int Number of layers in set2set readout. Used if, readout_type == 'set2set'. Default to 3. ffn_hidden_list: List List of sizes of hidden layer in the feed-forward network layer. Default to [300]. ffn_embeddings: int Size of penultimate layer in the feed-forward network layer. This determines the Principal Odor Map dimension. Default to 256. ffn_activation: str Activation function to be used in feed-forward network layer. Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU, 'prelu' for PReLU, 'tanh' for TanH, 'selu' for SELU, and 'elu' for ELU. ffn_dropout_p: float Dropout probability for the feed-forward network layer. Default to 0.0 ffn_dropout_at_input_no_act: bool If true, dropout is applied on the input tensor. For single layer, it is not passed to an activation function. weight_decay: float weight decay value for L1 and L2 regularization. Default to 1e-5. self_loop: bool Whether to add self loops for the nodes, i.e. edges from nodes to themselves. Generally, an MPNNPOMModel does not require self loops. Default to False. optimizer_name: str Name of optimizer to be used from [adam, adagrad, adamw, sparseadam, rmsprop, sgd, kfac] Default to 'adam'. device_name: Optional[str] The device on which to run computations. If None, a device is chosen automatically. kwargs This can include any keyword argument of TorchModel. """ model: nn.Module = MPNNPOM( n_tasks=n_tasks, node_out_feats=node_out_feats, edge_hidden_feats=edge_hidden_feats, edge_out_feats=edge_out_feats, num_step_message_passing=num_step_message_passing, mpnn_residual=mpnn_residual, message_aggregator_type=message_aggregator_type, mode=mode, number_atom_features=number_atom_features, number_bond_features=number_bond_features, n_classes=n_classes, readout_type=readout_type, num_step_set2set=num_step_set2set, num_layer_set2set=num_layer_set2set, ffn_hidden_list=ffn_hidden_list, ffn_embeddings=ffn_embeddings, ffn_activation=ffn_activation, ffn_dropout_p=ffn_dropout_p, ffn_dropout_at_input_no_act=ffn_dropout_at_input_no_act) if class_imbalance_ratio and (len(class_imbalance_ratio) != n_tasks): raise Exception("size of class_imbalance_ratio \ should be equal to n_tasks") if mode == 'regression': loss: Loss = L2Loss() output_types: List = ['prediction'] else: loss = CustomMultiLabelLoss( class_imbalance_ratio=class_imbalance_ratio, loss_aggr_type=loss_aggr_type, device=device_name) output_types = ['prediction', 'loss', 'embedding'] optimizer: Optimizer = get_optimizer(optimizer_name) optimizer.learning_rate = learning_rate if device_name is not None: device: Optional[torch.device] = torch.device(device_name) else: device = None super(MPNNPOMModel, self).__init__(model, loss=loss, output_types=output_types, optimizer=optimizer, learning_rate=learning_rate, batch_size=batch_size, device=device, **kwargs) self.weight_decay: float = weight_decay self._self_loop: bool = self_loop self.regularization_loss: Callable = self._regularization_loss def _regularization_loss(self) -> torch.Tensor: """ L1 and L2-norm losses for regularization Returns ------- torch.Tensor sum of l1_norm and l2_norm """ l1_regularization: torch.Tensor = torch.tensor(0., requires_grad=True) l2_regularization: torch.Tensor = torch.tensor(0., requires_grad=True) for name, param in self.model.named_parameters(): if 'bias' not in name: l1_regularization = l1_regularization + torch.norm(param, p=1) l2_regularization = l2_regularization + torch.norm(param, p=2) l1_norm: torch.Tensor = self.weight_decay * l1_regularization l2_norm: torch.Tensor = self.weight_decay * l2_regularization return l1_norm + l2_norm def _prepare_batch( self, batch: Tuple[List, List, List] ) -> Tuple[DGLGraph, List[torch.Tensor], List[torch.Tensor]]: """Create batch data for MPNN. Parameters ---------- batch: Tuple[List, List, List] The tuple is ``(inputs, labels, weights)``. Returns ------- g: DGLGraph DGLGraph for a batch of graphs. labels: list of torch.Tensor or None The graph labels. weights: list of torch.Tensor or None The weights for each sample or sample/task pair converted to torch.Tensor. """ inputs: List labels: List weights: List inputs, labels, weights = batch dgl_graphs: List[DGLGraph] = [ graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0] ] g: DGLGraph = dgl.batch(dgl_graphs).to(self.device) _, labels, weights = super(MPNNPOMModel, self)._prepare_batch( ([], labels, weights)) return g, labels, weights