Spaces:
Sleeping
Sleeping
| from torch import cat, nn | |
| import torch.nn.functional as F | |
| from torch.nn import Sequential, Linear, ReLU | |
| from torch_geometric.nn import GINConv, global_add_pool | |
| class GIN(nn.Module): | |
| r""" | |
| From `GraphDTA <https://doi.org/10.1093/bioinformatics/btaa921>`_ (Nguyen et al., 2020), | |
| based on `Graph Isomorphism Network <https://arxiv.org/abs/1810.00826>`_ (Xu et al., 2019) | |
| """ | |
| def __init__( | |
| self, | |
| num_features: int, | |
| out_channels: int, | |
| dropout: float | |
| ): | |
| super().__init__() | |
| dim = 32 | |
| self.dropout = dropout | |
| self.relu = nn.ReLU() | |
| nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim)) | |
| self.conv1 = GINConv(nn1) | |
| self.bn1 = nn.BatchNorm1d(dim) | |
| nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) | |
| self.conv2 = GINConv(nn2) | |
| self.bn2 = nn.BatchNorm1d(dim) | |
| nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) | |
| self.conv3 = GINConv(nn3) | |
| self.bn3 = nn.BatchNorm1d(dim) | |
| nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) | |
| self.conv4 = GINConv(nn4) | |
| self.bn4 = nn.BatchNorm1d(dim) | |
| nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) | |
| self.conv5 = GINConv(nn5) | |
| self.bn5 = nn.BatchNorm1d(dim) | |
| self.fc1_xd = Linear(dim, out_channels) | |
| def forward(self, data): | |
| x, edge_index, batch = data.x, data.edge_index, data.batch | |
| x = F.relu(self.conv1(x, edge_index)) | |
| x = self.bn1(x) | |
| x = F.relu(self.conv2(x, edge_index)) | |
| x = self.bn2(x) | |
| x = F.relu(self.conv3(x, edge_index)) | |
| x = self.bn3(x) | |
| x = F.relu(self.conv4(x, edge_index)) | |
| x = self.bn4(x) | |
| x = F.relu(self.conv5(x, edge_index)) | |
| x = self.bn5(x) | |
| x = global_add_pool(x, batch) | |
| x = F.relu(self.fc1_xd(x)) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| return x | |