| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class MLP(nn.Module): |
| def __init__(self, |
| filter_channels, |
| merge_layer=0, |
| res_layers=[], |
| norm='group', |
| last_op=None): |
| super(MLP, self).__init__() |
|
|
| self.filters = nn.ModuleList() |
| self.norms = nn.ModuleList() |
| self.merge_layer = merge_layer if merge_layer > 0 else len(filter_channels) // 2 |
| self.res_layers = res_layers |
| self.norm = norm |
| self.last_op = last_op |
|
|
| for l in range(0, len(filter_channels)-1): |
| if l in self.res_layers: |
| self.filters.append(nn.Conv1d( |
| filter_channels[l] + filter_channels[0], |
| filter_channels[l+1], |
| 1)) |
| else: |
| self.filters.append(nn.Conv1d( |
| filter_channels[l], |
| filter_channels[l+1], |
| 1)) |
| if l != len(filter_channels)-2: |
| if norm == 'group': |
| self.norms.append(nn.GroupNorm(32, filter_channels[l+1])) |
| elif norm == 'batch': |
| self.norms.append(nn.BatchNorm1d(filter_channels[l+1])) |
|
|
| def forward(self, feature): |
| ''' |
| feature may include multiple view inputs |
| args: |
| feature: [B, C_in, N] |
| return: |
| [B, C_out, N] prediction |
| ''' |
| y = feature |
| tmpy = feature |
| phi = None |
| for i, f in enumerate(self.filters): |
| y = f( |
| y if i not in self.res_layers |
| else torch.cat([y, tmpy], 1) |
| ) |
| if i != len(self.filters)-1: |
| if self.norm not in ['batch', 'group']: |
| y = F.leaky_relu(y) |
| else: |
| y = F.leaky_relu(self.norms[i](y)) |
| if i == self.merge_layer: |
| phi = y.clone() |
|
|
| if self.last_op is not None: |
| y = self.last_op(y) |
|
|
| return y, phi |
|
|