| |
|
|
| import logging |
| import numpy as np |
| import pickle |
| from enum import Enum |
| from typing import Optional |
| import torch |
| from torch import nn |
|
|
| from detectron2.config import CfgNode |
| from detectron2.utils.file_io import PathManager |
|
|
| from .vertex_direct_embedder import VertexDirectEmbedder |
| from .vertex_feature_embedder import VertexFeatureEmbedder |
|
|
|
|
| class EmbedderType(Enum): |
| """ |
| Embedder type which defines how vertices are mapped into the embedding space: |
| - "vertex_direct": direct vertex embedding |
| - "vertex_feature": embedding vertex features |
| """ |
|
|
| VERTEX_DIRECT = "vertex_direct" |
| VERTEX_FEATURE = "vertex_feature" |
|
|
|
|
| def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module: |
| """ |
| Create an embedder based on the provided configuration |
| |
| Args: |
| embedder_spec (CfgNode): embedder configuration |
| embedder_dim (int): embedding space dimensionality |
| Return: |
| An embedder instance for the specified configuration |
| Raises ValueError, in case of unexpected embedder type |
| """ |
| embedder_type = EmbedderType(embedder_spec.TYPE) |
| if embedder_type == EmbedderType.VERTEX_DIRECT: |
| embedder = VertexDirectEmbedder( |
| num_vertices=embedder_spec.NUM_VERTICES, |
| embed_dim=embedder_dim, |
| ) |
| if embedder_spec.INIT_FILE != "": |
| embedder.load(embedder_spec.INIT_FILE) |
| elif embedder_type == EmbedderType.VERTEX_FEATURE: |
| embedder = VertexFeatureEmbedder( |
| num_vertices=embedder_spec.NUM_VERTICES, |
| feature_dim=embedder_spec.FEATURE_DIM, |
| embed_dim=embedder_dim, |
| train_features=embedder_spec.FEATURES_TRAINABLE, |
| ) |
| if embedder_spec.INIT_FILE != "": |
| embedder.load(embedder_spec.INIT_FILE) |
| else: |
| raise ValueError(f"Unexpected embedder type {embedder_type}") |
|
|
| if not embedder_spec.IS_TRAINABLE: |
| embedder.requires_grad_(False) |
|
|
| return embedder |
|
|
|
|
| class Embedder(nn.Module): |
| """ |
| Embedder module that serves as a container for embedders to use with different |
| meshes. Extends Module to automatically save / load state dict. |
| """ |
|
|
| DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder." |
|
|
| def __init__(self, cfg: CfgNode): |
| """ |
| Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule |
| "embedder_{i}". |
| |
| Args: |
| cfg (CfgNode): configuration options |
| """ |
| super(Embedder, self).__init__() |
| self.mesh_names = set() |
| embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE |
| logger = logging.getLogger(__name__) |
| for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items(): |
| logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}") |
| self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim)) |
| self.mesh_names.add(mesh_name) |
| if cfg.MODEL.WEIGHTS != "": |
| self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS) |
|
|
| def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None): |
| if prefix is None: |
| prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX |
| state_dict = None |
| if fpath.endswith(".pkl"): |
| with PathManager.open(fpath, "rb") as hFile: |
| state_dict = pickle.load(hFile, encoding="latin1") |
| else: |
| with PathManager.open(fpath, "rb") as hFile: |
| state_dict = torch.load(hFile, map_location=torch.device("cpu")) |
| if state_dict is not None and "model" in state_dict: |
| state_dict_local = {} |
| for key in state_dict["model"]: |
| if key.startswith(prefix): |
| v_key = state_dict["model"][key] |
| if isinstance(v_key, np.ndarray): |
| v_key = torch.from_numpy(v_key) |
| state_dict_local[key[len(prefix) :]] = v_key |
| |
| self.load_state_dict(state_dict_local, strict=False) |
|
|
| def forward(self, mesh_name: str) -> torch.Tensor: |
| """ |
| Produce vertex embeddings for the specific mesh; vertex embeddings are |
| a tensor of shape [N, D] where: |
| N = number of vertices |
| D = number of dimensions in the embedding space |
| Args: |
| mesh_name (str): name of a mesh for which to obtain vertex embeddings |
| Return: |
| Vertex embeddings, a tensor of shape [N, D] |
| """ |
| return getattr(self, f"embedder_{mesh_name}")() |
|
|
| def has_embeddings(self, mesh_name: str) -> bool: |
| return hasattr(self, f"embedder_{mesh_name}") |
|
|