Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from .torch_nn import BasicConv, batched_index_select | |
| from .torch_edge import DenseDilatedKnnGraph, DilatedKnnGraph | |
| import torch.nn.functional as F | |
| class MRConv2d(nn.Module): | |
| """ | |
| Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type | |
| """ | |
| def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): | |
| super(MRConv2d, self).__init__() | |
| self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias) | |
| def forward(self, x, edge_index): | |
| x_i = batched_index_select(x, edge_index[1]) | |
| x_j = batched_index_select(x, edge_index[0]) | |
| x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) | |
| return self.nn(torch.cat([x, x_j], dim=1)) | |
| class EdgeConv2d(nn.Module): | |
| """ | |
| Edge convolution layer (with activation, batch normalization) for dense data type | |
| """ | |
| def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): | |
| super(EdgeConv2d, self).__init__() | |
| self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias) | |
| def forward(self, x, edge_index): | |
| x_i = batched_index_select(x, edge_index[1]) | |
| x_j = batched_index_select(x, edge_index[0]) | |
| max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) | |
| return max_value | |
| class GraphConv2d(nn.Module): | |
| """ | |
| Static graph convolution layer | |
| """ | |
| def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True): | |
| super(GraphConv2d, self).__init__() | |
| if conv == 'edge': | |
| self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias) | |
| elif conv == 'mr': | |
| self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias) | |
| else: | |
| raise NotImplementedError('conv:{} is not supported'.format(conv)) | |
| def forward(self, x, edge_index): | |
| return self.gconv(x, edge_index) | |
| class DynConv2d(GraphConv2d): | |
| """ | |
| Dynamic graph convolution layer | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', | |
| norm=None, bias=True, stochastic=False, epsilon=0.0, knn='matrix'): | |
| super(DynConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias) | |
| self.k = kernel_size | |
| self.d = dilation | |
| if knn == 'matrix': | |
| self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) | |
| else: | |
| self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) | |
| def forward(self, x): | |
| edge_index = self.dilated_knn_graph(x) | |
| return super(DynConv2d, self).forward(x, edge_index) | |
| class PlainDynBlock2d(nn.Module): | |
| """ | |
| Plain Dynamic graph convolution block | |
| """ | |
| def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, | |
| bias=True, stochastic=False, epsilon=0.0, knn='matrix'): | |
| super(PlainDynBlock2d, self).__init__() | |
| self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, | |
| act, norm, bias, stochastic, epsilon, knn) | |
| def forward(self, x): | |
| return self.body(x) | |
| class ResDynBlock2d(nn.Module): | |
| """ | |
| Residual Dynamic graph convolution block | |
| """ | |
| def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, | |
| bias=True, stochastic=False, epsilon=0.0, knn='matrix', res_scale=1): | |
| super(ResDynBlock2d, self).__init__() | |
| self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, | |
| act, norm, bias, stochastic, epsilon, knn) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| return self.body(x) + x*self.res_scale | |
| class DenseDynBlock2d(nn.Module): | |
| """ | |
| Dense Dynamic graph convolution block | |
| """ | |
| def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', | |
| act='relu', norm=None,bias=True, stochastic=False, epsilon=0.0, knn='matrix'): | |
| super(DenseDynBlock2d, self).__init__() | |
| self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv, | |
| act, norm, bias, stochastic, epsilon, knn) | |
| def forward(self, x): | |
| dense = self.body(x) | |
| return torch.cat((x, dense), 1) | |