File size: 870 Bytes
5b9bb29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import torch
import torch.nn as nn
class MBDModule(nn.Module):
"""Multi-Branch Dilated Convolution Module"""
def __init__(self, in_channels, out_channels):
super(MBDModule, self).__init__()
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.dilated_convs = nn.ModuleList([
nn.Conv2d(out_channels, out_channels, kernel_size=3,
padding=d, dilation=d) for d in [1, 2, 4]
])
self.fusion = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1)
def forward(self, x):
x = self.pointwise(x)
dilated_outputs = []
for conv in self.dilated_convs:
dilated_outputs.append(conv(x))
x = torch.cat(dilated_outputs, dim=1)
x = self.fusion(x)
return x |