import torch import torch.nn as nn from torch.nn import init import torch.nn.functional as F from einops import rearrange from rscd.models.decoderheads.help_func import Transformer, TransformerDecoder, TwoLayerConv2d def init_weights(net, init_type='normal', init_gain=0.02): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function class Diff_map(torch.nn.Module): def __init__(self, input_nc, output_nc, output_sigmoid=False, if_upsample_2x=True): """ In the constructor we instantiate two nn.Linear modules and assign them as member variables. """ super(Diff_map, self).__init__() self.upsamplex2 = nn.Upsample(scale_factor=2) self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') self.classifier = TwoLayerConv2d(in_channels=input_nc, out_channels=output_nc) self.if_upsample_2x = if_upsample_2x self.output_sigmoid = output_sigmoid self.sigmoid = nn.Sigmoid() def forward(self, x12): x = torch.abs(x12[0] - x12[1]) if not self.if_upsample_2x: x = self.upsamplex2(x) x = self.upsamplex4(x) x = self.classifier(x) if self.output_sigmoid: x = self.sigmoid(x) return x class BASE_Transformer(Diff_map): """ Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN """ def __init__(self, input_nc, output_nc, with_pos, token_len=4, token_trans=True, enc_depth=1, dec_depth=1, dim_head=64, decoder_dim_head=64, tokenizer=True, if_upsample_2x=True, pool_mode='max', pool_size=2, decoder_softmax=True, with_decoder_pos=None, with_decoder=True): super(BASE_Transformer, self).__init__(input_nc, output_nc, if_upsample_2x=if_upsample_2x, ) self.token_len = token_len self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1, padding=0, bias=False) self.tokenizer = tokenizer if not self.tokenizer: # if not use tokenzier,then downsample the feature map into a certain size self.pooling_size = pool_size self.pool_mode = pool_mode self.token_len = self.pooling_size * self.pooling_size self.token_trans = token_trans self.with_decoder = with_decoder dim = 32 mlp_dim = 2*dim self.with_pos = with_pos if with_pos == 'learned': self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32)) decoder_pos_size = 256//4 self.with_decoder_pos = with_decoder_pos if self.with_decoder_pos == 'learned': self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32, decoder_pos_size, decoder_pos_size)) self.enc_depth = enc_depth self.dec_depth = dec_depth self.dim_head = dim_head self.decoder_dim_head = decoder_dim_head self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8, dim_head=self.dim_head, mlp_dim=mlp_dim, dropout=0) self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth, heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, softmax=decoder_softmax) def _forward_semantic_tokens(self, x): b, c, h, w = x.shape spatial_attention = self.conv_a(x) spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() spatial_attention = torch.softmax(spatial_attention, dim=-1) x = x.view([b, c, -1]).contiguous() tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) return tokens def _forward_reshape_tokens(self, x): # b,c,h,w = x.shape if self.pool_mode == 'max': x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size]) elif self.pool_mode == 'ave': x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size]) else: x = x tokens = rearrange(x, 'b c h w -> b (h w) c') return tokens def _forward_transformer(self, x): if self.with_pos: x += self.pos_embedding x = self.transformer(x) return x def _forward_transformer_decoder(self, x, m): b, c, h, w = x.shape if self.with_decoder_pos == 'fix': x = x + self.pos_embedding_decoder elif self.with_decoder_pos == 'learned': x = x + self.pos_embedding_decoder x = rearrange(x, 'b c h w -> b (h w) c') x = self.transformer_decoder(x, m) x = rearrange(x, 'b (h w) c -> b c h w', h=h) return x def _forward_simple_decoder(self, x, m): b, c, h, w = x.shape b, l, c = m.shape m = m.expand([h,w,b,l,c]) m = rearrange(m, 'h w b l c -> l b c h w') m = m.sum(0) x = x + m return x def forward(self, x12): x1, x2 = x12[0], x12[1] # forward tokenzier if self.tokenizer: token1 = self._forward_semantic_tokens(x1) token2 = self._forward_semantic_tokens(x2) else: token1 = self._forward_reshape_tokens(x1) token2 = self._forward_reshape_tokens(x2) # forward transformer encoder if self.token_trans: self.tokens_ = torch.cat([token1, token2], dim=1) self.tokens = self._forward_transformer(self.tokens_) token1, token2 = self.tokens.chunk(2, dim=1) # forward transformer decoder if self.with_decoder: x1 = self._forward_transformer_decoder(x1, token1) x2 = self._forward_transformer_decoder(x2, token2) else: x1 = self._forward_simple_decoder(x1, token1) x2 = self._forward_simple_decoder(x2, token2) # feature differencing x = torch.abs(x1 - x2) if not self.if_upsample_2x: x = self.upsamplex2(x) x = self.upsamplex4(x) # forward small cnn x = self.classifier(x) if self.output_sigmoid: x = self.sigmoid(x) return x def base_resnet18(cfg): net = Diff_map(input_nc=cfg.input_nc, output_nc=cfg.output_nc, output_sigmoid=cfg.output_sigmoid) init_weights(net, cfg.init_type, init_gain=cfg.init_gain) return net def base_transformer_pos_s4(cfg): net = BASE_Transformer(input_nc=cfg.input_nc, output_nc=cfg.output_nc, token_len=cfg.token_len, with_pos=cfg.with_pos) init_weights(net, cfg.init_type, init_gain=cfg.init_gain) return net def base_transformer_pos_s4_dd8(cfg): net = BASE_Transformer(input_nc=cfg.input_nc, output_nc=cfg.output_nc, token_len=cfg.token_len, with_pos=cfg.with_pos, enc_depth=cfg.enc_depth, dec_depth=cfg.dec_depth) init_weights(net, cfg.init_type, init_gain=cfg.init_gain) return net def base_transformer_pos_s4_dd8_dedim8(cfg): net = BASE_Transformer(input_nc=cfg.input_nc, output_nc=cfg.output_nc, token_len=cfg.token_len, with_pos=cfg.with_pos, enc_depth=cfg.enc_depth, dec_depth=cfg.dec_depth, dim_head=cfg.dim_head, decoder_dim_head=cfg.decoder_dim_head) init_weights(net, cfg.init_type, init_gain=cfg.init_gain) return net if __name__ == "__main__": x1 = torch.randn(4,32,128,128) x2 = torch.randn(4,32,128,128) cfg = dict( type = 'base_transformer_pos_s4_dd8_dedim8', input_nc=32, output_nc=2, token_len=4, with_pos='learned', enc_depth=1, dec_depth=8, dim_head=8, decoder_dim_head=8, init_type='normal', init_gain=0.02, ) from munch import DefaultMunch cfg = DefaultMunch.fromDict(cfg) model = base_transformer_pos_s4_dd8_dedim8(cfg) outs = model([x1, x2]) print('BIT_head', outs) print(outs.shape)