| |
|
|
| import torch |
| import torch.nn as nn |
| import pytorch_lightning as pl |
|
|
|
|
| class MLP(pl.LightningModule): |
| def __init__(self, |
| filter_channels, |
| name=None, |
| res_layers=[], |
| norm='group', |
| last_op=None): |
|
|
| super(MLP, self).__init__() |
|
|
| self.filters = nn.ModuleList() |
| self.norms = nn.ModuleList() |
| self.res_layers = res_layers |
| self.norm = norm |
| self.last_op = last_op |
| self.name = name |
| self.activate = nn.LeakyReLU(inplace=True) |
|
|
| for l in range(0, len(filter_channels) - 1): |
| if l in self.res_layers: |
| self.filters.append( |
| nn.Conv1d(filter_channels[l] + filter_channels[0], |
| filter_channels[l + 1], 1)) |
| else: |
| self.filters.append( |
| nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) |
|
|
| if l != len(filter_channels) - 2: |
| if norm == 'group': |
| self.norms.append(nn.GroupNorm(32, filter_channels[l + 1])) |
| elif norm == 'batch': |
| self.norms.append(nn.BatchNorm1d(filter_channels[l + 1])) |
| elif norm == 'instance': |
| self.norms.append(nn.InstanceNorm1d(filter_channels[l + |
| 1])) |
| elif norm == 'weight': |
| self.filters[l] = nn.utils.weight_norm(self.filters[l], |
| name='weight') |
| |
| |
|
|
| def forward(self, feature): |
| ''' |
| feature may include multiple view inputs |
| args: |
| feature: [B, C_in, N] |
| return: |
| [B, C_out, N] prediction |
| ''' |
| y = feature |
| tmpy = feature |
|
|
| for i, f in enumerate(self.filters): |
|
|
| y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1)) |
| if i != len(self.filters) - 1: |
| if self.norm not in ['batch', 'group', 'instance']: |
| y = self.activate(y) |
| else: |
| y = self.activate(self.norms[i](y)) |
|
|
| if self.last_op is not None: |
| y = self.last_op(y) |
|
|
| return y |
|
|