""" Group-specific modules They handle features that also depends on the mask. Features are typically of shape batch_size * num_objects * num_channels * H * W All of them are permutation equivariant w.r.t. to the num_objects dimension """ import torch import torch.nn as nn import torch.nn.functional as F def interpolate_groups(g, ratio, mode, align_corners): if len(g.shape) == 4: g = F.interpolate(g, scale_factor=ratio, mode=mode, align_corners=align_corners) elif len(g.shape) == 5: batch_size, num_objects = g.shape[:2] g = F.interpolate(g.flatten(start_dim=0, end_dim=1), scale_factor=ratio, mode=mode, align_corners=align_corners) g = g.view(batch_size, num_objects, *g.shape[1:]) return g def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False): return interpolate_groups(g, ratio, mode, align_corners) def downsample_groups(g, ratio=1/2, mode='area', align_corners=None): return interpolate_groups(g, ratio, mode, align_corners) class GConv2D(nn.Conv2d): def forward(self, g): batch_size, num_objects = g.shape[:2] g = super().forward(g.flatten(start_dim=0, end_dim=1)) return g.view(batch_size, num_objects, *g.shape[1:]) class GroupResBlock(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() if in_dim == out_dim: self.downsample = None else: self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1) def forward(self, g): out_g = self.conv1(F.relu(g)) out_g = self.conv2(F.relu(out_g)) if self.downsample is not None: g = self.downsample(g) return out_g + g class MainToGroupDistributor(nn.Module): def __init__(self, x_transform=None, method='cat', reverse_order=False): super().__init__() self.x_transform = x_transform self.method = method self.reverse_order = reverse_order def forward(self, x, g): num_objects = g.shape[1] while 0: print(num_objects, g.size()) # 3 torch.Size([8, 3, 2, 384, 384]) if self.x_transform is not None: x = self.x_transform(x) if self.method == 'cat': if self.reverse_order: g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2) else: # print('2', g.size(), x.size(), x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1).size()) # torch.Size([8, 2, 2, 224, 224]) torch.Size([8, 3, 224, 224]) torch.Size([8, 2, 3, 224, 224]) # torch.Size([1, 1, 2, 480, 864]) torch.Size([1, 3, 480, 864]) torch.Size([1, 1, 3, 480, 864]) g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2) elif self.method == 'add': # print(g.size(), x.size(), x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1).size()) # torch.Size([8, 2, 512, 16, 16]) torch.Size([8, 512, 16, 16]) torch.Size([8, 2, 512, 16, 16]) g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g else: raise NotImplementedError return g