Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .BasePIFuNet import BasePIFuNet | |
| from .SurfaceClassifier import SurfaceClassifier | |
| from .DepthNormalizer import DepthNormalizer | |
| from .ConvFilters import * | |
| from ..net_util import init_net | |
| class ConvPIFuNet(BasePIFuNet): | |
| ''' | |
| Conv Piximp network is the standard 3-phase network that we will use. | |
| The image filter is a pure multi-layer convolutional network, | |
| while during feature extraction phase all features in the pyramid at the projected location | |
| will be aggregated. | |
| It does the following: | |
| 1. Compute image feature pyramids and store it in self.im_feat_list | |
| 2. Calculate calibration and indexing on each of the feat, and append them together | |
| 3. Classification. | |
| ''' | |
| def __init__(self, | |
| opt, | |
| projection_mode='orthogonal', | |
| error_term=nn.MSELoss(), | |
| ): | |
| super(ConvPIFuNet, self).__init__( | |
| projection_mode=projection_mode, | |
| error_term=error_term) | |
| self.name = 'convpifu' | |
| self.opt = opt | |
| self.num_views = self.opt.num_views | |
| self.image_filter = self.define_imagefilter(opt) | |
| self.surface_classifier = SurfaceClassifier( | |
| filter_channels=self.opt.mlp_dim, | |
| num_views=self.opt.num_views, | |
| no_residual=self.opt.no_residual, | |
| last_op=nn.Sigmoid()) | |
| self.normalizer = DepthNormalizer(opt) | |
| # This is a list of [B x Feat_i x H x W] features | |
| self.im_feat_list = [] | |
| init_net(self) | |
| def define_imagefilter(self, opt): | |
| net = None | |
| if opt.netIMF == 'multiconv': | |
| net = MultiConv(opt.enc_dim) | |
| elif 'resnet' in opt.netIMF: | |
| net = ResNet(model=opt.netIMF) | |
| elif opt.netIMF == 'vgg16': | |
| net = Vgg16() | |
| else: | |
| raise NotImplementedError('model name [%s] is not recognized' % opt.imf_type) | |
| return net | |
| def filter(self, images): | |
| ''' | |
| Filter the input images | |
| store all intermediate features. | |
| :param images: [B, C, H, W] input images | |
| ''' | |
| self.im_feat_list = self.image_filter(images) | |
| def query(self, points, calibs, transforms=None, labels=None): | |
| ''' | |
| Given 3D points, query the network predictions for each point. | |
| Image features should be pre-computed before this call. | |
| store all intermediate features. | |
| query() function may behave differently during training/testing. | |
| :param points: [B, 3, N] world space coordinates of points | |
| :param calibs: [B, 3, 4] calibration matrices for each image | |
| :param transforms: Optional [B, 2, 3] image space coordinate transforms | |
| :param labels: Optional [B, Res, N] gt labeling | |
| :return: [B, Res, N] predictions for each point | |
| ''' | |
| if labels is not None: | |
| self.labels = labels | |
| xyz = self.projection(points, calibs, transforms) | |
| xy = xyz[:, :2, :] | |
| z = xyz[:, 2:3, :] | |
| z_feat = self.normalizer(z) | |
| # This is a list of [B, Feat_i, N] features | |
| point_local_feat_list = [self.index(im_feat, xy) for im_feat in self.im_feat_list] | |
| point_local_feat_list.append(z_feat) | |
| # [B, Feat_all, N] | |
| point_local_feat = torch.cat(point_local_feat_list, 1) | |
| self.preds = self.surface_classifier(point_local_feat) | |