|
|
|
|
| import torch
|
| from torch import nn
|
|
|
| from detectron2.config import CfgNode
|
| from detectron2.layers import ConvTranspose2d, interpolate
|
|
|
| from ...structures import DensePoseChartPredictorOutput
|
| from ..utils import initialize_module_params
|
| from .registry import DENSEPOSE_PREDICTOR_REGISTRY
|
|
|
|
|
| @DENSEPOSE_PREDICTOR_REGISTRY.register()
|
| class DensePoseChartPredictor(nn.Module):
|
| """
|
| Predictor (last layers of a DensePose model) that takes DensePose head outputs as an input
|
| and produces 4 tensors which represent DensePose results for predefined body parts
|
| (patches / charts):
|
| * coarse segmentation, a tensor of shape [N, K, Hout, Wout]
|
| * fine segmentation, a tensor of shape [N, C, Hout, Wout]
|
| * U coordinates, a tensor of shape [N, C, Hout, Wout]
|
| * V coordinates, a tensor of shape [N, C, Hout, Wout]
|
| where
|
| - N is the number of instances
|
| - K is the number of coarse segmentation channels (
|
| 2 = foreground / background,
|
| 15 = one of 14 body parts / background)
|
| - C is the number of fine segmentation channels (
|
| 24 fine body parts / background)
|
| - Hout and Wout are height and width of predictions
|
| """
|
|
|
| def __init__(self, cfg: CfgNode, input_channels: int):
|
| """
|
| Initialize predictor using configuration options
|
|
|
| Args:
|
| cfg (CfgNode): configuration options
|
| input_channels (int): input tensor size along the channel dimension
|
| """
|
| super().__init__()
|
| dim_in = input_channels
|
| n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
|
| dim_out_patches = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1
|
| kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL
|
|
|
| self.ann_index_lowres = ConvTranspose2d(
|
| dim_in, n_segm_chan, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
|
| )
|
|
|
| self.index_uv_lowres = ConvTranspose2d(
|
| dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
|
| )
|
|
|
| self.u_lowres = ConvTranspose2d(
|
| dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
|
| )
|
|
|
| self.v_lowres = ConvTranspose2d(
|
| dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
|
| )
|
| self.scale_factor = cfg.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE
|
| initialize_module_params(self)
|
|
|
| def interp2d(self, tensor_nchw: torch.Tensor):
|
| """
|
| Bilinear interpolation method to be used for upscaling
|
|
|
| Args:
|
| tensor_nchw (tensor): tensor of shape (N, C, H, W)
|
| Return:
|
| tensor of shape (N, C, Hout, Wout), where Hout and Wout are computed
|
| by applying the scale factor to H and W
|
| """
|
| return interpolate(
|
| tensor_nchw, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
|
| )
|
|
|
| def forward(self, head_outputs: torch.Tensor):
|
| """
|
| Perform forward step on DensePose head outputs
|
|
|
| Args:
|
| head_outputs (tensor): DensePose head outputs, tensor of shape [N, D, H, W]
|
| Return:
|
| An instance of DensePoseChartPredictorOutput
|
| """
|
| return DensePoseChartPredictorOutput(
|
| coarse_segm=self.interp2d(self.ann_index_lowres(head_outputs)),
|
| fine_segm=self.interp2d(self.index_uv_lowres(head_outputs)),
|
| u=self.interp2d(self.u_lowres(head_outputs)),
|
| v=self.interp2d(self.v_lowres(head_outputs)),
|
| )
|
|
|