File size: 870 Bytes
5b9bb29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn

class MBDModule(nn.Module):
    """Multi-Branch Dilated Convolution Module"""
    def __init__(self, in_channels, out_channels):
        super(MBDModule, self).__init__()
        
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
        self.dilated_convs = nn.ModuleList([
            nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                     padding=d, dilation=d) for d in [1, 2, 4]
        ])
        
        self.fusion = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1)
        
    def forward(self, x):
        x = self.pointwise(x)
        
        dilated_outputs = []
        for conv in self.dilated_convs:
            dilated_outputs.append(conv(x))
        
        x = torch.cat(dilated_outputs, dim=1)
        x = self.fusion(x)
        
        return x