| | import torch.nn as nn |
| |
|
| | __all__ = ['SharedMLP'] |
| |
|
| |
|
| | class SharedMLP(nn.Module): |
| | def __init__(self, in_channels, out_channels, dim=1, device='cuda'): |
| | super().__init__() |
| | |
| | if dim == 1: |
| | conv = nn.Conv1d |
| | bn = nn.InstanceNorm1d |
| | elif dim == 2: |
| | conv = nn.Conv2d |
| | bn = nn.InstanceNorm1d |
| | else: |
| | raise ValueError |
| | if not isinstance(out_channels, (list, tuple)): |
| | out_channels = [out_channels] |
| | layers = [] |
| | for oc in out_channels: |
| | layers.extend( |
| | [ |
| | conv(in_channels, oc, 1, device=device), |
| | bn(oc, device=device), |
| | nn.ReLU(True), |
| | ]) |
| | in_channels = oc |
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def forward(self, inputs): |
| | if isinstance(inputs, (list, tuple)): |
| | return (self.layers(inputs[0]), *inputs[1:]) |
| | else: |
| | return self.layers(inputs) |
| |
|