| """Code for the MegaLoc model. |
| Much of the code in this file is from SALAD https://github.com/serizba/salad |
| """ |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as tfm |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
| class MegaLocModel(nn.Module, PyTorchModelHubMixin): |
| def __init__( |
| self, |
| feat_dim=8448, |
| num_clusters=64, |
| cluster_dim=256, |
| token_dim=256, |
| mlp_dim=512, |
| ): |
| super().__init__() |
| self.backbone = DINOv2() |
| self.salad_out_dim = num_clusters * cluster_dim + token_dim |
| self.aggregator = Aggregator( |
| feat_dim=feat_dim, |
| agg_config={ |
| "num_channels": self.backbone.num_channels, |
| "num_clusters": num_clusters, |
| "cluster_dim": cluster_dim, |
| "token_dim": token_dim, |
| "mlp_dim": mlp_dim, |
| }, |
| salad_out_dim=self.salad_out_dim, |
| ) |
| self.feat_dim = feat_dim |
| self.l2norm = L2Norm() |
|
|
| def forward(self, images): |
| b, c, h, w = images.shape |
| if h % 14 != 0 or w % 14 != 0: |
| |
| |
| h = round(h / 14) * 14 |
| w = round(w / 14) * 14 |
| images = tfm.functional.resize(images, [h, w], antialias=True) |
| features = self.aggregator(self.backbone(images)) |
| features = self.l2norm(features) |
| return features |
|
|
|
|
| class L2Norm(nn.Module): |
| def __init__(self, dim=1): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| return F.normalize(x, p=2.0, dim=self.dim) |
|
|
|
|
| class Aggregator(nn.Module): |
| def __init__(self, feat_dim, agg_config, salad_out_dim): |
| super().__init__() |
| self.agg = SALAD(**agg_config) |
| self.linear = nn.Linear(salad_out_dim, feat_dim) |
|
|
| def forward(self, x): |
| x = self.agg(x) |
| return self.linear(x) |
|
|
|
|
| class DINOv2(nn.Module): |
| def __init__(self, num_trainable_blocks=4, norm_layer=True, return_token=True): |
| super().__init__() |
| self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14") |
| self.num_channels = 768 |
| self.num_trainable_blocks = num_trainable_blocks |
| self.norm_layer = norm_layer |
| self.return_token = return_token |
|
|
| def forward(self, x): |
| """ |
| The forward method for the DINOv2 class |
| |
| Parameters: |
| x (torch.Tensor): The input tensor [B, 3, H, W]. H and W should be divisible by 14. |
| |
| Returns: |
| f (torch.Tensor): The feature map [B, C, H // 14, W // 14]. |
| t (torch.Tensor): The token [B, C]. This is only returned if return_token is True. |
| """ |
|
|
| B, C, H, W = x.shape |
|
|
| x = self.model.prepare_tokens_with_masks(x) |
|
|
| |
| with torch.no_grad(): |
| for blk in self.model.blocks[: -self.num_trainable_blocks]: |
| x = blk(x) |
| x = x.detach() |
|
|
| |
| for blk in self.model.blocks[-self.num_trainable_blocks :]: |
| x = blk(x) |
|
|
| if self.norm_layer: |
| x = self.model.norm(x) |
|
|
| t = x[:, 0] |
| f = x[:, 1:] |
|
|
| |
| f = f.reshape((B, H // 14, W // 14, self.num_channels)).permute(0, 3, 1, 2) |
|
|
| if self.return_token: |
| return f, t |
| return f |
|
|
|
|
| |
| |
| def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor: |
| r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem. |
| This function solves the optimization problem and returns the OT matrix for the given parameters. |
| Args: |
| log_a : torch.Tensor |
| Source weights |
| log_b : torch.Tensor |
| Target weights |
| M : torch.Tensor |
| metric cost matrix |
| num_iters : int, default=100 |
| The number of iterations. |
| reg : float, default=1.0 |
| regularization value |
| """ |
| M = M / reg |
|
|
| u, v = torch.zeros_like(log_a), torch.zeros_like(log_b) |
|
|
| for _ in range(num_iters): |
| u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze() |
| v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze() |
|
|
| return M + u.unsqueeze(2) + v.unsqueeze(1) |
|
|
|
|
| |
| |
| def get_matching_probs(S, dustbin_score=1.0, num_iters=3, reg=1.0): |
| """sinkhorn""" |
| batch_size, m, n = S.size() |
| |
| S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device) |
| S_aug[:, :m, :n] = S |
| S_aug[:, m, :] = dustbin_score |
|
|
| |
| norm = -torch.tensor(math.log(n + m), device=S.device) |
| log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous() |
| log_a[-1] = log_a[-1] + math.log(n - m) |
| log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1) |
| log_P = log_otp_solver(log_a, log_b, S_aug, num_iters=num_iters, reg=reg) |
| return log_P - norm |
|
|
|
|
| class SALAD(nn.Module): |
| """ |
| This class represents the Sinkhorn Algorithm for Locally Aggregated Descriptors (SALAD) model. |
| |
| Attributes: |
| num_channels (int): The number of channels of the inputs (d). |
| num_clusters (int): The number of clusters in the model (m). |
| cluster_dim (int): The number of channels of the clusters (l). |
| token_dim (int): The dimension of the global scene token (g). |
| dropout (float): The dropout rate. |
| """ |
|
|
| def __init__( |
| self, |
| num_channels=1536, |
| num_clusters=64, |
| cluster_dim=128, |
| token_dim=256, |
| mlp_dim=512, |
| dropout=0.3, |
| ) -> None: |
| super().__init__() |
|
|
| self.num_channels = num_channels |
| self.num_clusters = num_clusters |
| self.cluster_dim = cluster_dim |
| self.token_dim = token_dim |
| self.mlp_dim = mlp_dim |
|
|
| if dropout > 0: |
| dropout = nn.Dropout(dropout) |
| else: |
| dropout = nn.Identity() |
|
|
| |
| self.token_features = nn.Sequential( |
| nn.Linear(self.num_channels, self.mlp_dim), nn.ReLU(), nn.Linear(self.mlp_dim, self.token_dim) |
| ) |
| |
| self.cluster_features = nn.Sequential( |
| nn.Conv2d(self.num_channels, self.mlp_dim, 1), |
| dropout, |
| nn.ReLU(), |
| nn.Conv2d(self.mlp_dim, self.cluster_dim, 1), |
| ) |
| |
| self.score = nn.Sequential( |
| nn.Conv2d(self.num_channels, self.mlp_dim, 1), |
| dropout, |
| nn.ReLU(), |
| nn.Conv2d(self.mlp_dim, self.num_clusters, 1), |
| ) |
| |
| self.dust_bin = nn.Parameter(torch.tensor(1.0)) |
|
|
| def forward(self, x): |
| """ |
| x (tuple): A tuple containing two elements, f and t. |
| (torch.Tensor): The feature tensors (t_i) [B, C, H // 14, W // 14]. |
| (torch.Tensor): The token tensor (t_{n+1}) [B, C]. |
| |
| Returns: |
| f (torch.Tensor): The global descriptor [B, m*l + g] |
| """ |
| x, t = x |
|
|
| f = self.cluster_features(x).flatten(2) |
| p = self.score(x).flatten(2) |
| t = self.token_features(t) |
|
|
| |
| p = get_matching_probs(p, self.dust_bin, 3) |
| p = torch.exp(p) |
| |
| p = p[:, :-1, :] |
|
|
| p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1) |
| f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1) |
|
|
| f = torch.cat( |
| [ |
| nn.functional.normalize(t, p=2, dim=-1), |
| nn.functional.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1), |
| ], |
| dim=-1, |
| ) |
|
|
| return nn.functional.normalize(f, p=2, dim=-1) |
|
|