Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torch.nn import Sequential as Seq, Linear as Lin, Conv2d | |
| ############################## | |
| # Basic layers | |
| ############################## | |
| def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): | |
| # activation layer | |
| act = act.lower() | |
| if act == 'relu': | |
| layer = nn.ReLU(inplace) | |
| elif act == 'leakyrelu': | |
| layer = nn.LeakyReLU(neg_slope, inplace) | |
| elif act == 'prelu': | |
| layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) | |
| else: | |
| raise NotImplementedError('activation layer [%s] is not found' % act) | |
| return layer | |
| def norm_layer(norm, nc): | |
| # normalization layer 2d | |
| norm = norm.lower() | |
| if norm == 'batch': | |
| layer = nn.BatchNorm2d(nc, affine=True) | |
| elif norm == 'instance': | |
| layer = nn.InstanceNorm2d(nc, affine=False) | |
| else: | |
| raise NotImplementedError('normalization layer [%s] is not found' % norm) | |
| return layer | |
| class MLP(Seq): | |
| def __init__(self, channels, act='relu', norm=None, bias=True): | |
| m = [] | |
| for i in range(1, len(channels)): | |
| m.append(Lin(channels[i - 1], channels[i], bias)) | |
| if act is not None and act.lower() != 'none': | |
| m.append(act_layer(act)) | |
| if norm is not None and norm.lower() != 'none': | |
| m.append(norm_layer(norm, channels[-1])) | |
| super(MLP, self).__init__(*m) | |
| class BasicConv(Seq): | |
| def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): | |
| m = [] | |
| for i in range(1, len(channels)): | |
| m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias)) | |
| if act is not None and act.lower() != 'none': | |
| m.append(act_layer(act)) | |
| if norm is not None and norm.lower() != 'none': | |
| m.append(norm_layer(norm, channels[-1])) | |
| if drop > 0: | |
| m.append(nn.Dropout2d(drop)) | |
| super(BasicConv, self).__init__(*m) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| def batched_index_select(inputs, index): | |
| """ | |
| :param inputs: torch.Size([batch_size, num_dims, num_vertices, 1]) | |
| :param index: torch.Size([batch_size, num_vertices, k]) | |
| :return: torch.Size([batch_size, num_dims, num_vertices, k]) | |
| """ | |
| batch_size, num_dims, num_vertices, _ = inputs.shape | |
| k = index.shape[2] | |
| idx = torch.arange(0, batch_size) * num_vertices | |
| idx = idx.view(batch_size, -1) | |
| inputs = inputs.transpose(2, 1).contiguous().view(-1, num_dims) | |
| index = index.view(batch_size, -1) + idx.type(index.dtype).to(inputs.device) | |
| index = index.view(-1) | |
| return torch.index_select(inputs, 0, index).view(batch_size, -1, num_dims).transpose(2, 1).view(batch_size, num_dims, -1, k) | |