import torch import torch.nn as nn class MBDModule(nn.Module): """Multi-Branch Dilated Convolution Module""" def __init__(self, in_channels, out_channels): super(MBDModule, self).__init__() self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.dilated_convs = nn.ModuleList([ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=d, dilation=d) for d in [1, 2, 4] ]) self.fusion = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1) def forward(self, x): x = self.pointwise(x) dilated_outputs = [] for conv in self.dilated_convs: dilated_outputs.append(conv(x)) x = torch.cat(dilated_outputs, dim=1) x = self.fusion(x) return x