| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .BasePIFuNet import BasePIFuNet |
|
|
|
|
| class VhullPIFuNet(BasePIFuNet): |
| ''' |
| Vhull Piximp network is a minimal network demonstrating how the template works |
| also, it helps debugging the training/test schemes |
| It does the following: |
| 1. Compute the masks of images and stores under self.im_feats |
| 2. Calculate calibration and indexing |
| 3. Return if the points fall into the intersection of all masks |
| ''' |
|
|
| def __init__(self, |
| num_views, |
| projection_mode='orthogonal', |
| error_term=nn.MSELoss(), |
| ): |
| super(VhullPIFuNet, self).__init__( |
| projection_mode=projection_mode, |
| error_term=error_term) |
| self.name = 'vhull' |
|
|
| self.num_views = num_views |
|
|
| self.im_feat = None |
|
|
| def filter(self, images): |
| ''' |
| Filter the input images |
| store all intermediate features. |
| :param images: [B, C, H, W] input images |
| ''' |
| |
| if images.shape[1] > 3: |
| self.im_feat = images[:, 3:4, :, :] |
| |
| else: |
| self.im_feat = images[:, 0:1, :, :] |
|
|
| def query(self, points, calibs, transforms=None, labels=None): |
| ''' |
| Given 3D points, query the network predictions for each point. |
| Image features should be pre-computed before this call. |
| store all intermediate features. |
| query() function may behave differently during training/testing. |
| :param points: [B, 3, N] world space coordinates of points |
| :param calibs: [B, 3, 4] calibration matrices for each image |
| :param transforms: Optional [B, 2, 3] image space coordinate transforms |
| :param labels: Optional [B, Res, N] gt labeling |
| :return: [B, Res, N] predictions for each point |
| ''' |
| if labels is not None: |
| self.labels = labels |
|
|
| xyz = self.projection(points, calibs, transforms) |
| xy = xyz[:, :2, :] |
|
|
| point_local_feat = self.index(self.im_feat, xy) |
| local_shape = point_local_feat.shape |
| point_feat = point_local_feat.view( |
| local_shape[0] // self.num_views, |
| local_shape[1] * self.num_views, |
| -1) |
| pred = torch.prod(point_feat, dim=1) |
|
|
| self.preds = pred.unsqueeze(1) |
|
|