|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
import torch |
|
|
from pathlib import Path |
|
|
import math |
|
|
import numpy as np |
|
|
|
|
|
from torch import nn |
|
|
from PIL import Image |
|
|
from torchvision.transforms import ToTensor |
|
|
from romatch.utils.kde import kde |
|
|
|
|
|
class BasicLayer(nn.Module): |
|
|
""" |
|
|
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU |
|
|
""" |
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True): |
|
|
super().__init__() |
|
|
self.layer = nn.Sequential( |
|
|
nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias), |
|
|
nn.BatchNorm2d(out_channels, affine=False), |
|
|
nn.ReLU(inplace = True) if relu else nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.layer(x) |
|
|
|
|
|
class TinyRoMa(nn.Module): |
|
|
""" |
|
|
Implementation of architecture described in |
|
|
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." |
|
|
""" |
|
|
|
|
|
def __init__(self, xfeat = None, |
|
|
freeze_xfeat = True, |
|
|
sample_mode = "threshold_balanced", |
|
|
symmetric = False, |
|
|
exact_softmax = False): |
|
|
super().__init__() |
|
|
del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher |
|
|
if freeze_xfeat: |
|
|
xfeat.train(False) |
|
|
self.xfeat = [xfeat] |
|
|
else: |
|
|
self.xfeat = nn.ModuleList([xfeat]) |
|
|
self.freeze_xfeat = freeze_xfeat |
|
|
match_dim = 256 |
|
|
self.coarse_matcher = nn.Sequential( |
|
|
BasicLayer(64+64+2, match_dim,), |
|
|
BasicLayer(match_dim, match_dim,), |
|
|
BasicLayer(match_dim, match_dim,), |
|
|
BasicLayer(match_dim, match_dim,), |
|
|
nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0)) |
|
|
fine_match_dim = 64 |
|
|
self.fine_matcher = nn.Sequential( |
|
|
BasicLayer(24+24+2, fine_match_dim,), |
|
|
BasicLayer(fine_match_dim, fine_match_dim,), |
|
|
BasicLayer(fine_match_dim, fine_match_dim,), |
|
|
BasicLayer(fine_match_dim, fine_match_dim,), |
|
|
nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),) |
|
|
self.sample_mode = sample_mode |
|
|
self.sample_thresh = 0.05 |
|
|
self.symmetric = symmetric |
|
|
self.exact_softmax = exact_softmax |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.fine_matcher[-1].weight.device |
|
|
|
|
|
def preprocess_tensor(self, x): |
|
|
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """ |
|
|
H, W = x.shape[-2:] |
|
|
_H, _W = (H//32) * 32, (W//32) * 32 |
|
|
rh, rw = H/_H, W/_W |
|
|
|
|
|
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False) |
|
|
return x, rh, rw |
|
|
|
|
|
def forward_single(self, x): |
|
|
with torch.inference_mode(self.freeze_xfeat or not self.training): |
|
|
xfeat = self.xfeat[0] |
|
|
with torch.no_grad(): |
|
|
x = x.mean(dim=1, keepdim = True) |
|
|
x = xfeat.norm(x) |
|
|
|
|
|
|
|
|
x1 = xfeat.block1(x) |
|
|
x2 = xfeat.block2(x1 + xfeat.skip1(x)) |
|
|
x3 = xfeat.block3(x2) |
|
|
x4 = xfeat.block4(x3) |
|
|
x5 = xfeat.block5(x4) |
|
|
x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear') |
|
|
x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear') |
|
|
feats = xfeat.block_fusion( x3 + x4 + x5 ) |
|
|
if self.freeze_xfeat: |
|
|
return x2.clone(), feats.clone() |
|
|
return x2, feats |
|
|
|
|
|
def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None): |
|
|
if coords.shape[-1] == 2: |
|
|
return self._to_pixel_coordinates(coords, H_A, W_A) |
|
|
|
|
|
if isinstance(coords, (list, tuple)): |
|
|
kpts_A, kpts_B = coords[0], coords[1] |
|
|
else: |
|
|
kpts_A, kpts_B = coords[...,:2], coords[...,2:] |
|
|
return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B) |
|
|
|
|
|
def _to_pixel_coordinates(self, coords, H, W): |
|
|
kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1) |
|
|
return kpts |
|
|
|
|
|
def pos_embed(self, corr_volume: torch.Tensor): |
|
|
B, H1, W1, H0, W0 = corr_volume.shape |
|
|
grid = torch.stack( |
|
|
torch.meshgrid( |
|
|
torch.linspace(-1+1/W1,1-1/W1, W1), |
|
|
torch.linspace(-1+1/H1,1-1/H1, H1), |
|
|
indexing = "xy"), |
|
|
dim = -1).float().to(corr_volume).reshape(H1*W1, 2) |
|
|
down = 4 |
|
|
if not self.training and not self.exact_softmax: |
|
|
grid_lr = torch.stack( |
|
|
torch.meshgrid( |
|
|
torch.linspace(-1+down/W1,1-down/W1, W1//down), |
|
|
torch.linspace(-1+down/H1,1-down/H1, H1//down), |
|
|
indexing = "xy"), |
|
|
dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2) |
|
|
cv = corr_volume |
|
|
best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) |
|
|
P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1) |
|
|
pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr) |
|
|
pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2) |
|
|
|
|
|
else: |
|
|
P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) |
|
|
pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid) |
|
|
return pos_embeddings |
|
|
|
|
|
def visualize_warp(self, warp, certainty, im_A = None, im_B = None, |
|
|
im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False): |
|
|
device = warp.device |
|
|
H,W2,_ = warp.shape |
|
|
W = W2//2 if symmetric else W2 |
|
|
if im_A is None: |
|
|
from PIL import Image |
|
|
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") |
|
|
if not isinstance(im_A, torch.Tensor): |
|
|
im_A = im_A.resize((W,H)) |
|
|
im_B = im_B.resize((W,H)) |
|
|
x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1) |
|
|
if symmetric: |
|
|
x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1) |
|
|
else: |
|
|
if symmetric: |
|
|
x_A = im_A |
|
|
x_B = im_B |
|
|
im_A_transfer_rgb = F.grid_sample( |
|
|
x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False |
|
|
)[0] |
|
|
if symmetric: |
|
|
im_B_transfer_rgb = F.grid_sample( |
|
|
x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False |
|
|
)[0] |
|
|
warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2) |
|
|
white_im = torch.ones((H,2*W),device=device) |
|
|
else: |
|
|
warp_im = im_A_transfer_rgb |
|
|
white_im = torch.ones((H, W), device = device) |
|
|
vis_im = certainty * warp_im + (1 - certainty) * white_im |
|
|
if save_path is not None: |
|
|
from romatch.utils import tensor_to_pil |
|
|
tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path) |
|
|
return vis_im |
|
|
|
|
|
def corr_volume(self, feat0, feat1): |
|
|
""" |
|
|
input: |
|
|
feat0 -> torch.Tensor(B, C, H, W) |
|
|
feat1 -> torch.Tensor(B, C, H, W) |
|
|
return: |
|
|
corr_volume -> torch.Tensor(B, H, W, H, W) |
|
|
""" |
|
|
B, C, H0, W0 = feat0.shape |
|
|
B, C, H1, W1 = feat1.shape |
|
|
feat0 = feat0.view(B, C, H0*W0) |
|
|
feat1 = feat1.view(B, C, H1*W1) |
|
|
corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) |
|
|
return corr_volume |
|
|
|
|
|
@torch.inference_mode() |
|
|
def match_from_path(self, im0_path, im1_path): |
|
|
device = self.device |
|
|
im0 = ToTensor()(Image.open(im0_path))[None].to(device) |
|
|
im1 = ToTensor()(Image.open(im1_path))[None].to(device) |
|
|
return self.match(im0, im1, batched = False) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def match(self, im0, im1, *args, batched = True): |
|
|
|
|
|
if isinstance(im0, (str, Path)): |
|
|
return self.match_from_path(im0, im1) |
|
|
elif isinstance(im0, Image.Image): |
|
|
batched = False |
|
|
device = self.device |
|
|
im0 = ToTensor()(im0)[None].to(device) |
|
|
im1 = ToTensor()(im1)[None].to(device) |
|
|
|
|
|
B,C,H0,W0 = im0.shape |
|
|
B,C,H1,W1 = im1.shape |
|
|
self.train(False) |
|
|
corresps = self.forward({"im_A":im0, "im_B":im1}) |
|
|
|
|
|
flow = F.interpolate( |
|
|
corresps[4]["flow"], |
|
|
size = (H0, W0), |
|
|
mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2) |
|
|
grid = torch.stack( |
|
|
torch.meshgrid( |
|
|
torch.linspace(-1+1/W0,1-1/W0, W0), |
|
|
torch.linspace(-1+1/H0,1-1/H0, H0), |
|
|
indexing = "xy"), |
|
|
dim = -1).float().to(flow.device).expand(B, H0, W0, 2) |
|
|
|
|
|
certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False) |
|
|
warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid() |
|
|
if batched: |
|
|
return warp, cert |
|
|
else: |
|
|
return warp[0], cert[0] |
|
|
|
|
|
def sample( |
|
|
self, |
|
|
matches, |
|
|
certainty, |
|
|
num=5_000, |
|
|
): |
|
|
H,W,_ = matches.shape |
|
|
if "threshold" in self.sample_mode: |
|
|
upper_thresh = self.sample_thresh |
|
|
certainty = certainty.clone() |
|
|
certainty[certainty > upper_thresh] = 1 |
|
|
matches, certainty = ( |
|
|
matches.reshape(-1, 4), |
|
|
certainty.reshape(-1), |
|
|
) |
|
|
expansion_factor = 4 if "balanced" in self.sample_mode else 1 |
|
|
good_samples = torch.multinomial(certainty, |
|
|
num_samples = min(expansion_factor*num, len(certainty)), |
|
|
replacement=False) |
|
|
good_matches, good_certainty = matches[good_samples], certainty[good_samples] |
|
|
if "balanced" not in self.sample_mode: |
|
|
return good_matches, good_certainty |
|
|
use_half = True if matches.device.type == "cuda" else False |
|
|
down = 1 if matches.device.type == "cuda" else 8 |
|
|
density = kde(good_matches, std=0.1, half = use_half, down = down) |
|
|
p = 1 / (density+1) |
|
|
p[density < 10] = 1e-7 |
|
|
balanced_samples = torch.multinomial(p, |
|
|
num_samples = min(num,len(good_certainty)), |
|
|
replacement=False) |
|
|
return good_matches[balanced_samples], good_certainty[balanced_samples] |
|
|
|
|
|
|
|
|
def forward(self, batch): |
|
|
""" |
|
|
input: |
|
|
x -> torch.Tensor(B, C, H, W) grayscale or rgb images |
|
|
return: |
|
|
|
|
|
""" |
|
|
im0 = batch["im_A"] |
|
|
im1 = batch["im_B"] |
|
|
corresps = {} |
|
|
im0, rh0, rw0 = self.preprocess_tensor(im0) |
|
|
im1, rh1, rw1 = self.preprocess_tensor(im1) |
|
|
B, C, H0, W0 = im0.shape |
|
|
B, C, H1, W1 = im1.shape |
|
|
to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None] |
|
|
|
|
|
if im0.shape[-2:] == im1.shape[-2:]: |
|
|
x = torch.cat([im0, im1], dim=0) |
|
|
x = self.forward_single(x) |
|
|
feats_x0_c, feats_x1_c = x[1].chunk(2) |
|
|
feats_x0_f, feats_x1_f = x[0].chunk(2) |
|
|
else: |
|
|
feats_x0_f, feats_x0_c = self.forward_single(im0) |
|
|
feats_x1_f, feats_x1_c = self.forward_single(im1) |
|
|
corr_volume = self.corr_volume(feats_x0_c, feats_x1_c) |
|
|
coarse_warp = self.pos_embed(corr_volume) |
|
|
coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1) |
|
|
feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False) |
|
|
coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1)) |
|
|
coarse_matches = coarse_matches + coarse_matches_delta * to_normalized |
|
|
corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]} |
|
|
coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False) |
|
|
coarse_matches_up_detach = coarse_matches_up.detach() |
|
|
feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False) |
|
|
fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1)) |
|
|
fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized |
|
|
corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]} |
|
|
return corresps |