colmap / feature_matcher_utilities.py
vslamlab's picture
Upload folder using huggingface_hub
e8ace62 verified
import torch
import numpy as np
import cv2
from lightglue import LightGlue
from lightglue.utils import rbd
from lightglue import SuperPoint, SIFT
from lightglue.utils import load_image
def unrotate_kps_W(kps_rot, k, H, W):
# Ensure inputs are Numpy
if hasattr(kps_rot, 'cpu'): kps_rot = kps_rot.cpu().numpy()
if hasattr(k, 'cpu'): k = k.cpu().numpy()
# Squeeze if necessary
if k.ndim > 1: k = k.squeeze()
if kps_rot.ndim > 2: kps_rot = kps_rot.squeeze()
x_r = kps_rot[:, 0]
y_r = kps_rot[:, 1]
x = np.zeros_like(x_r)
y = np.zeros_like(y_r)
mask0 = (k == 0)
x[mask0], y[mask0] = x_r[mask0], y_r[mask0]
mask1 = (k == 1)
x[mask1], y[mask1] = (W - 1) - y_r[mask1], x_r[mask1]
mask2 = (k == 2)
x[mask2], y[mask2] = (W - 1) - x_r[mask2], (H - 1) - y_r[mask2]
mask3 = (k == 3)
x[mask3], y[mask3] = y_r[mask3], (H - 1) - x_r[mask3]
return np.stack([x, y], axis=-1)
def extract_keypoints(path_to_image0, features='superpoint', rotations = [0,1,2,3]):
# --- Models on GPU ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Load images as Torch tensors (3,H,W) in [0,1] ---
timg = load_image(path_to_image0).to(device)
_, h, w = timg.shape
if features == 'sift':
extractor = SIFT(max_num_keypoints=2048).eval().to(device)
feats = extractor.extract(timg)
return feats , h, w
if features == 'superpoint':
extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
# --- Extract local features ---
feats = {}
for k in (rotations):
timg_rotated = torch.rot90(timg, k, dims=(1, 2))
feats[k] = extractor.extract(timg_rotated)
#print(f"Extracted {feats[k]['keypoints'].shape[1]} keypoints for rotation {k}")
# --- Merge features back to original coordinate system ---
all_keypoints = []
all_scores = []
all_descriptors = []
all_rotations = []
for k, feat in feats.items():
kpts = feat['keypoints'] # Shape (1, N, 2)
num_kpts = kpts.shape[1]
# if k == 0:
# kpts_corrected = kpts
# elif k == 1:
# kpts_corrected = torch.stack(
# [w - 1 - kpts[..., 1], kpts[..., 0]], dim=-1
# )
# elif k == 2:
# kpts_corrected = torch.stack(
# [w - 1 - kpts[..., 0], h - 1 - kpts[..., 1]], dim=-1
# )
# elif k == 3:
# kpts_corrected = torch.stack(
# [kpts[..., 1], h - 1 - kpts[..., 0]], dim=-1
# )
rot_indices = torch.full((1, num_kpts), k, dtype=torch.long, device=device)
all_keypoints.append(feat['keypoints'])
all_scores.append(feat['keypoint_scores'])
all_descriptors.append(feat['descriptors'])
all_rotations.append(rot_indices)
# Concatenate all features along the keypoint dimension (dim=1)
feats_merged = {
'keypoints': torch.cat(all_keypoints, dim=1),
'keypoint_scores': torch.cat(all_scores, dim=1),
'descriptors': torch.cat(all_descriptors, dim=1),
'rotations': torch.cat(all_rotations, dim=1)
}
num_kpts = feats_merged['keypoints'].shape[1]
# perm = torch.randperm(num_kpts, device=device)
# feats_merged['keypoints'] = feats_merged['keypoints'][:, perm, :]
# feats_merged['keypoint_scores'] = feats_merged['keypoint_scores'][:, perm]
# feats_merged['descriptors'] = feats_merged['descriptors'][:, perm, :]
# Optional: If you want to retain other keys like 'shape' or 'image_size'
#feats_merged['image_size'] = torch.tensor([w, h], device=device).unsqueeze(0)
#feats_merged['scales'] = torch.tensor([w, h], device=device).unsqueeze(0)
# for f in feats_merged:
# if 'scales' not in f:
# f['scales'] = torch.ones(all_keypoints.shape[:-1], device=device)
# if 'oris' not in f:
# f['oris'] = torch.zeros(all_keypoints.shape[:-1], device=device)
return feats_merged , feats, h, w
def lightglue_matching(feats0, feats1, matcher = None):
if matcher is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
matcher = LightGlue(features='superpoint').eval().to(device)
out_k = matcher({'image0': feats0, 'image1': feats1})
_, _, out_k = [rbd(x) for x in [feats0, feats1, out_k]] # remove batch dim
return out_k['matches']
def feature_matching(feats0, feats1, matcher = None, exhaustive = True):
best_rot = 0
best_num_matches = 0
matches_tensor = None
# Find the best rotation alignment
for rot in [0,1,2,3]:
matches_tensor_rot = lightglue_matching(feats0[0], feats1[rot], matcher = matcher)
if (len(matches_tensor_rot) > best_num_matches):
best_num_matches = len(matches_tensor_rot)
best_rot = rot
matches_tensor = matches_tensor_rot
if matches_tensor is not None and len(matches_tensor) > 0:
matches_np = matches_tensor.cpu().numpy().astype(np.uint32)
else:
return None
# Adjust matches to account for rotations
for k in range(best_rot):
matches_np[:,1] += feats1[k]['keypoints'].shape[1]
all_matches = [matches_np]
if not exhaustive:
return matches_np
# Find the other rotation combinations
rots = []
for rot in [1, 2, 3]:
rot_i = best_rot + rot
if rot_i >=4:
rot_i = rot_i -4
rots.append(rot_i)
# Compute matches for the other rotation combinations
for rot_i in [1,2,3]:
rot_j = rots[rot_i-1]
matches_tensor_rot = lightglue_matching(feats0[rot_i], feats1[rot_j], matcher = matcher)
matches_np_i = matches_tensor_rot.cpu().numpy().astype(np.uint32)
if rot_i > 0:
for k in range(rot_i):
matches_np_i[:,0] += feats0[k]['keypoints'].shape[1]
if rot_j > 0:
for k in range(rot_j):
matches_np_i[:,1] += feats1[k]['keypoints'].shape[1]
all_matches.append(matches_np_i)
print(f"Rotation {rot_i} vs {rot_j}: {len(matches_tensor_rot)} matches")
# Stack all matches together
matches_stacked = (
np.vstack(all_matches) if len(all_matches) and all_matches[0].size else
np.empty((0, 2), dtype=np.uint32)
)
# if best_rot > 0:
# for k in range(best_rot):
# print(f"Adjusting for rotation {k}")
# matches_np[:,1] += feats1[k]['keypoints'].shape[1]
# return matches_np
return matches_stacked