| | |
| |
|
| | import pickle |
| | import torch |
| | from torch import nn |
| |
|
| | from detectron2.utils.file_io import PathManager |
| |
|
| | from .utils import normalize_embeddings |
| |
|
| |
|
| | class VertexFeatureEmbedder(nn.Module): |
| | """ |
| | Class responsible for embedding vertex features. Mapping from |
| | feature space to the embedding space is a tensor of size [K, D], where |
| | K = number of dimensions in the feature space |
| | D = number of dimensions in the embedding space |
| | Vertex features is a tensor of size [N, K], where |
| | N = number of vertices |
| | K = number of dimensions in the feature space |
| | Vertex embeddings are computed as F * E = tensor of size [N, D] |
| | """ |
| |
|
| | def __init__( |
| | self, num_vertices: int, feature_dim: int, embed_dim: int, train_features: bool = False |
| | ): |
| | """ |
| | Initialize embedder, set random embeddings |
| | |
| | Args: |
| | num_vertices (int): number of vertices to embed |
| | feature_dim (int): number of dimensions in the feature space |
| | embed_dim (int): number of dimensions in the embedding space |
| | train_features (bool): determines whether vertex features should |
| | be trained (default: False) |
| | """ |
| | super(VertexFeatureEmbedder, self).__init__() |
| | if train_features: |
| | self.features = nn.Parameter(torch.Tensor(num_vertices, feature_dim)) |
| | else: |
| | self.register_buffer("features", torch.Tensor(num_vertices, feature_dim)) |
| | self.embeddings = nn.Parameter(torch.Tensor(feature_dim, embed_dim)) |
| | self.reset_parameters() |
| |
|
| | @torch.no_grad() |
| | def reset_parameters(self): |
| | self.features.zero_() |
| | self.embeddings.zero_() |
| |
|
| | def forward(self) -> torch.Tensor: |
| | """ |
| | Produce vertex embeddings, a tensor of shape [N, D] where: |
| | N = number of vertices |
| | D = number of dimensions in the embedding space |
| | |
| | Return: |
| | Full vertex embeddings, a tensor of shape [N, D] |
| | """ |
| | return normalize_embeddings(torch.mm(self.features, self.embeddings)) |
| |
|
| | @torch.no_grad() |
| | def load(self, fpath: str): |
| | """ |
| | Load data from a file |
| | |
| | Args: |
| | fpath (str): file path to load data from |
| | """ |
| | with PathManager.open(fpath, "rb") as hFile: |
| | data = pickle.load(hFile) |
| | for name in ["features", "embeddings"]: |
| | if name in data: |
| | getattr(self, name).copy_( |
| | torch.tensor(data[name]).float().to(device=getattr(self, name).device) |
| | ) |
| |
|