|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.nn import init
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
from rscd.models.decoderheads.help_func import Transformer, TransformerDecoder, TwoLayerConv2d
|
|
|
|
|
|
def init_weights(net, init_type='normal', init_gain=0.02):
|
|
|
"""Initialize network weights.
|
|
|
|
|
|
Parameters:
|
|
|
net (network) -- network to be initialized
|
|
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
|
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
|
|
|
|
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
|
|
work better for some applications. Feel free to try yourself.
|
|
|
"""
|
|
|
def init_func(m):
|
|
|
classname = m.__class__.__name__
|
|
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
|
|
if init_type == 'normal':
|
|
|
init.normal_(m.weight.data, 0.0, init_gain)
|
|
|
elif init_type == 'xavier':
|
|
|
init.xavier_normal_(m.weight.data, gain=init_gain)
|
|
|
elif init_type == 'kaiming':
|
|
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
|
|
elif init_type == 'orthogonal':
|
|
|
init.orthogonal_(m.weight.data, gain=init_gain)
|
|
|
else:
|
|
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
|
init.constant_(m.bias.data, 0.0)
|
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
|
init.normal_(m.weight.data, 1.0, init_gain)
|
|
|
init.constant_(m.bias.data, 0.0)
|
|
|
|
|
|
print('initialize network with %s' % init_type)
|
|
|
net.apply(init_func)
|
|
|
|
|
|
class Diff_map(torch.nn.Module):
|
|
|
def __init__(self, input_nc, output_nc,
|
|
|
output_sigmoid=False, if_upsample_2x=True):
|
|
|
"""
|
|
|
In the constructor we instantiate two nn.Linear modules and assign them as
|
|
|
member variables.
|
|
|
"""
|
|
|
super(Diff_map, self).__init__()
|
|
|
self.upsamplex2 = nn.Upsample(scale_factor=2)
|
|
|
self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear')
|
|
|
|
|
|
self.classifier = TwoLayerConv2d(in_channels=input_nc, out_channels=output_nc)
|
|
|
|
|
|
self.if_upsample_2x = if_upsample_2x
|
|
|
|
|
|
self.output_sigmoid = output_sigmoid
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
def forward(self, x12):
|
|
|
x = torch.abs(x12[0] - x12[1])
|
|
|
if not self.if_upsample_2x:
|
|
|
x = self.upsamplex2(x)
|
|
|
x = self.upsamplex4(x)
|
|
|
x = self.classifier(x)
|
|
|
|
|
|
if self.output_sigmoid:
|
|
|
x = self.sigmoid(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class BASE_Transformer(Diff_map):
|
|
|
"""
|
|
|
Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN
|
|
|
"""
|
|
|
def __init__(self, input_nc, output_nc, with_pos,
|
|
|
token_len=4, token_trans=True,
|
|
|
enc_depth=1, dec_depth=1,
|
|
|
dim_head=64, decoder_dim_head=64,
|
|
|
tokenizer=True, if_upsample_2x=True,
|
|
|
pool_mode='max', pool_size=2,
|
|
|
decoder_softmax=True, with_decoder_pos=None,
|
|
|
with_decoder=True):
|
|
|
super(BASE_Transformer, self).__init__(input_nc, output_nc,
|
|
|
if_upsample_2x=if_upsample_2x,
|
|
|
)
|
|
|
self.token_len = token_len
|
|
|
self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1,
|
|
|
padding=0, bias=False)
|
|
|
self.tokenizer = tokenizer
|
|
|
if not self.tokenizer:
|
|
|
|
|
|
self.pooling_size = pool_size
|
|
|
self.pool_mode = pool_mode
|
|
|
self.token_len = self.pooling_size * self.pooling_size
|
|
|
|
|
|
self.token_trans = token_trans
|
|
|
self.with_decoder = with_decoder
|
|
|
dim = 32
|
|
|
mlp_dim = 2*dim
|
|
|
|
|
|
self.with_pos = with_pos
|
|
|
if with_pos == 'learned':
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32))
|
|
|
decoder_pos_size = 256//4
|
|
|
self.with_decoder_pos = with_decoder_pos
|
|
|
if self.with_decoder_pos == 'learned':
|
|
|
self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32,
|
|
|
decoder_pos_size,
|
|
|
decoder_pos_size))
|
|
|
self.enc_depth = enc_depth
|
|
|
self.dec_depth = dec_depth
|
|
|
self.dim_head = dim_head
|
|
|
self.decoder_dim_head = decoder_dim_head
|
|
|
self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8,
|
|
|
dim_head=self.dim_head,
|
|
|
mlp_dim=mlp_dim, dropout=0)
|
|
|
self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth,
|
|
|
heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0,
|
|
|
softmax=decoder_softmax)
|
|
|
|
|
|
def _forward_semantic_tokens(self, x):
|
|
|
b, c, h, w = x.shape
|
|
|
spatial_attention = self.conv_a(x)
|
|
|
spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous()
|
|
|
spatial_attention = torch.softmax(spatial_attention, dim=-1)
|
|
|
x = x.view([b, c, -1]).contiguous()
|
|
|
tokens = torch.einsum('bln,bcn->blc', spatial_attention, x)
|
|
|
|
|
|
return tokens
|
|
|
|
|
|
def _forward_reshape_tokens(self, x):
|
|
|
|
|
|
if self.pool_mode == 'max':
|
|
|
x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size])
|
|
|
elif self.pool_mode == 'ave':
|
|
|
x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size])
|
|
|
else:
|
|
|
x = x
|
|
|
tokens = rearrange(x, 'b c h w -> b (h w) c')
|
|
|
return tokens
|
|
|
|
|
|
def _forward_transformer(self, x):
|
|
|
if self.with_pos:
|
|
|
x += self.pos_embedding
|
|
|
x = self.transformer(x)
|
|
|
return x
|
|
|
|
|
|
def _forward_transformer_decoder(self, x, m):
|
|
|
b, c, h, w = x.shape
|
|
|
if self.with_decoder_pos == 'fix':
|
|
|
x = x + self.pos_embedding_decoder
|
|
|
elif self.with_decoder_pos == 'learned':
|
|
|
x = x + self.pos_embedding_decoder
|
|
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
|
|
x = self.transformer_decoder(x, m)
|
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h)
|
|
|
return x
|
|
|
|
|
|
def _forward_simple_decoder(self, x, m):
|
|
|
b, c, h, w = x.shape
|
|
|
b, l, c = m.shape
|
|
|
m = m.expand([h,w,b,l,c])
|
|
|
m = rearrange(m, 'h w b l c -> l b c h w')
|
|
|
m = m.sum(0)
|
|
|
x = x + m
|
|
|
return x
|
|
|
|
|
|
def forward(self, x12):
|
|
|
x1, x2 = x12[0], x12[1]
|
|
|
|
|
|
if self.tokenizer:
|
|
|
token1 = self._forward_semantic_tokens(x1)
|
|
|
token2 = self._forward_semantic_tokens(x2)
|
|
|
else:
|
|
|
token1 = self._forward_reshape_tokens(x1)
|
|
|
token2 = self._forward_reshape_tokens(x2)
|
|
|
|
|
|
if self.token_trans:
|
|
|
self.tokens_ = torch.cat([token1, token2], dim=1)
|
|
|
self.tokens = self._forward_transformer(self.tokens_)
|
|
|
token1, token2 = self.tokens.chunk(2, dim=1)
|
|
|
|
|
|
if self.with_decoder:
|
|
|
x1 = self._forward_transformer_decoder(x1, token1)
|
|
|
x2 = self._forward_transformer_decoder(x2, token2)
|
|
|
else:
|
|
|
x1 = self._forward_simple_decoder(x1, token1)
|
|
|
x2 = self._forward_simple_decoder(x2, token2)
|
|
|
|
|
|
x = torch.abs(x1 - x2)
|
|
|
|
|
|
if not self.if_upsample_2x:
|
|
|
x = self.upsamplex2(x)
|
|
|
x = self.upsamplex4(x)
|
|
|
|
|
|
|
|
|
x = self.classifier(x)
|
|
|
if self.output_sigmoid:
|
|
|
x = self.sigmoid(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def base_resnet18(cfg):
|
|
|
net = Diff_map(input_nc=cfg.input_nc,
|
|
|
output_nc=cfg.output_nc,
|
|
|
output_sigmoid=cfg.output_sigmoid)
|
|
|
init_weights(net, cfg.init_type, init_gain=cfg.init_gain)
|
|
|
return net
|
|
|
|
|
|
def base_transformer_pos_s4(cfg):
|
|
|
net = BASE_Transformer(input_nc=cfg.input_nc,
|
|
|
output_nc=cfg.output_nc,
|
|
|
token_len=cfg.token_len,
|
|
|
with_pos=cfg.with_pos)
|
|
|
init_weights(net, cfg.init_type, init_gain=cfg.init_gain)
|
|
|
return net
|
|
|
|
|
|
def base_transformer_pos_s4_dd8(cfg):
|
|
|
net = BASE_Transformer(input_nc=cfg.input_nc,
|
|
|
output_nc=cfg.output_nc,
|
|
|
token_len=cfg.token_len,
|
|
|
with_pos=cfg.with_pos,
|
|
|
enc_depth=cfg.enc_depth,
|
|
|
dec_depth=cfg.dec_depth)
|
|
|
init_weights(net, cfg.init_type, init_gain=cfg.init_gain)
|
|
|
return net
|
|
|
|
|
|
def base_transformer_pos_s4_dd8_dedim8(cfg):
|
|
|
net = BASE_Transformer(input_nc=cfg.input_nc,
|
|
|
output_nc=cfg.output_nc,
|
|
|
token_len=cfg.token_len,
|
|
|
with_pos=cfg.with_pos,
|
|
|
enc_depth=cfg.enc_depth,
|
|
|
dec_depth=cfg.dec_depth,
|
|
|
dim_head=cfg.dim_head,
|
|
|
decoder_dim_head=cfg.decoder_dim_head)
|
|
|
init_weights(net, cfg.init_type, init_gain=cfg.init_gain)
|
|
|
return net
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
x1 = torch.randn(4,32,128,128)
|
|
|
x2 = torch.randn(4,32,128,128)
|
|
|
cfg = dict(
|
|
|
type = 'base_transformer_pos_s4_dd8_dedim8',
|
|
|
input_nc=32,
|
|
|
output_nc=2,
|
|
|
token_len=4,
|
|
|
with_pos='learned',
|
|
|
enc_depth=1,
|
|
|
dec_depth=8,
|
|
|
dim_head=8,
|
|
|
decoder_dim_head=8,
|
|
|
|
|
|
init_type='normal',
|
|
|
init_gain=0.02,
|
|
|
)
|
|
|
from munch import DefaultMunch
|
|
|
cfg = DefaultMunch.fromDict(cfg)
|
|
|
model = base_transformer_pos_s4_dd8_dedim8(cfg)
|
|
|
outs = model([x1, x2])
|
|
|
print('BIT_head', outs)
|
|
|
print(outs.shape) |