Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # | |
| # Developed by Haozhe Xie <cshzxie@gmail.com> | |
| # | |
| # References: | |
| # - https://github.com/shawnxu1318/MVCNN-Multi-View-Convolutional-Neural-Networks/blob/master/mvcnn.py | |
| import torch | |
| import torchvision.models | |
| class Encoder(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(Encoder, self).__init__() | |
| self.cfg = cfg | |
| # Layer Definition | |
| vgg16_bn = torchvision.models.vgg16_bn(pretrained=True) | |
| self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27] | |
| self.layer1 = torch.nn.Sequential( | |
| torch.nn.Conv2d(512, 512, kernel_size=3), | |
| torch.nn.BatchNorm2d(512), | |
| torch.nn.ELU(), | |
| ) | |
| self.layer2 = torch.nn.Sequential( | |
| torch.nn.Conv2d(512, 512, kernel_size=3), | |
| torch.nn.BatchNorm2d(512), | |
| torch.nn.ELU(), | |
| torch.nn.MaxPool2d(kernel_size=3) | |
| ) | |
| self.layer3 = torch.nn.Sequential( | |
| torch.nn.Conv2d(512, 256, kernel_size=1), | |
| torch.nn.BatchNorm2d(256), | |
| torch.nn.ELU() | |
| ) | |
| # Don't update params in VGG16 | |
| for param in vgg16_bn.parameters(): | |
| param.requires_grad = False | |
| def forward(self, rendering_images): | |
| # print(rendering_images.size()) # torch.Size([batch_size, n_views, img_c, img_h, img_w]) | |
| rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous() | |
| rendering_images = torch.split(rendering_images, 1, dim=0) | |
| image_features = [] | |
| for img in rendering_images: | |
| features = self.vgg(img.squeeze(dim=0)) | |
| # print(features.size()) # torch.Size([batch_size, 512, 28, 28]) | |
| features = self.layer1(features) | |
| # print(features.size()) # torch.Size([batch_size, 512, 26, 26]) | |
| features = self.layer2(features) | |
| # print(features.size()) # torch.Size([batch_size, 512, 24, 24]) | |
| features = self.layer3(features) | |
| # print(features.size()) # torch.Size([batch_size, 256, 8, 8]) | |
| image_features.append(features) | |
| image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous() | |
| # print(image_features.size()) # torch.Size([batch_size, n_views, 256, 8, 8]) | |
| return image_features | |
| class DummyCfg: | |
| class NETWORK: | |
| TCONV_USE_BIAS = False | |
| cfg = DummyCfg() | |
| # Instantiate the decoder | |
| encoder = Encoder(cfg) | |
| # Simulate input: shape [batch_size,n_views,img_c, img_h, img_w] | |
| batch_size = 64 | |
| n_views=5 | |
| img_c, img_h, img_w = 3,224,224 | |
| dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w) | |
| # Run the decoder | |
| print(dummy_input.shape) | |
| image_features = encoder(dummy_input) | |
| print("image_features shape:", image_features.shape) # Expected: [64, 5, 9, 32, 32, 32] | |