import torch import torch.nn as nn import torch.nn.functional as F from resnet import Resnet18 # Ensure that the Resnet18 class is correctly defined in this module class ConvBNReLU(nn.Module): def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): super(ConvBNReLU, self).__init__() self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm2d(out_chan) self.init_weight() def forward(self, x): x = self.conv(x) x = F.relu(self.bn(x)) return x def init_weight(self): nn.init.kaiming_normal_(self.conv.weight, a=1) if self.conv.bias is not None: nn.init.constant_(self.conv.bias, 0) class BiSeNetOutput(nn.Module): def __init__(self, in_chan, mid_chan, n_classes): super(BiSeNetOutput, self).__init__() self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) self.init_weight() def forward(self, x): x = self.conv(x) x = self.conv_out(x) return x def init_weight(self): nn.init.kaiming_normal_(self.conv_out.weight, a=1) if self.conv_out.bias is not None: nn.init.constant_(self.conv_out.bias, 0) def get_params(self): wd_params = [self.conv_out.weight] nowd_params = [] if self.conv_out.bias is not None: nowd_params.append(self.conv_out.bias) return wd_params, nowd_params class AttentionRefinementModule(nn.Module): def __init__(self, in_chan, out_chan): super(AttentionRefinementModule, self).__init__() self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) self.bn_atten = nn.BatchNorm2d(out_chan) self.sigmoid_atten = nn.Sigmoid() self.init_weight() def forward(self, x): feat = self.conv(x) atten = F.avg_pool2d(feat, feat.size()[2:]) atten = self.conv_atten(atten) atten = self.bn_atten(atten) atten = self.sigmoid_atten(atten) out = torch.mul(feat, atten) return out def init_weight(self): nn.init.kaiming_normal_(self.conv_atten.weight, a=1) if self.conv_atten.bias is not None: nn.init.constant_(self.conv_atten.bias, 0) class ContextPath(nn.Module): def __init__(self): super(ContextPath, self).__init__() self.resnet = Resnet18() self.arm16 = AttentionRefinementModule(256, 128) self.arm32 = AttentionRefinementModule(512, 128) self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) def forward(self, x): H0, W0 = x.size()[2:] feat8, feat16, feat32 = self.resnet(x) H8, W8 = feat8.size()[2:] H16, W16 = feat16.size()[2:] H32, W32 = feat32.size()[2:] avg = F.avg_pool2d(feat32, feat32.size()[2:]) avg = self.conv_avg(avg) avg_up = F.interpolate(avg, (H32, W32), mode='nearest') feat32_arm = self.arm32(feat32) feat32_sum = feat32_arm + avg_up feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') feat32_up = self.conv_head32(feat32_up) feat16_arm = self.arm16(feat16) feat16_sum = feat16_arm + feat32_up feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') feat16_up = self.conv_head16(feat16_up) return feat8, feat16_up, feat32_up def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if ly.bias is not None: nn.init.constant_(ly.bias, 0) class FeatureFusionModule(nn.Module): def __init__(self, in_chan, out_chan): super(FeatureFusionModule, self).__init__() self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) self.relu = nn.ReLU(inplace=True) self.sigmoid = nn.Sigmoid() self.init_weight() def forward(self, fsp, fcp): fcat = torch.cat([fsp, fcp], dim=1) feat = self.convblk(fcat) atten = F.avg_pool2d(feat, feat.size()[2:]) atten = self.conv1(atten) atten = self.relu(atten) atten = self.conv2(atten) atten = self.sigmoid(atten) feat_atten = torch.mul(feat, atten) feat_out = feat_atten + feat return feat_out def init_weight(self): nn.init.kaiming_normal_(self.conv1.weight, a=1) if self.conv1.bias is not None: nn.init.constant_(self.conv1.bias, 0) nn.init.kaiming_normal_(self.conv2.weight, a=1) if self.conv2.bias is not None: nn.init.constant_(self.conv2.bias, 0) class BiSeNet(nn.Module): def __init__(self, n_classes): super(BiSeNet, self).__init__() self.cp = ContextPath() self.ffm = FeatureFusionModule(256, 256) self.conv_out = BiSeNetOutput(256, 256, n_classes) self.conv_out16 = BiSeNetOutput(128, 64, n_classes) self.conv_out32 = BiSeNetOutput(128, 64, n_classes) def forward(self, x): H, W = x.size()[2:] feat_res8, feat_cp8, feat_cp16 = self.cp(x) feat_sp = feat_res8 # Using res3b1 feature as spatial path feature feat_fuse = self.ffm(feat_sp, feat_cp8) feat_out = self.conv_out(feat_fuse) feat_out16 = self.conv_out16(feat_cp8) feat_out32 = self.conv_out32(feat_cp16) feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) return feat_out, feat_out16, feat_out32 def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if ly.bias is not None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] for name, child in