Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from src.model.layers import TransformerEncoder | |
| class Generator(nn.Module): | |
| """ | |
| Generator network that uses a Transformer Encoder to process node and edge features. | |
| The network first processes input node and edge features with separate linear layers, | |
| then applies a Transformer Encoder to model interactions, and finally outputs both transformed | |
| features and readout samples. | |
| """ | |
| def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): | |
| """ | |
| Initializes the Generator. | |
| Args: | |
| act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh"). | |
| vertexes (int): Number of vertexes in the graph. | |
| edges (int): Number of edge features. | |
| nodes (int): Number of node features. | |
| dropout (float): Dropout rate. | |
| dim (int): Dimensionality used for intermediate features. | |
| depth (int): Number of Transformer encoder blocks. | |
| heads (int): Number of attention heads in the Transformer. | |
| mlp_ratio (int): Ratio for determining hidden layer size in MLP modules. | |
| """ | |
| super(Generator, self).__init__() | |
| self.vertexes = vertexes | |
| self.edges = edges | |
| self.nodes = nodes | |
| self.depth = depth | |
| self.dim = dim | |
| self.heads = heads | |
| self.mlp_ratio = mlp_ratio | |
| self.dropout = dropout | |
| # Set the activation function based on the provided string | |
| if act == "relu": | |
| act = nn.ReLU() | |
| elif act == "leaky": | |
| act = nn.LeakyReLU() | |
| elif act == "sigmoid": | |
| act = nn.Sigmoid() | |
| elif act == "tanh": | |
| act = nn.Tanh() | |
| # Calculate the total number of features and dimensions for transformer | |
| self.features = vertexes * vertexes * edges + vertexes * nodes | |
| self.transformer_dim = vertexes * vertexes * dim + vertexes * dim | |
| self.node_layers = nn.Sequential( | |
| nn.Linear(nodes, 64), act, | |
| nn.Linear(64, dim), act, | |
| nn.Dropout(self.dropout) | |
| ) | |
| self.edge_layers = nn.Sequential( | |
| nn.Linear(edges, 64), act, | |
| nn.Linear(64, dim), act, | |
| nn.Dropout(self.dropout) | |
| ) | |
| self.TransformerEncoder = TransformerEncoder( | |
| dim=self.dim, depth=self.depth, heads=self.heads, act=act, | |
| mlp_ratio=self.mlp_ratio, drop_rate=self.dropout | |
| ) | |
| self.readout_e = nn.Linear(self.dim, edges) | |
| self.readout_n = nn.Linear(self.dim, nodes) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, z_e, z_n): | |
| """ | |
| Forward pass of the Generator. | |
| Args: | |
| z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). | |
| z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). | |
| Returns: | |
| tuple: A tuple containing: | |
| - node: Updated node features after the transformer. | |
| - edge: Updated edge features after the transformer. | |
| - node_sample: Readout sample from node features. | |
| - edge_sample: Readout sample from edge features. | |
| """ | |
| b, n, c = z_n.shape | |
| # The fourth dimension of edge features | |
| _, _, _, d = z_e.shape | |
| # Process node and edge features through their respective layers | |
| node = self.node_layers(z_n) | |
| edge = self.edge_layers(z_e) | |
| # Symmetrize the edge features by averaging with its transpose along vertex dimensions | |
| edge = (edge + edge.permute(0, 2, 1, 3)) / 2 | |
| # Pass the features through the Transformer Encoder | |
| node, edge = self.TransformerEncoder(node, edge) | |
| # Readout layers to generate final outputs | |
| node_sample = self.readout_n(node) | |
| edge_sample = self.readout_e(edge) | |
| return node, edge, node_sample, edge_sample | |
| class Discriminator(nn.Module): | |
| """ | |
| Discriminator network that evaluates node and edge features. | |
| It processes features with linear layers, applies a Transformer Encoder to capture dependencies, | |
| and finally predicts a scalar value using an MLP on aggregated node features. | |
| This class is used in DrugGEN model. | |
| """ | |
| def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): | |
| """ | |
| Initializes the Discriminator. | |
| Args: | |
| act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). | |
| vertexes (int): Number of vertexes. | |
| edges (int): Number of edge features. | |
| nodes (int): Number of node features. | |
| dropout (float): Dropout rate. | |
| dim (int): Dimensionality for intermediate representations. | |
| depth (int): Number of Transformer encoder blocks. | |
| heads (int): Number of attention heads. | |
| mlp_ratio (int): MLP ratio for hidden layer dimensions. | |
| """ | |
| super(Discriminator, self).__init__() | |
| self.vertexes = vertexes | |
| self.edges = edges | |
| self.nodes = nodes | |
| self.depth = depth | |
| self.dim = dim | |
| self.heads = heads | |
| self.mlp_ratio = mlp_ratio | |
| self.dropout = dropout | |
| # Set the activation function | |
| if act == "relu": | |
| act = nn.ReLU() | |
| elif act == "leaky": | |
| act = nn.LeakyReLU() | |
| elif act == "sigmoid": | |
| act = nn.Sigmoid() | |
| elif act == "tanh": | |
| act = nn.Tanh() | |
| self.features = vertexes * vertexes * edges + vertexes * nodes | |
| self.transformer_dim = vertexes * vertexes * dim + vertexes * dim | |
| # Define layers for processing node and edge features | |
| self.node_layers = nn.Sequential( | |
| nn.Linear(nodes, 64), act, | |
| nn.Linear(64, dim), act, | |
| nn.Dropout(self.dropout) | |
| ) | |
| self.edge_layers = nn.Sequential( | |
| nn.Linear(edges, 64), act, | |
| nn.Linear(64, dim), act, | |
| nn.Dropout(self.dropout) | |
| ) | |
| # Transformer Encoder for modeling node and edge interactions | |
| self.TransformerEncoder = TransformerEncoder( | |
| dim=self.dim, depth=self.depth, heads=self.heads, act=act, | |
| mlp_ratio=self.mlp_ratio, drop_rate=self.dropout | |
| ) | |
| # Calculate dimensions for node features aggregation | |
| self.node_features = vertexes * dim | |
| self.edge_features = vertexes * vertexes * dim | |
| # MLP to predict a scalar value from aggregated node features | |
| self.node_mlp = nn.Sequential( | |
| nn.Linear(self.node_features, 64), act, | |
| nn.Linear(64, 32), act, | |
| nn.Linear(32, 16), act, | |
| nn.Linear(16, 1) | |
| ) | |
| def forward(self, z_e, z_n): | |
| """ | |
| Forward pass of the Discriminator. | |
| Args: | |
| z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). | |
| z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). | |
| Returns: | |
| torch.Tensor: Prediction scores (typically a scalar per sample). | |
| """ | |
| b, n, c = z_n.shape | |
| # Unpack the shape of edge features (not used further directly) | |
| _, _, _, d = z_e.shape | |
| # Process node and edge features separately | |
| node = self.node_layers(z_n) | |
| edge = self.edge_layers(z_e) | |
| # Symmetrize edge features by averaging with its transpose | |
| edge = (edge + edge.permute(0, 2, 1, 3)) / 2 | |
| # Process features through the Transformer Encoder | |
| node, edge = self.TransformerEncoder(node, edge) | |
| # Flatten node features for MLP | |
| node = node.view(b, -1) | |
| # Predict a scalar score using the node MLP | |
| prediction = self.node_mlp(node) | |
| return prediction | |
| class simple_disc(nn.Module): | |
| """ | |
| A simplified discriminator that processes flattened features through an MLP | |
| to predict a scalar score. | |
| This class is used in NoTarget model. | |
| """ | |
| def __init__(self, act, m_dim, vertexes, b_dim): | |
| """ | |
| Initializes the simple discriminator. | |
| Args: | |
| act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). | |
| m_dim (int): Dimensionality for atom type features. | |
| vertexes (int): Number of vertexes. | |
| b_dim (int): Dimensionality for bond type features. | |
| """ | |
| super().__init__() | |
| # Set the activation function and check if it's supported | |
| if act == "relu": | |
| act = nn.ReLU() | |
| elif act == "leaky": | |
| act = nn.LeakyReLU() | |
| elif act == "sigmoid": | |
| act = nn.Sigmoid() | |
| elif act == "tanh": | |
| act = nn.Tanh() | |
| else: | |
| raise ValueError("Unsupported activation function: {}".format(act)) | |
| # Compute total number of features combining both dimensions | |
| features = vertexes * m_dim + vertexes * vertexes * b_dim | |
| print(vertexes) | |
| print(m_dim) | |
| print(b_dim) | |
| print(features) | |
| self.predictor = nn.Sequential( | |
| nn.Linear(features, 256), act, | |
| nn.Linear(256, 128), act, | |
| nn.Linear(128, 64), act, | |
| nn.Linear(64, 32), act, | |
| nn.Linear(32, 16), act, | |
| nn.Linear(16, 1) | |
| ) | |
| def forward(self, x): | |
| """ | |
| Forward pass of the simple discriminator. | |
| Args: | |
| x (torch.Tensor): Input features tensor. | |
| Returns: | |
| torch.Tensor: Prediction scores. | |
| """ | |
| prediction = self.predictor(x) | |
| return prediction |