Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SurfaceClassifier(nn.Module): | |
| def __init__(self, filter_channels, num_views=1, no_residual=True, last_op=None): | |
| super(SurfaceClassifier, self).__init__() | |
| self.filters = [] | |
| self.num_views = num_views | |
| self.no_residual = no_residual | |
| filter_channels = filter_channels | |
| self.last_op = last_op | |
| if self.no_residual: | |
| for l in range(0, len(filter_channels) - 1): | |
| self.filters.append(nn.Conv1d( | |
| filter_channels[l], | |
| filter_channels[l + 1], | |
| 1)) | |
| self.add_module("conv%d" % l, self.filters[l]) | |
| else: | |
| for l in range(0, len(filter_channels) - 1): | |
| if 0 != l: | |
| self.filters.append( | |
| nn.Conv1d( | |
| filter_channels[l] + filter_channels[0], | |
| filter_channels[l + 1], | |
| 1)) | |
| else: | |
| self.filters.append(nn.Conv1d( | |
| filter_channels[l], | |
| filter_channels[l + 1], | |
| 1)) | |
| self.add_module("conv%d" % l, self.filters[l]) | |
| def forward(self, feature): | |
| ''' | |
| :param feature: list of [BxC_inxHxW] tensors of image features | |
| :param xy: [Bx3xN] tensor of (x,y) coodinates in the image plane | |
| :return: [BxC_outxN] tensor of features extracted at the coordinates | |
| ''' | |
| y = feature | |
| tmpy = feature | |
| for i, f in enumerate(self.filters): | |
| if self.no_residual: | |
| y = self._modules['conv' + str(i)](y) | |
| else: | |
| y = self._modules['conv' + str(i)]( | |
| y if i == 0 | |
| else torch.cat([y, tmpy], 1) | |
| ) | |
| if i != len(self.filters) - 1: | |
| y = F.leaky_relu(y) | |
| if self.num_views > 1 and i == len(self.filters) // 2: | |
| y = y.view( | |
| -1, self.num_views, y.shape[1], y.shape[2] | |
| ).mean(dim=1) | |
| tmpy = feature.view( | |
| -1, self.num_views, feature.shape[1], feature.shape[2] | |
| ).mean(dim=1) | |
| if self.last_op: | |
| y = self.last_op(y) | |
| return y | |