| | |
| |
|
| | import torch |
| | from torch.nn import functional as F |
| |
|
| |
|
| | def squared_euclidean_distance_matrix(pts1: torch.Tensor, pts2: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Get squared Euclidean Distance Matrix |
| | Computes pairwise squared Euclidean distances between points |
| | |
| | Args: |
| | pts1: Tensor [M x D], M is the number of points, D is feature dimensionality |
| | pts2: Tensor [N x D], N is the number of points, D is feature dimensionality |
| | |
| | Return: |
| | Tensor [M, N]: matrix of squared Euclidean distances; at index (m, n) |
| | it contains || pts1[m] - pts2[n] ||^2 |
| | """ |
| | edm = torch.mm(-2 * pts1, pts2.t()) |
| | edm += (pts1 * pts1).sum(1, keepdim=True) + (pts2 * pts2).sum(1, keepdim=True).t() |
| | return edm.contiguous() |
| |
|
| |
|
| | def normalize_embeddings(embeddings: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor: |
| | """ |
| | Normalize N D-dimensional embedding vectors arranged in a tensor [N, D] |
| | |
| | Args: |
| | embeddings (tensor [N, D]): N D-dimensional embedding vectors |
| | epsilon (float): minimum value for a vector norm |
| | Return: |
| | Normalized embeddings (tensor [N, D]), such that L2 vector norms are all equal to 1. |
| | """ |
| | return embeddings / torch.clamp(embeddings.norm(p=None, dim=1, keepdim=True), min=epsilon) |
| |
|
| |
|
| | def get_closest_vertices_mask_from_ES( |
| | E: torch.Tensor, |
| | S: torch.Tensor, |
| | h: int, |
| | w: int, |
| | mesh_vertex_embeddings: torch.Tensor, |
| | device: torch.device, |
| | ): |
| | """ |
| | Interpolate Embeddings and Segmentations to the size of a given bounding box, |
| | and compute closest vertices and the segmentation mask |
| | |
| | Args: |
| | E (tensor [1, D, H, W]): D-dimensional embedding vectors for every point of the |
| | default-sized box |
| | S (tensor [1, 2, H, W]): 2-dimensional segmentation mask for every point of the |
| | default-sized box |
| | h (int): height of the target bounding box |
| | w (int): width of the target bounding box |
| | mesh_vertex_embeddings (tensor [N, D]): vertex embeddings for a chosen mesh |
| | N is the number of vertices in the mesh, D is feature dimensionality |
| | device (torch.device): device to move the tensors to |
| | Return: |
| | Closest Vertices (tensor [h, w]), int, for every point of the resulting box |
| | Segmentation mask (tensor [h, w]), boolean, for every point of the resulting box |
| | """ |
| | embedding_resized = F.interpolate(E, size=(h, w), mode="bilinear")[0].to(device) |
| | coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0].to(device) |
| | mask = coarse_segm_resized.argmax(0) > 0 |
| | closest_vertices = torch.zeros(mask.shape, dtype=torch.long, device=device) |
| | all_embeddings = embedding_resized[:, mask].t() |
| | size_chunk = 10_000 |
| | edm = [] |
| | if len(all_embeddings) == 0: |
| | return closest_vertices, mask |
| | for chunk in range((len(all_embeddings) - 1) // size_chunk + 1): |
| | chunk_embeddings = all_embeddings[size_chunk * chunk : size_chunk * (chunk + 1)] |
| | edm.append( |
| | torch.argmin( |
| | squared_euclidean_distance_matrix(chunk_embeddings, mesh_vertex_embeddings), dim=1 |
| | ) |
| | ) |
| | closest_vertices[mask] = torch.cat(edm) |
| | return closest_vertices, mask |
| |
|