Antoine1091's picture
Upload folder using huggingface_hub
49d2955 verified
"""
ISDNet decoder heads: ASPP, ISDHead, RefineASPPHead
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from .modules import ShallowNet, Lap_Pyramid_Conv
from ..utils import batch_mm_loop
class ASPPModule(nn.ModuleList):
"""Atrous Spatial Pyramid Pooling module."""
def __init__(self, dilations, in_ch, ch, conv_cfg, norm_cfg, act_cfg):
super().__init__([
ConvModule(
in_ch, ch,
1 if d == 1 else 3,
dilation=d,
padding=0 if d == 1 else d,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
for d in dilations
])
def forward(self, x):
return [m(x) for m in self]
class SegmentationHead(nn.Module):
"""Simple segmentation head with conv + classifier."""
def __init__(self, conv_cfg, norm_cfg, act_cfg, in_ch, mid_ch, n_classes, **kw):
super().__init__()
self.conv = ConvModule(in_ch, mid_ch, 3, 1, 1,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.out = nn.Conv2d(mid_ch, n_classes, 1, bias=True)
def forward(self, x):
return self.out(self.conv(x))
class SRDecoder(nn.Module):
"""Super-resolution decoder for feature alignment loss."""
def __init__(self, conv_cfg, norm_cfg, act_cfg, ch=128, up_lists=[2, 2, 2]):
super().__init__()
self.up1 = nn.Upsample(scale_factor=up_lists[0])
self.conv1 = ConvModule(ch, ch // 2, 3, 1, 1,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.up2 = nn.Upsample(scale_factor=up_lists[1])
self.conv2 = ConvModule(ch // 2, ch // 2, 3, 1, 1,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.up3 = nn.Upsample(scale_factor=up_lists[2])
self.conv3 = ConvModule(ch // 2, ch, 3, 1, 1,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.conv_sr = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, 3)
def forward(self, x, fa=False):
feats = self.conv3(self.up3(self.conv2(self.up2(self.conv1(self.up1(x))))))
if fa:
return feats, self.conv_sr(feats)
return self.conv_sr(feats)
class ChannelAtt(nn.Module):
"""Channel attention module."""
def __init__(self, in_ch, out_ch, conv_cfg, norm_cfg, act_cfg):
super().__init__()
self.conv = ConvModule(in_ch, out_ch, 3, 1, 1,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.conv1x1 = ConvModule(out_ch, out_ch, 1, 1, 0,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)
def forward(self, x):
feat = self.conv(x)
return feat, self.conv1x1(feat.mean(dim=(2, 3), keepdim=True))
class RelationAwareFusion(nn.Module):
"""
Relation-aware fusion module.
Fuses shallow (spatial) and deep (context) features using
cross-attention mechanism.
"""
def __init__(self, ch, conv_cfg, norm_cfg, act_cfg, ext=2, r=16):
super().__init__()
self.r = r
self.g1 = nn.Parameter(torch.zeros(1))
self.g2 = nn.Parameter(torch.zeros(1))
self.sp_mlp = nn.Sequential(
nn.Linear(ch * 2, ch),
nn.ReLU(),
nn.Linear(ch, ch)
)
self.sp_att = ChannelAtt(ch * ext, ch, conv_cfg, norm_cfg, act_cfg)
self.co_mlp = nn.Sequential(
nn.Linear(ch * 2, ch),
nn.ReLU(),
nn.Linear(ch, ch)
)
self.co_att = ChannelAtt(ch, ch, conv_cfg, norm_cfg, act_cfg)
self.co_head = ConvModule(ch, ch, 3, 1, 1,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.smooth = ConvModule(ch, ch, 3, 1, 1,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)
def forward(self, sp_feat, co_feat):
s_f, s_a = self.sp_att(sp_feat)
c_f, c_a = self.co_att(co_feat)
b, c = s_a.shape[:2]
# Use loop-based batch mm to avoid CUBLAS strided batched issues
s_a_reshaped = s_a.view(b, self.r, c // self.r)
c_a_reshaped = c_a.view(b, self.r, c // self.r).permute(0, 2, 1)
aff = batch_mm_loop(s_a_reshaped, c_a_reshaped).view(b, -1)
re_s = torch.sigmoid(s_a + self.g1 * F.relu(self.sp_mlp(aff)).unsqueeze(-1).unsqueeze(-1))
re_c = torch.sigmoid(c_a + self.g2 * F.relu(self.co_mlp(aff)).unsqueeze(-1).unsqueeze(-1))
c_f = self.co_head(
F.interpolate(c_f * re_c, s_f.shape[2:], mode='bilinear', align_corners=False)
)
return s_f, c_f, self.smooth(s_f * re_s + c_f)
class Reducer(nn.Module):
"""Channel reducer module."""
def __init__(self, in_ch=512, reduce=128):
super().__init__()
self.conv = nn.Conv2d(in_ch, reduce, 1, bias=False)
self.bn = nn.SyncBatchNorm(reduce)
def forward(self, x):
return F.relu(self.bn(self.conv(x)))
class ISDHead(nn.Module):
"""
ISD decoder head.
Combines shallow STDC features with deep backbone features
using relation-aware fusion at multiple scales.
"""
def __init__(self, in_ch, ch, num_classes, down_ratio, prev_ch,
conv_cfg=None, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU'),
dropout=0.1, reduce=False, stdc_pretrain=''):
super().__init__()
self.ch = ch
self.fuse8 = RelationAwareFusion(ch, conv_cfg, norm_cfg, act_cfg, ext=2)
self.fuse16 = RelationAwareFusion(ch, conv_cfg, norm_cfg, act_cfg, ext=4)
self.sr_dec = SRDecoder(conv_cfg, norm_cfg, act_cfg, ch, [4, 2, 2])
self.stdc = ShallowNet(in_channels=6, pretrain_model=stdc_pretrain)
self.lap = Lap_Pyramid_Conv(num_high=2)
self.seg_aux16 = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
self.seg_aux8 = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
self.seg = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
self.reduce = Reducer() if reduce else None
self.drop = nn.Dropout2d(dropout) if dropout > 0 else None
def forward(self, inputs, prev_output, train_flag=True):
# Laplacian pyramid decomposition
pyr = self.lap.pyramid_decom(inputs)
pyr1_up = F.interpolate(pyr[1], pyr[0].shape[2:], mode='bilinear', align_corners=False)
high_in = torch.cat([pyr[0], pyr1_up], dim=1)
# Shallow features
s8, s16 = self.stdc(high_in)
# Deep features
deep = self.reduce(prev_output[0]) if self.reduce else prev_output[0]
# Multi-scale fusion
_, a16, f16 = self.fuse16(s16, deep)
_, a8, f8 = self.fuse8(s8, f16)
# Segmentation output
out = self.seg(self.drop(f8) if self.drop else f8)
if train_flag:
feats, sr_out = self.sr_dec(deep, True)
target = pyr[0] + pyr1_up
if sr_out.shape[2:] != target.shape[2:]:
sr_out = F.interpolate(sr_out, target.shape[2:], mode='bilinear', align_corners=False)
return (out,
self.seg_aux16(a8),
self.seg_aux8(a16),
{'recon_losses': F.mse_loss(sr_out, target) * 0.1},
{'fa_loss': self._fa(deep, feats)})
return out
def _fa(self, seg_f, sr_f, eps=1e-6):
"""Feature alignment loss."""
if seg_f.shape[2:] != sr_f.shape[2:]:
sr_f = F.interpolate(sr_f, seg_f.shape[2:], mode='bilinear', align_corners=False)
sf = torch.flatten(seg_f, 2)
srf = torch.flatten(sr_f, 2)
sf = sf / (sf.norm(p=2, dim=2, keepdim=True) + eps)
srf = srf / (srf.norm(p=2, dim=2, keepdim=True) + eps)
# Use loop-based batch mm for CUBLAS compatibility
sf_t = sf.permute(0, 2, 1)
srf_t = srf.permute(0, 2, 1)
return F.l1_loss(batch_mm_loop(sf_t, sf), batch_mm_loop(srf_t, srf).detach())
class RefineASPPHead(nn.Module):
"""
ASPP-based decoder head for deep path.
Processes low-resolution backbone features with
atrous spatial pyramid pooling.
"""
def __init__(self, in_ch, ch, num_classes, dilations=(1, 12, 24, 36),
conv_cfg=None, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU'),
dropout=0.1, in_index=-1):
super().__init__()
self.in_index = in_index
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvModule(in_ch, ch, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
)
self.aspp = ASPPModule(dilations, in_ch, ch, conv_cfg, norm_cfg, act_cfg)
self.bottle = ConvModule(
(len(dilations) + 1) * ch, ch, 3, padding=1,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
)
self.seg = nn.Conv2d(ch, num_classes, 1)
self.drop = nn.Dropout2d(dropout) if dropout > 0 else None
def forward(self, inputs):
x = inputs[self.in_index] if isinstance(inputs, (list, tuple)) else inputs
outs = [F.interpolate(self.pool(x), x.shape[2:], mode='bilinear', align_corners=False)]
outs.extend(self.aspp(x))
feat = self.bottle(torch.cat(outs, dim=1))
return self.seg(self.drop(feat) if self.drop else feat), [feat]