|
|
"""
|
|
|
modules.py - This file stores the rather boring network blocks.
|
|
|
|
|
|
x - usually means features that only depends on the image
|
|
|
g - usually means features that also depends on the mask.
|
|
|
They might have an extra "group" or "num_objects" dimension, hence
|
|
|
batch_size * num_objects * num_channels * H * W
|
|
|
|
|
|
The trailing number of a variable usually denote the stride
|
|
|
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from XMem2.model.group_modules import *
|
|
|
from XMem2.model import resnet
|
|
|
from XMem2.model.cbam import CBAM
|
|
|
|
|
|
|
|
|
class FeatureFusionBlock(nn.Module):
|
|
|
def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
|
|
|
super().__init__()
|
|
|
|
|
|
self.distributor = MainToGroupDistributor()
|
|
|
self.block1 = GroupResBlock(x_in_dim + g_in_dim, g_mid_dim)
|
|
|
self.attention = CBAM(g_mid_dim)
|
|
|
self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
|
|
|
|
|
|
def forward(self, x, g):
|
|
|
batch_size, num_objects = g.shape[:2]
|
|
|
|
|
|
g = self.distributor(x, g)
|
|
|
g = self.block1(g)
|
|
|
r = self.attention(g.flatten(start_dim=0, end_dim=1))
|
|
|
r = r.view(batch_size, num_objects, *r.shape[1:])
|
|
|
|
|
|
g = self.block2(g + r)
|
|
|
|
|
|
return g
|
|
|
|
|
|
|
|
|
class HiddenUpdater(nn.Module):
|
|
|
|
|
|
def __init__(self, g_dims, mid_dim, hidden_dim):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
|
|
self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
|
|
|
self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
|
|
|
self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
|
|
|
|
|
|
self.transform = GConv2D(
|
|
|
mid_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1
|
|
|
)
|
|
|
|
|
|
nn.init.xavier_normal_(self.transform.weight)
|
|
|
|
|
|
def forward(self, g, h):
|
|
|
g = (
|
|
|
self.g16_conv(g[0])
|
|
|
+ self.g8_conv(downsample_groups(g[1], ratio=1 / 2))
|
|
|
+ self.g4_conv(downsample_groups(g[2], ratio=1 / 4))
|
|
|
)
|
|
|
|
|
|
g = torch.cat([g, h], 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
values = self.transform(g)
|
|
|
forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim])
|
|
|
update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2])
|
|
|
new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :])
|
|
|
new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
|
|
|
|
|
|
return new_h
|
|
|
|
|
|
|
|
|
class HiddenReinforcer(nn.Module):
|
|
|
|
|
|
def __init__(self, g_dim, hidden_dim):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.transform = GConv2D(
|
|
|
g_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1
|
|
|
)
|
|
|
|
|
|
nn.init.xavier_normal_(self.transform.weight)
|
|
|
|
|
|
def forward(self, g, h):
|
|
|
g = torch.cat([g, h], 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
values = self.transform(g)
|
|
|
forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim])
|
|
|
update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2])
|
|
|
new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :])
|
|
|
new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
|
|
|
|
|
|
return new_h
|
|
|
|
|
|
|
|
|
class ValueEncoder(nn.Module):
|
|
|
def __init__(self, value_dim, hidden_dim, single_object=False, pretrained=True):
|
|
|
super().__init__()
|
|
|
|
|
|
self.single_object = single_object
|
|
|
network = resnet.resnet18(
|
|
|
pretrained=pretrained, extra_dim=1 if single_object else 2
|
|
|
)
|
|
|
self.conv1 = network.conv1
|
|
|
self.bn1 = network.bn1
|
|
|
self.relu = network.relu
|
|
|
self.maxpool = network.maxpool
|
|
|
|
|
|
self.layer1 = network.layer1
|
|
|
self.layer2 = network.layer2
|
|
|
self.layer3 = network.layer3
|
|
|
|
|
|
self.distributor = MainToGroupDistributor()
|
|
|
self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim)
|
|
|
if hidden_dim > 0:
|
|
|
self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim)
|
|
|
else:
|
|
|
self.hidden_reinforce = None
|
|
|
|
|
|
def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True):
|
|
|
|
|
|
if not self.single_object:
|
|
|
g_1 = torch.stack([masks, others], 2)
|
|
|
else:
|
|
|
g_1 = masks.unsqueeze(2)
|
|
|
g_2 = self.distributor(image, g_1)
|
|
|
|
|
|
batch_size, num_objects = g_2.shape[:2]
|
|
|
g = g_2.flatten(start_dim=0, end_dim=1)
|
|
|
|
|
|
g = self.conv1(g)
|
|
|
g = self.bn1(g)
|
|
|
g = self.maxpool(g)
|
|
|
g = self.relu(g)
|
|
|
|
|
|
g = self.layer1(g)
|
|
|
g = self.layer2(g)
|
|
|
g = self.layer3(g)
|
|
|
|
|
|
g = g.view(batch_size, num_objects, *g.shape[1:])
|
|
|
g = self.fuser(image_feat_f16, g)
|
|
|
|
|
|
if is_deep_update and self.hidden_reinforce is not None:
|
|
|
h = self.hidden_reinforce(g, h)
|
|
|
|
|
|
return g, h
|
|
|
|
|
|
|
|
|
class KeyEncoder(nn.Module):
|
|
|
def __init__(self, pretrained=True):
|
|
|
super().__init__()
|
|
|
network = resnet.resnet50(pretrained=pretrained)
|
|
|
self.conv1 = network.conv1
|
|
|
self.bn1 = network.bn1
|
|
|
self.relu = network.relu
|
|
|
self.maxpool = network.maxpool
|
|
|
|
|
|
self.res2 = network.layer1
|
|
|
self.layer2 = network.layer2
|
|
|
self.layer3 = network.layer3
|
|
|
|
|
|
def forward(self, f):
|
|
|
x = self.conv1(f)
|
|
|
x = self.bn1(x)
|
|
|
x = self.relu(x)
|
|
|
x = self.maxpool(x)
|
|
|
f4 = self.res2(x)
|
|
|
f8 = self.layer2(f4)
|
|
|
f16 = self.layer3(f8)
|
|
|
|
|
|
return f16, f8, f4
|
|
|
|
|
|
|
|
|
class UpsampleBlock(nn.Module):
|
|
|
def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
|
|
|
super().__init__()
|
|
|
self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1)
|
|
|
self.distributor = MainToGroupDistributor(method='add')
|
|
|
self.out_conv = GroupResBlock(g_up_dim, g_out_dim)
|
|
|
self.scale_factor = scale_factor
|
|
|
|
|
|
def forward(self, skip_f, up_g):
|
|
|
skip_f = self.skip_conv(skip_f)
|
|
|
g = upsample_groups(up_g, ratio=self.scale_factor)
|
|
|
g = self.distributor(skip_f, g)
|
|
|
g = self.out_conv(g)
|
|
|
return g
|
|
|
|
|
|
|
|
|
class KeyProjection(nn.Module):
|
|
|
def __init__(self, in_dim, keydim):
|
|
|
super().__init__()
|
|
|
|
|
|
self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
|
|
|
|
|
|
self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
|
|
|
|
|
|
self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
|
|
|
|
|
|
nn.init.orthogonal_(self.key_proj.weight.data)
|
|
|
nn.init.zeros_(self.key_proj.bias.data)
|
|
|
|
|
|
def forward(self, x, need_s, need_e):
|
|
|
shrinkage = self.d_proj(x) ** 2 + 1 if (need_s) else None
|
|
|
selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
|
|
|
|
|
|
return self.key_proj(x), shrinkage, selection
|
|
|
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
def __init__(self, val_dim, hidden_dim):
|
|
|
super().__init__()
|
|
|
|
|
|
self.fuser = FeatureFusionBlock(1024, val_dim + hidden_dim, 512, 512)
|
|
|
if hidden_dim > 0:
|
|
|
self.hidden_update = HiddenUpdater([512, 256, 256 + 1], 256, hidden_dim)
|
|
|
else:
|
|
|
self.hidden_update = None
|
|
|
|
|
|
self.up_16_8 = UpsampleBlock(512, 512, 256)
|
|
|
self.up_8_4 = UpsampleBlock(256, 256, 256)
|
|
|
|
|
|
self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1)
|
|
|
|
|
|
def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
|
|
|
batch_size, num_objects = memory_readout.shape[:2]
|
|
|
|
|
|
if self.hidden_update is not None:
|
|
|
g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2))
|
|
|
else:
|
|
|
g16 = self.fuser(f16, memory_readout)
|
|
|
|
|
|
g8 = self.up_16_8(f8, g16)
|
|
|
g4 = self.up_8_4(f4, g8)
|
|
|
logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1)))
|
|
|
|
|
|
if h_out and self.hidden_update is not None:
|
|
|
g4 = torch.cat(
|
|
|
[g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2
|
|
|
)
|
|
|
hidden_state = self.hidden_update([g16, g8, g4], hidden_state)
|
|
|
else:
|
|
|
hidden_state = None
|
|
|
|
|
|
logits = F.interpolate(
|
|
|
logits, scale_factor=4, mode='bilinear', align_corners=False
|
|
|
)
|
|
|
logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
|
|
|
|
|
|
return hidden_state, logits
|
|
|
|