| """ |
| ISDNet decoder heads: ASPP, ISDHead, RefineASPPHead |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmcv.cnn import ConvModule |
|
|
| from .modules import ShallowNet, Lap_Pyramid_Conv |
| from ..utils import batch_mm_loop |
|
|
|
|
| class ASPPModule(nn.ModuleList): |
| """Atrous Spatial Pyramid Pooling module.""" |
|
|
| def __init__(self, dilations, in_ch, ch, conv_cfg, norm_cfg, act_cfg): |
| super().__init__([ |
| ConvModule( |
| in_ch, ch, |
| 1 if d == 1 else 3, |
| dilation=d, |
| padding=0 if d == 1 else d, |
| conv_cfg=conv_cfg, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg |
| ) |
| for d in dilations |
| ]) |
|
|
| def forward(self, x): |
| return [m(x) for m in self] |
|
|
|
|
| class SegmentationHead(nn.Module): |
| """Simple segmentation head with conv + classifier.""" |
|
|
| def __init__(self, conv_cfg, norm_cfg, act_cfg, in_ch, mid_ch, n_classes, **kw): |
| super().__init__() |
| self.conv = ConvModule(in_ch, mid_ch, 3, 1, 1, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| self.out = nn.Conv2d(mid_ch, n_classes, 1, bias=True) |
|
|
| def forward(self, x): |
| return self.out(self.conv(x)) |
|
|
|
|
| class SRDecoder(nn.Module): |
| """Super-resolution decoder for feature alignment loss.""" |
|
|
| def __init__(self, conv_cfg, norm_cfg, act_cfg, ch=128, up_lists=[2, 2, 2]): |
| super().__init__() |
| self.up1 = nn.Upsample(scale_factor=up_lists[0]) |
| self.conv1 = ConvModule(ch, ch // 2, 3, 1, 1, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| self.up2 = nn.Upsample(scale_factor=up_lists[1]) |
| self.conv2 = ConvModule(ch // 2, ch // 2, 3, 1, 1, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| self.up3 = nn.Upsample(scale_factor=up_lists[2]) |
| self.conv3 = ConvModule(ch // 2, ch, 3, 1, 1, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| self.conv_sr = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, 3) |
|
|
| def forward(self, x, fa=False): |
| feats = self.conv3(self.up3(self.conv2(self.up2(self.conv1(self.up1(x)))))) |
| if fa: |
| return feats, self.conv_sr(feats) |
| return self.conv_sr(feats) |
|
|
|
|
| class ChannelAtt(nn.Module): |
| """Channel attention module.""" |
|
|
| def __init__(self, in_ch, out_ch, conv_cfg, norm_cfg, act_cfg): |
| super().__init__() |
| self.conv = ConvModule(in_ch, out_ch, 3, 1, 1, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| self.conv1x1 = ConvModule(out_ch, out_ch, 1, 1, 0, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) |
|
|
| def forward(self, x): |
| feat = self.conv(x) |
| return feat, self.conv1x1(feat.mean(dim=(2, 3), keepdim=True)) |
|
|
|
|
| class RelationAwareFusion(nn.Module): |
| """ |
| Relation-aware fusion module. |
| |
| Fuses shallow (spatial) and deep (context) features using |
| cross-attention mechanism. |
| """ |
|
|
| def __init__(self, ch, conv_cfg, norm_cfg, act_cfg, ext=2, r=16): |
| super().__init__() |
| self.r = r |
| self.g1 = nn.Parameter(torch.zeros(1)) |
| self.g2 = nn.Parameter(torch.zeros(1)) |
| self.sp_mlp = nn.Sequential( |
| nn.Linear(ch * 2, ch), |
| nn.ReLU(), |
| nn.Linear(ch, ch) |
| ) |
| self.sp_att = ChannelAtt(ch * ext, ch, conv_cfg, norm_cfg, act_cfg) |
| self.co_mlp = nn.Sequential( |
| nn.Linear(ch * 2, ch), |
| nn.ReLU(), |
| nn.Linear(ch, ch) |
| ) |
| self.co_att = ChannelAtt(ch, ch, conv_cfg, norm_cfg, act_cfg) |
| self.co_head = ConvModule(ch, ch, 3, 1, 1, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| self.smooth = ConvModule(ch, ch, 3, 1, 1, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) |
|
|
| def forward(self, sp_feat, co_feat): |
| s_f, s_a = self.sp_att(sp_feat) |
| c_f, c_a = self.co_att(co_feat) |
| b, c = s_a.shape[:2] |
|
|
| |
| s_a_reshaped = s_a.view(b, self.r, c // self.r) |
| c_a_reshaped = c_a.view(b, self.r, c // self.r).permute(0, 2, 1) |
| aff = batch_mm_loop(s_a_reshaped, c_a_reshaped).view(b, -1) |
|
|
| re_s = torch.sigmoid(s_a + self.g1 * F.relu(self.sp_mlp(aff)).unsqueeze(-1).unsqueeze(-1)) |
| re_c = torch.sigmoid(c_a + self.g2 * F.relu(self.co_mlp(aff)).unsqueeze(-1).unsqueeze(-1)) |
|
|
| c_f = self.co_head( |
| F.interpolate(c_f * re_c, s_f.shape[2:], mode='bilinear', align_corners=False) |
| ) |
| return s_f, c_f, self.smooth(s_f * re_s + c_f) |
|
|
|
|
| class Reducer(nn.Module): |
| """Channel reducer module.""" |
|
|
| def __init__(self, in_ch=512, reduce=128): |
| super().__init__() |
| self.conv = nn.Conv2d(in_ch, reduce, 1, bias=False) |
| self.bn = nn.SyncBatchNorm(reduce) |
|
|
| def forward(self, x): |
| return F.relu(self.bn(self.conv(x))) |
|
|
|
|
| class ISDHead(nn.Module): |
| """ |
| ISD decoder head. |
| |
| Combines shallow STDC features with deep backbone features |
| using relation-aware fusion at multiple scales. |
| """ |
|
|
| def __init__(self, in_ch, ch, num_classes, down_ratio, prev_ch, |
| conv_cfg=None, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU'), |
| dropout=0.1, reduce=False, stdc_pretrain=''): |
| super().__init__() |
| self.ch = ch |
| self.fuse8 = RelationAwareFusion(ch, conv_cfg, norm_cfg, act_cfg, ext=2) |
| self.fuse16 = RelationAwareFusion(ch, conv_cfg, norm_cfg, act_cfg, ext=4) |
| self.sr_dec = SRDecoder(conv_cfg, norm_cfg, act_cfg, ch, [4, 2, 2]) |
| self.stdc = ShallowNet(in_channels=6, pretrain_model=stdc_pretrain) |
| self.lap = Lap_Pyramid_Conv(num_high=2) |
| self.seg_aux16 = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes) |
| self.seg_aux8 = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes) |
| self.seg = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes) |
| self.reduce = Reducer() if reduce else None |
| self.drop = nn.Dropout2d(dropout) if dropout > 0 else None |
|
|
| def forward(self, inputs, prev_output, train_flag=True): |
| |
| pyr = self.lap.pyramid_decom(inputs) |
| pyr1_up = F.interpolate(pyr[1], pyr[0].shape[2:], mode='bilinear', align_corners=False) |
| high_in = torch.cat([pyr[0], pyr1_up], dim=1) |
|
|
| |
| s8, s16 = self.stdc(high_in) |
|
|
| |
| deep = self.reduce(prev_output[0]) if self.reduce else prev_output[0] |
|
|
| |
| _, a16, f16 = self.fuse16(s16, deep) |
| _, a8, f8 = self.fuse8(s8, f16) |
|
|
| |
| out = self.seg(self.drop(f8) if self.drop else f8) |
|
|
| if train_flag: |
| feats, sr_out = self.sr_dec(deep, True) |
| target = pyr[0] + pyr1_up |
| if sr_out.shape[2:] != target.shape[2:]: |
| sr_out = F.interpolate(sr_out, target.shape[2:], mode='bilinear', align_corners=False) |
| return (out, |
| self.seg_aux16(a8), |
| self.seg_aux8(a16), |
| {'recon_losses': F.mse_loss(sr_out, target) * 0.1}, |
| {'fa_loss': self._fa(deep, feats)}) |
| return out |
|
|
| def _fa(self, seg_f, sr_f, eps=1e-6): |
| """Feature alignment loss.""" |
| if seg_f.shape[2:] != sr_f.shape[2:]: |
| sr_f = F.interpolate(sr_f, seg_f.shape[2:], mode='bilinear', align_corners=False) |
| sf = torch.flatten(seg_f, 2) |
| srf = torch.flatten(sr_f, 2) |
| sf = sf / (sf.norm(p=2, dim=2, keepdim=True) + eps) |
| srf = srf / (srf.norm(p=2, dim=2, keepdim=True) + eps) |
| |
| sf_t = sf.permute(0, 2, 1) |
| srf_t = srf.permute(0, 2, 1) |
| return F.l1_loss(batch_mm_loop(sf_t, sf), batch_mm_loop(srf_t, srf).detach()) |
|
|
|
|
| class RefineASPPHead(nn.Module): |
| """ |
| ASPP-based decoder head for deep path. |
| |
| Processes low-resolution backbone features with |
| atrous spatial pyramid pooling. |
| """ |
|
|
| def __init__(self, in_ch, ch, num_classes, dilations=(1, 12, 24, 36), |
| conv_cfg=None, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU'), |
| dropout=0.1, in_index=-1): |
| super().__init__() |
| self.in_index = in_index |
| self.pool = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), |
| ConvModule(in_ch, ch, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| ) |
| self.aspp = ASPPModule(dilations, in_ch, ch, conv_cfg, norm_cfg, act_cfg) |
| self.bottle = ConvModule( |
| (len(dilations) + 1) * ch, ch, 3, padding=1, |
| conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg |
| ) |
| self.seg = nn.Conv2d(ch, num_classes, 1) |
| self.drop = nn.Dropout2d(dropout) if dropout > 0 else None |
|
|
| def forward(self, inputs): |
| x = inputs[self.in_index] if isinstance(inputs, (list, tuple)) else inputs |
| outs = [F.interpolate(self.pool(x), x.shape[2:], mode='bilinear', align_corners=False)] |
| outs.extend(self.aspp(x)) |
| feat = self.bottle(torch.cat(outs, dim=1)) |
| return self.seg(self.drop(feat) if self.drop else feat), [feat] |
|
|