|
|
|
|
| import logging
|
| import numpy as np
|
| import pickle
|
| from enum import Enum
|
| from typing import Optional
|
| import torch
|
| from torch import nn
|
|
|
| from detectron2.config import CfgNode
|
| from detectron2.utils.file_io import PathManager
|
|
|
| from .vertex_direct_embedder import VertexDirectEmbedder
|
| from .vertex_feature_embedder import VertexFeatureEmbedder
|
|
|
|
|
| class EmbedderType(Enum):
|
| """
|
| Embedder type which defines how vertices are mapped into the embedding space:
|
| - "vertex_direct": direct vertex embedding
|
| - "vertex_feature": embedding vertex features
|
| """
|
|
|
| VERTEX_DIRECT = "vertex_direct"
|
| VERTEX_FEATURE = "vertex_feature"
|
|
|
|
|
| def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module:
|
| """
|
| Create an embedder based on the provided configuration
|
|
|
| Args:
|
| embedder_spec (CfgNode): embedder configuration
|
| embedder_dim (int): embedding space dimensionality
|
| Return:
|
| An embedder instance for the specified configuration
|
| Raises ValueError, in case of unexpected embedder type
|
| """
|
| embedder_type = EmbedderType(embedder_spec.TYPE)
|
| if embedder_type == EmbedderType.VERTEX_DIRECT:
|
| embedder = VertexDirectEmbedder(
|
| num_vertices=embedder_spec.NUM_VERTICES,
|
| embed_dim=embedder_dim,
|
| )
|
| if embedder_spec.INIT_FILE != "":
|
| embedder.load(embedder_spec.INIT_FILE)
|
| elif embedder_type == EmbedderType.VERTEX_FEATURE:
|
| embedder = VertexFeatureEmbedder(
|
| num_vertices=embedder_spec.NUM_VERTICES,
|
| feature_dim=embedder_spec.FEATURE_DIM,
|
| embed_dim=embedder_dim,
|
| train_features=embedder_spec.FEATURES_TRAINABLE,
|
| )
|
| if embedder_spec.INIT_FILE != "":
|
| embedder.load(embedder_spec.INIT_FILE)
|
| else:
|
| raise ValueError(f"Unexpected embedder type {embedder_type}")
|
|
|
| if not embedder_spec.IS_TRAINABLE:
|
| embedder.requires_grad_(False)
|
|
|
| return embedder
|
|
|
|
|
| class Embedder(nn.Module):
|
| """
|
| Embedder module that serves as a container for embedders to use with different
|
| meshes. Extends Module to automatically save / load state dict.
|
| """
|
|
|
| DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder."
|
|
|
| def __init__(self, cfg: CfgNode):
|
| """
|
| Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule
|
| "embedder_{i}".
|
|
|
| Args:
|
| cfg (CfgNode): configuration options
|
| """
|
| super(Embedder, self).__init__()
|
| self.mesh_names = set()
|
| embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
|
| logger = logging.getLogger(__name__)
|
| for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items():
|
| logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}")
|
| self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim))
|
| self.mesh_names.add(mesh_name)
|
| if cfg.MODEL.WEIGHTS != "":
|
| self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS)
|
|
|
| def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None):
|
| if prefix is None:
|
| prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX
|
| state_dict = None
|
| if fpath.endswith(".pkl"):
|
| with PathManager.open(fpath, "rb") as hFile:
|
| state_dict = pickle.load(hFile, encoding="latin1")
|
| else:
|
| with PathManager.open(fpath, "rb") as hFile:
|
| state_dict = torch.load(hFile, map_location=torch.device("cpu"))
|
| if state_dict is not None and "model" in state_dict:
|
| state_dict_local = {}
|
| for key in state_dict["model"]:
|
| if key.startswith(prefix):
|
| v_key = state_dict["model"][key]
|
| if isinstance(v_key, np.ndarray):
|
| v_key = torch.from_numpy(v_key)
|
| state_dict_local[key[len(prefix) :]] = v_key
|
|
|
| self.load_state_dict(state_dict_local, strict=False)
|
|
|
| def forward(self, mesh_name: str) -> torch.Tensor:
|
| """
|
| Produce vertex embeddings for the specific mesh; vertex embeddings are
|
| a tensor of shape [N, D] where:
|
| N = number of vertices
|
| D = number of dimensions in the embedding space
|
| Args:
|
| mesh_name (str): name of a mesh for which to obtain vertex embeddings
|
| Return:
|
| Vertex embeddings, a tensor of shape [N, D]
|
| """
|
| return getattr(self, f"embedder_{mesh_name}")()
|
|
|
| def has_embeddings(self, mesh_name: str) -> bool:
|
| return hasattr(self, f"embedder_{mesh_name}")
|
|
|