| | """ |
| | big_modules.py - This file stores higher-level network blocks. |
| | |
| | x - usually denotes features that are shared between objects. |
| | g - usually denotes features that are not shared between objects |
| | with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W). |
| | |
| | The trailing number of a variable usually denotes the stride |
| | """ |
| |
|
| | from typing import Iterable |
| | from omegaconf import DictConfig |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d |
| | from matanyone.model.utils import resnet |
| | from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock |
| |
|
| | class UncertPred(nn.Module): |
| | def __init__(self, model_cfg: DictConfig): |
| | super().__init__() |
| | self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(64) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1) |
| | self.bn2 = nn.BatchNorm2d(32) |
| | self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1) |
| | |
| | def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor): |
| | last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area') |
| | x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1) |
| | x = self.conv1x1_v2(x) |
| | x = self.bn1(x) |
| | x = self.relu(x) |
| | x = self.conv3x3(x) |
| | x = self.bn2(x) |
| | x = self.relu(x) |
| | x = self.conv3x3_out(x) |
| | return x |
| | |
| | |
| | def train(self, mode=True): |
| | self.training = False |
| | for module in self.children(): |
| | module.train(False) |
| | return self |
| |
|
| | class PixelEncoder(nn.Module): |
| | def __init__(self, model_cfg: DictConfig): |
| | super().__init__() |
| |
|
| | self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type |
| | |
| | |
| | is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True) |
| | if self.is_resnet: |
| | if model_cfg.pixel_encoder.type == 'resnet18': |
| | network = resnet.resnet18(pretrained=is_pretrained_resnet) |
| | elif model_cfg.pixel_encoder.type == 'resnet50': |
| | network = resnet.resnet50(pretrained=is_pretrained_resnet) |
| | else: |
| | raise NotImplementedError |
| | self.conv1 = network.conv1 |
| | self.bn1 = network.bn1 |
| | self.relu = network.relu |
| | self.maxpool = network.maxpool |
| |
|
| | self.res2 = network.layer1 |
| | self.layer2 = network.layer2 |
| | self.layer3 = network.layer3 |
| | else: |
| | raise NotImplementedError |
| |
|
| | def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor): |
| | f1 = x |
| | x = self.conv1(x) |
| | x = self.bn1(x) |
| | x = self.relu(x) |
| | f2 = x |
| | x = self.maxpool(x) |
| | f4 = self.res2(x) |
| | f8 = self.layer2(f4) |
| | f16 = self.layer3(f8) |
| |
|
| | return f16, f8, f4, f2, f1 |
| |
|
| | |
| | def train(self, mode=True): |
| | self.training = False |
| | for module in self.children(): |
| | module.train(False) |
| | return self |
| |
|
| |
|
| | class KeyProjection(nn.Module): |
| | def __init__(self, model_cfg: DictConfig): |
| | super().__init__() |
| | in_dim = model_cfg.pixel_encoder.ms_dims[0] |
| | mid_dim = model_cfg.pixel_dim |
| | key_dim = model_cfg.key_dim |
| |
|
| | self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1) |
| | self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) |
| | |
| | self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1) |
| | |
| | self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) |
| |
|
| | nn.init.orthogonal_(self.key_proj.weight.data) |
| | nn.init.zeros_(self.key_proj.bias.data) |
| |
|
| | def forward(self, x: torch.Tensor, *, need_s: bool, |
| | need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor): |
| | x = self.pix_feat_proj(x) |
| | shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None |
| | selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None |
| |
|
| | return self.key_proj(x), shrinkage, selection |
| |
|
| |
|
| | class MaskEncoder(nn.Module): |
| | def __init__(self, model_cfg: DictConfig, single_object=False): |
| | super().__init__() |
| | pixel_dim = model_cfg.pixel_dim |
| | value_dim = model_cfg.value_dim |
| | sensory_dim = model_cfg.sensory_dim |
| | final_dim = model_cfg.mask_encoder.final_dim |
| |
|
| | self.single_object = single_object |
| | extra_dim = 1 if single_object else 2 |
| |
|
| | |
| | |
| | is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True) |
| | if model_cfg.mask_encoder.type == 'resnet18': |
| | network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim) |
| | elif model_cfg.mask_encoder.type == 'resnet50': |
| | network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim) |
| | else: |
| | raise NotImplementedError |
| | self.conv1 = network.conv1 |
| | self.bn1 = network.bn1 |
| | self.relu = network.relu |
| | self.maxpool = network.maxpool |
| |
|
| | self.layer1 = network.layer1 |
| | self.layer2 = network.layer2 |
| | self.layer3 = network.layer3 |
| |
|
| | self.distributor = MainToGroupDistributor() |
| | self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim) |
| |
|
| | self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim) |
| |
|
| | def forward(self, |
| | image: torch.Tensor, |
| | pix_feat: torch.Tensor, |
| | sensory: torch.Tensor, |
| | masks: torch.Tensor, |
| | others: torch.Tensor, |
| | *, |
| | deep_update: bool = True, |
| | chunk_size: int = -1) -> (torch.Tensor, torch.Tensor): |
| | |
| | |
| | if self.single_object: |
| | g = masks.unsqueeze(2) |
| | else: |
| | g = torch.stack([masks, others], dim=2) |
| |
|
| | g = self.distributor(image, g) |
| |
|
| | batch_size, num_objects = g.shape[:2] |
| | if chunk_size < 1 or chunk_size >= num_objects: |
| | chunk_size = num_objects |
| | fast_path = True |
| | new_sensory = sensory |
| | else: |
| | if deep_update: |
| | new_sensory = torch.empty_like(sensory) |
| | else: |
| | new_sensory = sensory |
| | fast_path = False |
| |
|
| | |
| | all_g = [] |
| | for i in range(0, num_objects, chunk_size): |
| | if fast_path: |
| | g_chunk = g |
| | else: |
| | g_chunk = g[:, i:i + chunk_size] |
| | actual_chunk_size = g_chunk.shape[1] |
| | g_chunk = g_chunk.flatten(start_dim=0, end_dim=1) |
| |
|
| | g_chunk = self.conv1(g_chunk) |
| | g_chunk = self.bn1(g_chunk) |
| | g_chunk = self.maxpool(g_chunk) |
| | g_chunk = self.relu(g_chunk) |
| |
|
| | g_chunk = self.layer1(g_chunk) |
| | g_chunk = self.layer2(g_chunk) |
| | g_chunk = self.layer3(g_chunk) |
| |
|
| | g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:]) |
| | g_chunk = self.fuser(pix_feat, g_chunk) |
| | all_g.append(g_chunk) |
| | if deep_update: |
| | if fast_path: |
| | new_sensory = self.sensory_update(g_chunk, sensory) |
| | else: |
| | new_sensory[:, i:i + chunk_size] = self.sensory_update( |
| | g_chunk, sensory[:, i:i + chunk_size]) |
| | g = torch.cat(all_g, dim=1) |
| |
|
| | return g, new_sensory |
| |
|
| | |
| | def train(self, mode=True): |
| | self.training = False |
| | for module in self.children(): |
| | module.train(False) |
| | return self |
| |
|
| |
|
| | class PixelFeatureFuser(nn.Module): |
| | def __init__(self, model_cfg: DictConfig, single_object=False): |
| | super().__init__() |
| | value_dim = model_cfg.value_dim |
| | sensory_dim = model_cfg.sensory_dim |
| | pixel_dim = model_cfg.pixel_dim |
| | embed_dim = model_cfg.embed_dim |
| | self.single_object = single_object |
| |
|
| | self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim) |
| | if self.single_object: |
| | self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1) |
| | else: |
| | self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1) |
| |
|
| | def forward(self, |
| | pix_feat: torch.Tensor, |
| | pixel_memory: torch.Tensor, |
| | sensory_memory: torch.Tensor, |
| | last_mask: torch.Tensor, |
| | last_others: torch.Tensor, |
| | *, |
| | chunk_size: int = -1) -> torch.Tensor: |
| | batch_size, num_objects = pixel_memory.shape[:2] |
| |
|
| | if self.single_object: |
| | last_mask = last_mask.unsqueeze(2) |
| | else: |
| | last_mask = torch.stack([last_mask, last_others], dim=2) |
| |
|
| | if chunk_size < 1: |
| | chunk_size = num_objects |
| |
|
| | |
| | all_p16 = [] |
| | for i in range(0, num_objects, chunk_size): |
| | sensory_readout = self.sensory_compress( |
| | torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2)) |
| | p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout |
| | p16 = self.fuser(pix_feat, p16) |
| | all_p16.append(p16) |
| | p16 = torch.cat(all_p16, dim=1) |
| |
|
| | return p16 |
| |
|
| |
|
| | class MaskDecoder(nn.Module): |
| | def __init__(self, model_cfg: DictConfig): |
| | super().__init__() |
| | embed_dim = model_cfg.embed_dim |
| | sensory_dim = model_cfg.sensory_dim |
| | ms_image_dims = model_cfg.pixel_encoder.ms_dims |
| | up_dims = model_cfg.mask_decoder.up_dims |
| |
|
| | assert embed_dim == up_dims[0] |
| |
|
| | self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim, |
| | sensory_dim) |
| |
|
| | self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1]) |
| | self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1]) |
| | self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2]) |
| | |
| | self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3]) |
| | self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4]) |
| |
|
| | self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) |
| | self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) |
| |
|
| | def forward(self, |
| | ms_image_feat: Iterable[torch.Tensor], |
| | memory_readout: torch.Tensor, |
| | sensory: torch.Tensor, |
| | *, |
| | chunk_size: int = -1, |
| | update_sensory: bool = True, |
| | seg_pass: bool = False, |
| | last_mask=None, |
| | sigmoid_residual=False) -> (torch.Tensor, torch.Tensor): |
| |
|
| | batch_size, num_objects = memory_readout.shape[:2] |
| | f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:]) |
| | if chunk_size < 1 or chunk_size >= num_objects: |
| | chunk_size = num_objects |
| | fast_path = True |
| | new_sensory = sensory |
| | else: |
| | if update_sensory: |
| | new_sensory = torch.empty_like(sensory) |
| | else: |
| | new_sensory = sensory |
| | fast_path = False |
| |
|
| | |
| | all_logits = [] |
| | for i in range(0, num_objects, chunk_size): |
| | if fast_path: |
| | p16 = memory_readout |
| | else: |
| | p16 = memory_readout[:, i:i + chunk_size] |
| | actual_chunk_size = p16.shape[1] |
| |
|
| | p8 = self.up_16_8(p16, f8) |
| | p4 = self.up_8_4(p8, f4) |
| | p2 = self.up_4_2(p4, f2) |
| | p1 = self.up_2_1(p2, f1) |
| | with torch.cuda.amp.autocast(enabled=False): |
| | if seg_pass: |
| | if last_mask is not None: |
| | res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) |
| | if sigmoid_residual: |
| | res = (torch.sigmoid(res) - 0.5) * 2 |
| | logits = last_mask + res |
| | else: |
| | logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) |
| | else: |
| | if last_mask is not None: |
| | res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) |
| | if sigmoid_residual: |
| | res = (torch.sigmoid(res) - 0.5) * 2 |
| | logits = last_mask + res |
| | else: |
| | logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) |
| | |
| | if update_sensory: |
| | p1 = torch.cat( |
| | [p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2) |
| | if fast_path: |
| | new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory) |
| | else: |
| | new_sensory[:, |
| | i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1], |
| | sensory[:, |
| | i:i + chunk_size]) |
| | all_logits.append(logits) |
| | logits = torch.cat(all_logits, dim=0) |
| | logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) |
| |
|
| | return new_sensory, logits |
| |
|