# -*- coding: utf-8 -*- # # Developed by Haozhe Xie # # References: # - https://github.com/shawnxu1318/MVCNN-Multi-View-Convolutional-Neural-Networks/blob/master/mvcnn.py import torch import torchvision.models class Encoder(torch.nn.Module): def __init__(self, cfg): super(Encoder, self).__init__() self.cfg = cfg # Layer Definition vgg16_bn = torchvision.models.vgg16_bn(pretrained=True) self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27] self.layer1 = torch.nn.Sequential( torch.nn.Conv2d(512, 512, kernel_size=3), torch.nn.BatchNorm2d(512), torch.nn.ELU(), ) self.layer2 = torch.nn.Sequential( torch.nn.Conv2d(512, 512, kernel_size=3), torch.nn.BatchNorm2d(512), torch.nn.ELU(), torch.nn.MaxPool2d(kernel_size=3) ) self.layer3 = torch.nn.Sequential( torch.nn.Conv2d(512, 256, kernel_size=1), torch.nn.BatchNorm2d(256), torch.nn.ELU() ) # Don't update params in VGG16 for param in vgg16_bn.parameters(): param.requires_grad = False def forward(self, rendering_images): # print(rendering_images.size()) # torch.Size([batch_size, n_views, img_c, img_h, img_w]) rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous() rendering_images = torch.split(rendering_images, 1, dim=0) image_features = [] for img in rendering_images: features = self.vgg(img.squeeze(dim=0)) # print(features.size()) # torch.Size([batch_size, 512, 28, 28]) features = self.layer1(features) # print(features.size()) # torch.Size([batch_size, 512, 26, 26]) features = self.layer2(features) # print(features.size()) # torch.Size([batch_size, 512, 24, 24]) features = self.layer3(features) # print(features.size()) # torch.Size([batch_size, 256, 8, 8]) image_features.append(features) image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous() # print(image_features.size()) # torch.Size([batch_size, n_views, 256, 8, 8]) return image_features class DummyCfg: class NETWORK: TCONV_USE_BIAS = False cfg = DummyCfg() # Instantiate the decoder encoder = Encoder(cfg) # Simulate input: shape [batch_size,n_views,img_c, img_h, img_w] batch_size = 64 n_views=5 img_c, img_h, img_w = 3,224,224 dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w) # Run the decoder print(dummy_input.shape) image_features = encoder(dummy_input) print("image_features shape:", image_features.shape) # Expected: [64, 5, 9, 32, 32, 32]