File size: 3,146 Bytes
355b5d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..geometry import index, orthogonal, perspective
class BasePIFuNet(nn.Module):
def __init__(self,
projection_mode='orthogonal',
criteria={'occ': nn.MSELoss()},
):
'''
args:
projection_mode: orthonal / perspective
error_term: point-wise error term
'''
super(BasePIFuNet, self).__init__()
self.name = 'base'
self.criteria = criteria
self.index = index
self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
self.preds = None
self.labels = None
self.nmls = None
self.labels_nml = None
self.preds_surface = None # with normal loss only
def forward(self, points, images, calibs, transforms=None):
'''
args:
points: [B, 3, N] 3d points in world space
images: [B, C, H, W] input images
calibs: [B, 3, 4] calibration matrices for each image
transforms: [B, 2, 3] image space coordinate transforms
return:
[B, C, N] prediction corresponding to the given points
'''
self.filter(images)
self.query(points, calibs, transforms)
return self.get_preds()
def filter(self, images):
'''
apply a fully convolutional network to images.
the resulting feature will be stored.
args:
images: [B, C, H, W]
'''
None
def query(self, points, calibs, trasnforms=None, labels=None):
'''
given 3d points, we obtain 2d projection of these given the camera matrices.
filter needs to be called beforehand.
the prediction is stored to self.preds
args:
points: [B, 3, N] 3d points in world space
calibs: [B, 3, 4] calibration matrices for each image
transforms: [B, 2, 3] image space coordinate transforms
labels: [B, C, N] ground truth labels (for supervision only)
return:
[B, C, N] prediction
'''
None
def calc_normal(self, points, calibs, transforms=None, delta=0.1):
'''
return surface normal in 'model' space.
it computes normal only in the last stack.
note that the current implementation use forward difference.
args:
points: [B, 3, N] 3d points in world space
calibs: [B, 3, 4] calibration matrices for each image
transforms: [B, 2, 3] image space coordinate transforms
delta: perturbation for finite difference
'''
None
def get_preds(self):
'''
return the current prediction.
return:
[B, C, N] prediction
'''
return self.preds
def get_error(self, gamma=None):
'''
return the loss given the ground truth labels and prediction
'''
return self.error_term(self.preds, self.labels, gamma)
|