| |
|
|
| import pickle |
| import torch |
| from torch import nn |
|
|
| from detectron2.utils.file_io import PathManager |
|
|
| from .utils import normalize_embeddings |
|
|
|
|
| class VertexDirectEmbedder(nn.Module): |
| """ |
| Class responsible for embedding vertices. Vertex embeddings take |
| the form of a tensor of size [N, D], where |
| N = number of vertices |
| D = number of dimensions in the embedding space |
| """ |
|
|
| def __init__(self, num_vertices: int, embed_dim: int): |
| """ |
| Initialize embedder, set random embeddings |
| |
| Args: |
| num_vertices (int): number of vertices to embed |
| embed_dim (int): number of dimensions in the embedding space |
| """ |
| super(VertexDirectEmbedder, self).__init__() |
| self.embeddings = nn.Parameter(torch.Tensor(num_vertices, embed_dim)) |
| self.reset_parameters() |
|
|
| @torch.no_grad() |
| def reset_parameters(self): |
| """ |
| Reset embeddings to random values |
| """ |
| self.embeddings.zero_() |
|
|
| def forward(self) -> torch.Tensor: |
| """ |
| Produce vertex embeddings, a tensor of shape [N, D] where: |
| N = number of vertices |
| D = number of dimensions in the embedding space |
| |
| Return: |
| Full vertex embeddings, a tensor of shape [N, D] |
| """ |
| return normalize_embeddings(self.embeddings) |
|
|
| @torch.no_grad() |
| def load(self, fpath: str): |
| """ |
| Load data from a file |
| |
| Args: |
| fpath (str): file path to load data from |
| """ |
| with PathManager.open(fpath, "rb") as hFile: |
| data = pickle.load(hFile) |
| for name in ["embeddings"]: |
| if name in data: |
| getattr(self, name).copy_( |
| torch.tensor(data[name]).float().to(device=getattr(self, name).device) |
| ) |
|
|