| import torch.nn as nn |
|
|
| __all__ = ['SharedMLP'] |
|
|
|
|
| class SharedMLP(nn.Module): |
| def __init__(self, in_channels, out_channels, dim=1, device='cuda'): |
| super().__init__() |
| |
| if dim == 1: |
| conv = nn.Conv1d |
| bn = nn.InstanceNorm1d |
| elif dim == 2: |
| conv = nn.Conv2d |
| bn = nn.InstanceNorm1d |
| else: |
| raise ValueError |
| if not isinstance(out_channels, (list, tuple)): |
| out_channels = [out_channels] |
| layers = [] |
| for oc in out_channels: |
| layers.extend( |
| [ |
| conv(in_channels, oc, 1, device=device), |
| bn(oc, device=device), |
| nn.ReLU(True), |
| ]) |
| in_channels = oc |
| self.layers = nn.Sequential(*layers) |
|
|
| def forward(self, inputs): |
| if isinstance(inputs, (list, tuple)): |
| return (self.layers(inputs[0]), *inputs[1:]) |
| else: |
| return self.layers(inputs) |
|
|