|
|
import torch |
|
|
from torch import nn |
|
|
from addict import Dict |
|
|
|
|
|
from rscd.models.decoderheads.pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder4ScalesFASeg |
|
|
from rscd.models.decoderheads.transformer_decoder import MultiScaleMaskedTransformerDecoder_OurDH_v4,MultiScaleMaskedTransformerDecoder_OurDH_v5 |
|
|
|
|
|
from torch.nn import functional as F |
|
|
|
|
|
class MaskFormerHead(nn.Module): |
|
|
def __init__(self, input_shape, |
|
|
num_classes = 1, |
|
|
num_queries = 10, |
|
|
dec_layers = 10 |
|
|
): |
|
|
super().__init__() |
|
|
self.num_classes = num_classes |
|
|
self.num_queries = num_queries |
|
|
self.dec_layers = dec_layers |
|
|
self.pixel_decoder = self.pixel_decoder_init(input_shape) |
|
|
self.predictor = self.predictor_init() |
|
|
|
|
|
def pixel_decoder_init(self, input_shape): |
|
|
common_stride = 4 |
|
|
transformer_dropout = 0 |
|
|
transformer_nheads = 8 |
|
|
transformer_dim_feedforward = 1024 |
|
|
transformer_enc_layers = 4 |
|
|
conv_dim = 256 |
|
|
mask_dim = 256 |
|
|
transformer_in_features = ["res3", "res4", "res5"] |
|
|
|
|
|
pixel_decoder = MSDeformAttnPixelDecoder4ScalesFASeg(input_shape, |
|
|
transformer_dropout, |
|
|
transformer_nheads, |
|
|
transformer_dim_feedforward, |
|
|
transformer_enc_layers, |
|
|
conv_dim, |
|
|
mask_dim, |
|
|
transformer_in_features, |
|
|
common_stride) |
|
|
return pixel_decoder |
|
|
|
|
|
def predictor_init(self): |
|
|
in_channels = 256 |
|
|
num_classes = self.num_classes |
|
|
hidden_dim = 256 |
|
|
num_queries = self.num_queries |
|
|
nheads = 8 |
|
|
dim_feedforward = 1024 |
|
|
dec_layers = self.dec_layers - 1 |
|
|
pre_norm = False |
|
|
mask_dim = 256 |
|
|
enforce_input_project = False |
|
|
mask_classification = True |
|
|
predictor = MultiScaleMaskedTransformerDecoder_OurDH_v5(in_channels, |
|
|
num_classes, |
|
|
mask_classification, |
|
|
hidden_dim, |
|
|
num_queries, |
|
|
nheads, |
|
|
dim_feedforward, |
|
|
dec_layers, |
|
|
pre_norm, |
|
|
mask_dim, |
|
|
enforce_input_project) |
|
|
return predictor |
|
|
|
|
|
def forward(self, features, mask=None): |
|
|
mask_features, transformer_encoder_features, multi_scale_features, pos_list_2d = self.pixel_decoder.forward_features(features) |
|
|
predictions = self.predictor(multi_scale_features, mask_features, mask, pos_list_2d) |
|
|
return predictions |
|
|
|
|
|
def dsconv_3x3(in_channel, out_channel): |
|
|
return nn.Sequential( |
|
|
nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel), |
|
|
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1), |
|
|
nn.BatchNorm2d(out_channel), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
class SaELayer(nn.Module): |
|
|
def __init__(self, in_channel, reduction=32): |
|
|
super(SaELayer, self).__init__() |
|
|
assert in_channel>=reduction and in_channel%reduction==0,'invalid in_channel in SaElayer' |
|
|
self.reduction = reduction |
|
|
self.cardinality=4 |
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
self.fc1 = nn.Sequential( |
|
|
nn.Linear(in_channel,in_channel//self.reduction, bias=False), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
self.fc2 = nn.Sequential( |
|
|
nn.Linear(in_channel, in_channel // self.reduction, bias=False), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
self.fc3 = nn.Sequential( |
|
|
nn.Linear(in_channel, in_channel // self.reduction, bias=False), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
self.fc4 = nn.Sequential( |
|
|
nn.Linear(in_channel, in_channel // self.reduction, bias=False), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(in_channel//self.reduction*self.cardinality, in_channel, bias=False), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
b, c, _, _ = x.size() |
|
|
y = self.avg_pool(x).view(b, c) |
|
|
y1 = self.fc1(y) |
|
|
y2 = self.fc2(y) |
|
|
y3 = self.fc3(y) |
|
|
y4 = self.fc4(y) |
|
|
y_concate = torch.cat([y1,y2,y3,y4],dim=1) |
|
|
y_ex_dim = self.fc(y_concate).view(b,c,1,1) |
|
|
|
|
|
return y_ex_dim.expand_as(x) |
|
|
|
|
|
class TFF(nn.Module): |
|
|
def __init__(self, in_channel, out_channel): |
|
|
super(TFF, self).__init__() |
|
|
self.catconvA = dsconv_3x3(in_channel * 2, in_channel) |
|
|
self.catconvB = dsconv_3x3(in_channel * 2, in_channel) |
|
|
self.catconv = dsconv_3x3(in_channel * 2, out_channel) |
|
|
self.convA = nn.Conv2d(in_channel, 1, 1) |
|
|
self.convB = nn.Conv2d(in_channel, 1, 1) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
self.senetv2 = SaELayer(in_channel) |
|
|
|
|
|
def forward(self, xA, xB): |
|
|
x_diff = xA - xB |
|
|
x_weight = self.senetv2(x_diff) |
|
|
|
|
|
x_diffA = self.catconvA(torch.cat([x_diff, xA], dim=1)) |
|
|
x_diffB = self.catconvB(torch.cat([x_diff, xB], dim=1)) |
|
|
|
|
|
A_weight = self.sigmoid(self.convA(x_diffA)) |
|
|
B_weight = self.sigmoid(self.convB(x_diffB)) |
|
|
|
|
|
xA = A_weight * xA * x_weight |
|
|
xB = B_weight * xB * x_weight |
|
|
|
|
|
x = self.catconv(torch.cat([xA, xB], dim=1)) |
|
|
|
|
|
return x |
|
|
|
|
|
class MaskFormerModel_sea_ourDH(nn.Module): |
|
|
def __init__(self, channels, |
|
|
num_classes = 1, |
|
|
num_queries = 10, |
|
|
dec_layers = 14): |
|
|
super().__init__() |
|
|
self.channels = channels |
|
|
self.backbone_feature_shape = dict() |
|
|
for i, channel in enumerate(self.channels): |
|
|
self.backbone_feature_shape[f'res{i+2}'] = Dict({'channel': channel, 'stride': 2**(i+2)}) |
|
|
|
|
|
self.tff1 = TFF(self.channels[0], self.channels[0]) |
|
|
self.tff2 = TFF(self.channels[1], self.channels[1]) |
|
|
self.tff3 = TFF(self.channels[2], self.channels[2]) |
|
|
self.tff4 = TFF(self.channels[3], self.channels[3]) |
|
|
|
|
|
self.sem_seg_head = MaskFormerHead(self.backbone_feature_shape, num_classes, num_queries, dec_layers) |
|
|
|
|
|
def semantic_inference(self, mask_cls, mask_pred): |
|
|
|
|
|
mask_cls = F.softmax(mask_cls, dim=-1)[...,1:] |
|
|
mask_pred = mask_pred.sigmoid() |
|
|
semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred).detach() |
|
|
b, c, h, w = semseg.shape |
|
|
for i in range(b): |
|
|
for j in range(c): |
|
|
minval = semseg[i, j].min() |
|
|
maxval = semseg[i, j].max() |
|
|
semseg[i, j] = (semseg[i, j] - minval) / (maxval - minval) |
|
|
return semseg |
|
|
|
|
|
def forward(self, inputs): |
|
|
featuresA, featuresB =inputs |
|
|
features = [self.tff1(featuresA[0], featuresB[0]), |
|
|
self.tff2(featuresA[1], featuresB[1]), |
|
|
self.tff3(featuresA[2], featuresB[2]), |
|
|
self.tff4(featuresA[3], featuresB[3]),] |
|
|
features = { |
|
|
'res2': features[0], |
|
|
'res3': features[1], |
|
|
'res4': features[2], |
|
|
'res5': features[3] |
|
|
} |
|
|
|
|
|
outputs = self.sem_seg_head(features) |
|
|
|
|
|
mask_cls_results = outputs["pred_logits"] |
|
|
mask_pred_results = outputs["pred_masks"] |
|
|
|
|
|
mask_pred_results = F.interpolate( |
|
|
mask_pred_results, |
|
|
scale_factor=(4,4), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
pred_masks = self.semantic_inference(mask_cls_results, mask_pred_results) |
|
|
|
|
|
return [pred_masks, outputs] |
|
|
|
|
|
|