| | import math |
| | import torch |
| | import torch.nn as nn |
| |
|
| | |
| | |
| | def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor: |
| | r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem. |
| | This function solves the optimization problem and returns the OT matrix for the given parameters. |
| | Args: |
| | log_a : torch.Tensor |
| | Source weights |
| | log_b : torch.Tensor |
| | Target weights |
| | M : torch.Tensor |
| | metric cost matrix |
| | num_iters : int, default=100 |
| | The number of iterations. |
| | reg : float, default=1.0 |
| | regularization value |
| | """ |
| | M = M / reg |
| |
|
| | u, v = torch.zeros_like(log_a), torch.zeros_like(log_b) |
| |
|
| | for _ in range(num_iters): |
| | u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze() |
| | v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze() |
| |
|
| | return M + u.unsqueeze(2) + v.unsqueeze(1) |
| |
|
| | |
| | |
| | def get_matching_probs(S, dustbin_score = 1.0, num_iters=3, reg=1.0): |
| | """sinkhorn""" |
| | batch_size, m, n = S.size() |
| | |
| | S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device) |
| | S_aug[:, :m, :n] = S |
| | S_aug[:, m, :] = dustbin_score |
| |
|
| | |
| | norm = -torch.tensor(math.log(n + m), device=S.device) |
| | log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous() |
| | log_a[-1] = log_a[-1] + math.log(n-m) |
| | log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1) |
| | log_P = log_otp_solver( |
| | log_a, |
| | log_b, |
| | S_aug, |
| | num_iters=num_iters, |
| | reg=reg |
| | ) |
| | return log_P - norm |
| |
|
| |
|
| | class SALAD(nn.Module): |
| | """ |
| | This class represents the Sinkhorn Algorithm for Locally Aggregated Descriptors (SALAD) model. |
| | |
| | Attributes: |
| | num_channels (int): The number of channels of the inputs (d). |
| | num_clusters (int): The number of clusters in the model (m). |
| | cluster_dim (int): The number of channels of the clusters (l). |
| | token_dim (int): The dimension of the global scene token (g). |
| | dropout (float): The dropout rate. |
| | """ |
| | def __init__(self, |
| | num_channels=1536, |
| | num_clusters=64, |
| | cluster_dim=128, |
| | token_dim=256, |
| | dropout=0.3, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.num_channels = num_channels |
| | self.num_clusters= num_clusters |
| | self.cluster_dim = cluster_dim |
| | self.token_dim = token_dim |
| | |
| | if dropout > 0: |
| | dropout = nn.Dropout(dropout) |
| | else: |
| | dropout = nn.Identity() |
| |
|
| | |
| | self.token_features = nn.Sequential( |
| | nn.Linear(self.num_channels, 512), |
| | nn.ReLU(), |
| | nn.Linear(512, self.token_dim) |
| | ) |
| | |
| | self.cluster_features = nn.Sequential( |
| | nn.Conv2d(self.num_channels, 512, 1), |
| | dropout, |
| | nn.ReLU(), |
| | nn.Conv2d(512, self.cluster_dim, 1) |
| | ) |
| | |
| | self.score = nn.Sequential( |
| | nn.Conv2d(self.num_channels, 512, 1), |
| | dropout, |
| | nn.ReLU(), |
| | nn.Conv2d(512, self.num_clusters, 1), |
| | ) |
| | |
| | self.dust_bin = nn.Parameter(torch.tensor(1.)) |
| |
|
| |
|
| | def forward(self, x): |
| | """ |
| | x (tuple): A tuple containing two elements, f and t. |
| | (torch.Tensor): The feature tensors (t_i) [B, C, H // 14, W // 14]. |
| | (torch.Tensor): The token tensor (t_{n+1}) [B, C]. |
| | |
| | Returns: |
| | f (torch.Tensor): The global descriptor [B, m*l + g] |
| | """ |
| | x, t = x |
| |
|
| | f = self.cluster_features(x).flatten(2) |
| | p = self.score(x).flatten(2) |
| | t = self.token_features(t) |
| |
|
| | |
| | p = get_matching_probs(p, self.dust_bin, 3) |
| | p = torch.exp(p) |
| | |
| | p = p[:, :-1, :] |
| |
|
| |
|
| | p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1) |
| | f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1) |
| |
|
| | f = torch.cat([ |
| | nn.functional.normalize(t, p=2, dim=-1), |
| | nn.functional.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1) |
| | ], dim=-1) |
| |
|
| | return nn.functional.normalize(f, p=2, dim=-1) |