| import torch |
| import torch.nn as nn |
| class MultitaskHead(nn.Module): |
| def __init__(self, input_channels, num_class, head_size): |
| super(MultitaskHead, self).__init__() |
|
|
| m = int(input_channels / 4) |
| heads = [] |
| for output_channels in sum(head_size, []): |
| heads.append( |
| nn.Sequential( |
| nn.Conv2d(input_channels, m, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(m, output_channels, kernel_size=1), |
| ) |
| ) |
| self.heads = nn.ModuleList(heads) |
| assert num_class == sum(sum(head_size, [])) |
|
|
| def forward(self, x): |
| |
| return torch.cat([head(x) for head in self.heads], dim=1) |
|
|
|
|
| class AngleDistanceHead(nn.Module): |
| def __init__(self, input_channels, num_class, head_size): |
| super(AngleDistanceHead, self).__init__() |
|
|
| m = int(input_channels/4) |
|
|
| heads = [] |
| for output_channels in sum(head_size, []): |
| if output_channels != 2: |
| heads.append( |
| nn.Sequential( |
| nn.Conv2d(input_channels, m, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(m, output_channels, kernel_size=1), |
| ) |
| ) |
| else: |
| heads.append( |
| nn.Sequential( |
| nn.Conv2d(input_channels, m, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| CosineSineLayer(m) |
| ) |
| ) |
| self.heads = nn.ModuleList(heads) |
| assert num_class == sum(sum(head_size, [])) |
| def forward(self, x): |
| return torch.cat([head(x) for head in self.heads], dim=1) |