|
|
"""MegaLoc: One Retrieval to Place Them All |
|
|
|
|
|
This module implements the MegaLoc model for visual place recognition. |
|
|
The model combines a Vision Transformer backbone with an optimal transport-based |
|
|
feature aggregation module. |
|
|
|
|
|
Paper: https://arxiv.org/abs/2502.17237 |
|
|
License: MIT |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms.functional as tfm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor: |
|
|
r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem. |
|
|
This function solves the optimization problem and returns the OT matrix for the given parameters. |
|
|
Args: |
|
|
log_a : torch.Tensor |
|
|
Source weights |
|
|
log_b : torch.Tensor |
|
|
Target weights |
|
|
M : torch.Tensor |
|
|
metric cost matrix |
|
|
num_iters : int, default=100 |
|
|
The number of iterations. |
|
|
reg : float, default=1.0 |
|
|
regularization value |
|
|
""" |
|
|
M = M / reg |
|
|
|
|
|
u, v = torch.zeros_like(log_a), torch.zeros_like(log_b) |
|
|
|
|
|
for _ in range(num_iters): |
|
|
u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze() |
|
|
v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze() |
|
|
|
|
|
return M + u.unsqueeze(2) + v.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_matching_probs(S, dustbin_score=1.0, num_iters=3, reg=1.0): |
|
|
"""sinkhorn""" |
|
|
batch_size, m, n = S.size() |
|
|
|
|
|
S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device) |
|
|
S_aug[:, :m, :n] = S |
|
|
S_aug[:, m, :] = dustbin_score |
|
|
|
|
|
|
|
|
norm = -torch.tensor(math.log(n + m), device=S.device) |
|
|
log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous() |
|
|
log_a[-1] = log_a[-1] + math.log(n - m) |
|
|
log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1) |
|
|
log_P = log_otp_solver(log_a, log_b, S_aug, num_iters=num_iters, reg=reg) |
|
|
return log_P - norm |
|
|
|
|
|
|
|
|
class FeatureAggregator(nn.Module): |
|
|
"""Optimal transport-based aggregation of local features into global descriptor. |
|
|
|
|
|
This module aggregates local patch features into a compact global representation |
|
|
using differentiable optimal transport. |
|
|
|
|
|
Args: |
|
|
num_channels: Number of input feature channels (from backbone) |
|
|
num_clusters: Number of cluster centers |
|
|
cluster_dim: Dimensionality of cluster descriptors |
|
|
token_dim: Dimensionality of global scene token |
|
|
mlp_dim: Hidden dimension for MLPs |
|
|
dropout: Dropout probability (0 to disable) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_channels=1536, |
|
|
num_clusters=64, |
|
|
cluster_dim=128, |
|
|
token_dim=256, |
|
|
mlp_dim=512, |
|
|
dropout=0.3, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.num_channels = num_channels |
|
|
self.num_clusters = num_clusters |
|
|
self.cluster_dim = cluster_dim |
|
|
self.token_dim = token_dim |
|
|
self.mlp_dim = mlp_dim |
|
|
|
|
|
if dropout > 0: |
|
|
dropout = nn.Dropout(dropout) |
|
|
else: |
|
|
dropout = nn.Identity() |
|
|
|
|
|
|
|
|
self.token_features = nn.Sequential( |
|
|
nn.Linear(self.num_channels, self.mlp_dim), nn.ReLU(), nn.Linear(self.mlp_dim, self.token_dim) |
|
|
) |
|
|
|
|
|
self.cluster_features = nn.Sequential( |
|
|
nn.Conv2d(self.num_channels, self.mlp_dim, 1), |
|
|
dropout, |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(self.mlp_dim, self.cluster_dim, 1), |
|
|
) |
|
|
|
|
|
self.score = nn.Sequential( |
|
|
nn.Conv2d(self.num_channels, self.mlp_dim, 1), |
|
|
dropout, |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(self.mlp_dim, self.num_clusters, 1), |
|
|
) |
|
|
|
|
|
self.dust_bin = nn.Parameter(torch.tensor(1.0)) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: Tuple of (features, token) |
|
|
features: [B, C, H, W] spatial feature map |
|
|
token: [B, C] global CLS token |
|
|
|
|
|
Returns: |
|
|
Global descriptor [B, num_clusters * cluster_dim + token_dim] |
|
|
""" |
|
|
x, t = x |
|
|
|
|
|
f = self.cluster_features(x).flatten(2) |
|
|
p = self.score(x).flatten(2) |
|
|
t = self.token_features(t) |
|
|
|
|
|
p = get_matching_probs(p, self.dust_bin, 3) |
|
|
p = torch.exp(p) |
|
|
p = p[:, :-1, :] |
|
|
|
|
|
p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1) |
|
|
f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1) |
|
|
|
|
|
f = torch.cat( |
|
|
[ |
|
|
F.normalize(t, p=2, dim=-1), |
|
|
F.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1), |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
return F.normalize(f, p=2, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbedding(nn.Module): |
|
|
"""Convert image patches to embeddings using a convolutional layer.""" |
|
|
|
|
|
def __init__(self, image_size: int = 518, patch_size: int = 14, in_channels: int = 3, embed_dim: int = 768): |
|
|
super().__init__() |
|
|
self.image_size = image_size |
|
|
self.patch_size = patch_size |
|
|
self.num_patches = (image_size // patch_size) ** 2 |
|
|
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.proj(x) |
|
|
x = x.flatten(2) |
|
|
x = x.transpose(1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
"""Learnable per-channel scaling as used in CaiT and DINOv2.""" |
|
|
|
|
|
def __init__(self, dim: int, init_value: float = 1e-5): |
|
|
super().__init__() |
|
|
self.gamma = nn.Parameter(init_value * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x * self.gamma |
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
"""Multi-head self-attention module.""" |
|
|
|
|
|
def __init__( |
|
|
self, dim: int, num_heads: int = 12, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0 |
|
|
): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.scale = self.head_dim**-0.5 |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, N, C = x.shape |
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) |
|
|
qkv = qkv.permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self.attn_drop(attn) |
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
"""MLP module with GELU activation.""" |
|
|
|
|
|
def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, drop: float = 0.0): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = nn.GELU() |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""Vision Transformer block with LayerScale.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
qkv_bias: bool = True, |
|
|
drop: float = 0.0, |
|
|
attn_drop: float = 0.0, |
|
|
init_values: float = 1e-5, |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
|
|
self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) |
|
|
self.ls1 = LayerScale(dim, init_value=init_values) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
|
|
self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop) |
|
|
self.ls2 = LayerScale(dim, init_value=init_values) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = x + self.ls1(self.attn(self.norm1(x))) |
|
|
x = x + self.ls2(self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class DINOv2(nn.Module): |
|
|
"""DINOv2 Vision Transformer backbone for feature extraction. |
|
|
|
|
|
This implements a ViT-B/14 architecture compatible with DINOv2 weights. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_size: int = 518, |
|
|
patch_size: int = 14, |
|
|
in_channels: int = 3, |
|
|
embed_dim: int = 768, |
|
|
depth: int = 12, |
|
|
num_heads: int = 12, |
|
|
mlp_ratio: float = 4.0, |
|
|
qkv_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
self.embed_dim = embed_dim |
|
|
self.num_channels = embed_dim |
|
|
|
|
|
self.patch_embed = PatchEmbedding( |
|
|
image_size=image_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim |
|
|
) |
|
|
|
|
|
self.interpolate_offset = 0.1 |
|
|
self.interpolate_antialias = False |
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
num_patches = (image_size // patch_size) ** 2 |
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
TransformerBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) |
|
|
for _ in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim, eps=1e-6) |
|
|
|
|
|
def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: |
|
|
"""Interpolate positional encoding for different input sizes.""" |
|
|
previous_dtype = x.dtype |
|
|
npatch = x.shape[1] - 1 |
|
|
N = self.pos_embed.shape[1] - 1 |
|
|
|
|
|
if npatch == N and w == h: |
|
|
return self.pos_embed |
|
|
|
|
|
pos_embed = self.pos_embed.float() |
|
|
class_pos_embed = pos_embed[:, 0] |
|
|
patch_pos_embed = pos_embed[:, 1:] |
|
|
|
|
|
dim = x.shape[-1] |
|
|
w0 = w // self.patch_size |
|
|
h0 = h // self.patch_size |
|
|
M = int(math.sqrt(N)) |
|
|
|
|
|
sx = float(w0 + self.interpolate_offset) / M |
|
|
sy = float(h0 + self.interpolate_offset) / M |
|
|
|
|
|
patch_pos_embed = F.interpolate( |
|
|
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), |
|
|
scale_factor=(sx, sy), |
|
|
mode="bicubic", |
|
|
antialias=self.interpolate_antialias, |
|
|
) |
|
|
|
|
|
assert (w0, h0) == patch_pos_embed.shape[-2:] |
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
|
|
|
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) |
|
|
|
|
|
def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Extract features from images. |
|
|
|
|
|
Args: |
|
|
images: Input images [B, 3, H, W] where H, W are multiples of 14 |
|
|
|
|
|
Returns: |
|
|
Tuple of (patch_features [B, 768, H//14, W//14], cls_token [B, 768]) |
|
|
""" |
|
|
B, _, H, W = images.shape |
|
|
|
|
|
x = self.patch_embed(images) |
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
x = x + self.interpolate_pos_encoding(x, H, W) |
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
cls_token = x[:, 0] |
|
|
patch_tokens = x[:, 1:] |
|
|
patch_features = patch_tokens.reshape(B, H // self.patch_size, W // self.patch_size, self.embed_dim).permute( |
|
|
0, 3, 1, 2 |
|
|
) |
|
|
|
|
|
return patch_features, cls_token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class L2Norm(nn.Module): |
|
|
def __init__(self, dim=1): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, x): |
|
|
return F.normalize(x, p=2.0, dim=self.dim) |
|
|
|
|
|
|
|
|
class Aggregator(nn.Module): |
|
|
def __init__(self, feat_dim, agg_config, salad_out_dim): |
|
|
super().__init__() |
|
|
self.agg = FeatureAggregator(**agg_config) |
|
|
self.linear = nn.Linear(salad_out_dim, feat_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.agg(x) |
|
|
return self.linear(x) |
|
|
|
|
|
|
|
|
class MegaLoc(nn.Module): |
|
|
"""MegaLoc: Unified visual place recognition model. |
|
|
|
|
|
Combines a DINOv2 Vision Transformer backbone with optimal transport-based |
|
|
feature aggregation to produce compact, discriminative image descriptors |
|
|
for place recognition and image retrieval tasks. |
|
|
|
|
|
Args: |
|
|
feat_dim: Output descriptor dimensionality (default: 8448) |
|
|
num_clusters: Number of cluster centers for aggregation (default: 64) |
|
|
cluster_dim: Dimensionality of cluster descriptors (default: 256) |
|
|
token_dim: Dimensionality of global scene token (default: 256) |
|
|
mlp_dim: Hidden dimension for MLPs (default: 512) |
|
|
|
|
|
Example: |
|
|
>>> model = torch.hub.load("gmberton/MegaLoc", "get_trained_model") |
|
|
>>> model.eval() |
|
|
>>> descriptor = model(image) # [B, 8448] |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
feat_dim: int = 8448, |
|
|
num_clusters: int = 64, |
|
|
cluster_dim: int = 256, |
|
|
token_dim: int = 256, |
|
|
mlp_dim: int = 512, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.backbone = DINOv2() |
|
|
self.salad_out_dim = num_clusters * cluster_dim + token_dim |
|
|
self.aggregator = Aggregator( |
|
|
feat_dim=feat_dim, |
|
|
agg_config={ |
|
|
"num_channels": self.backbone.num_channels, |
|
|
"num_clusters": num_clusters, |
|
|
"cluster_dim": cluster_dim, |
|
|
"token_dim": token_dim, |
|
|
"mlp_dim": mlp_dim, |
|
|
}, |
|
|
salad_out_dim=self.salad_out_dim, |
|
|
) |
|
|
self.feat_dim = feat_dim |
|
|
self.l2norm = L2Norm() |
|
|
|
|
|
def forward(self, images: torch.Tensor) -> torch.Tensor: |
|
|
"""Extract global descriptor from images. |
|
|
|
|
|
Args: |
|
|
images: Input images [B, 3, H, W] |
|
|
|
|
|
Returns: |
|
|
L2-normalized descriptors [B, feat_dim] |
|
|
""" |
|
|
b, c, h, w = images.shape |
|
|
if h % 14 != 0 or w % 14 != 0: |
|
|
h = round(h / 14) * 14 |
|
|
w = round(w / 14) * 14 |
|
|
images = tfm.resize(images, [h, w], antialias=True) |
|
|
features = self.aggregator(self.backbone(images)) |
|
|
features = self.l2norm(features) |
|
|
return features |
|
|
|