|
|
|
|
| from typing import Optional
|
| from torch import nn
|
|
|
| from detectron2.config import CfgNode
|
|
|
| from .cse.embedder import Embedder
|
| from .filter import DensePoseDataFilter
|
|
|
|
|
| def build_densepose_predictor(cfg: CfgNode, input_channels: int):
|
| """
|
| Create an instance of DensePose predictor based on configuration options.
|
|
|
| Args:
|
| cfg (CfgNode): configuration options
|
| input_channels (int): input tensor size along the channel dimension
|
| Return:
|
| An instance of DensePose predictor
|
| """
|
| from .predictors import DENSEPOSE_PREDICTOR_REGISTRY
|
|
|
| predictor_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME
|
| return DENSEPOSE_PREDICTOR_REGISTRY.get(predictor_name)(cfg, input_channels)
|
|
|
|
|
| def build_densepose_data_filter(cfg: CfgNode):
|
| """
|
| Build DensePose data filter which selects data for training
|
|
|
| Args:
|
| cfg (CfgNode): configuration options
|
|
|
| Return:
|
| Callable: list(Tensor), list(Instances) -> list(Tensor), list(Instances)
|
| An instance of DensePose filter, which takes feature tensors and proposals
|
| as an input and returns filtered features and proposals
|
| """
|
| dp_filter = DensePoseDataFilter(cfg)
|
| return dp_filter
|
|
|
|
|
| def build_densepose_head(cfg: CfgNode, input_channels: int):
|
| """
|
| Build DensePose head based on configurations options
|
|
|
| Args:
|
| cfg (CfgNode): configuration options
|
| input_channels (int): input tensor size along the channel dimension
|
| Return:
|
| An instance of DensePose head
|
| """
|
| from .roi_heads.registry import ROI_DENSEPOSE_HEAD_REGISTRY
|
|
|
| head_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.NAME
|
| return ROI_DENSEPOSE_HEAD_REGISTRY.get(head_name)(cfg, input_channels)
|
|
|
|
|
| def build_densepose_losses(cfg: CfgNode):
|
| """
|
| Build DensePose loss based on configurations options
|
|
|
| Args:
|
| cfg (CfgNode): configuration options
|
| Return:
|
| An instance of DensePose loss
|
| """
|
| from .losses import DENSEPOSE_LOSS_REGISTRY
|
|
|
| loss_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME
|
| return DENSEPOSE_LOSS_REGISTRY.get(loss_name)(cfg)
|
|
|
|
|
| def build_densepose_embedder(cfg: CfgNode) -> Optional[nn.Module]:
|
| """
|
| Build embedder used to embed mesh vertices into an embedding space.
|
| Embedder contains sub-embedders, one for each mesh ID.
|
|
|
| Args:
|
| cfg (cfgNode): configuration options
|
| Return:
|
| Embedding module
|
| """
|
| if cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS:
|
| return Embedder(cfg)
|
| return None
|
|
|