File size: 6,637 Bytes
e8ace62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import torch
import numpy as np
import cv2
from lightglue import LightGlue
from lightglue.utils import rbd
from lightglue import SuperPoint, SIFT
from lightglue.utils import load_image
def unrotate_kps_W(kps_rot, k, H, W):
# Ensure inputs are Numpy
if hasattr(kps_rot, 'cpu'): kps_rot = kps_rot.cpu().numpy()
if hasattr(k, 'cpu'): k = k.cpu().numpy()
# Squeeze if necessary
if k.ndim > 1: k = k.squeeze()
if kps_rot.ndim > 2: kps_rot = kps_rot.squeeze()
x_r = kps_rot[:, 0]
y_r = kps_rot[:, 1]
x = np.zeros_like(x_r)
y = np.zeros_like(y_r)
mask0 = (k == 0)
x[mask0], y[mask0] = x_r[mask0], y_r[mask0]
mask1 = (k == 1)
x[mask1], y[mask1] = (W - 1) - y_r[mask1], x_r[mask1]
mask2 = (k == 2)
x[mask2], y[mask2] = (W - 1) - x_r[mask2], (H - 1) - y_r[mask2]
mask3 = (k == 3)
x[mask3], y[mask3] = y_r[mask3], (H - 1) - x_r[mask3]
return np.stack([x, y], axis=-1)
def extract_keypoints(path_to_image0, features='superpoint', rotations = [0,1,2,3]):
# --- Models on GPU ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Load images as Torch tensors (3,H,W) in [0,1] ---
timg = load_image(path_to_image0).to(device)
_, h, w = timg.shape
if features == 'sift':
extractor = SIFT(max_num_keypoints=2048).eval().to(device)
feats = extractor.extract(timg)
return feats , h, w
if features == 'superpoint':
extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
# --- Extract local features ---
feats = {}
for k in (rotations):
timg_rotated = torch.rot90(timg, k, dims=(1, 2))
feats[k] = extractor.extract(timg_rotated)
#print(f"Extracted {feats[k]['keypoints'].shape[1]} keypoints for rotation {k}")
# --- Merge features back to original coordinate system ---
all_keypoints = []
all_scores = []
all_descriptors = []
all_rotations = []
for k, feat in feats.items():
kpts = feat['keypoints'] # Shape (1, N, 2)
num_kpts = kpts.shape[1]
# if k == 0:
# kpts_corrected = kpts
# elif k == 1:
# kpts_corrected = torch.stack(
# [w - 1 - kpts[..., 1], kpts[..., 0]], dim=-1
# )
# elif k == 2:
# kpts_corrected = torch.stack(
# [w - 1 - kpts[..., 0], h - 1 - kpts[..., 1]], dim=-1
# )
# elif k == 3:
# kpts_corrected = torch.stack(
# [kpts[..., 1], h - 1 - kpts[..., 0]], dim=-1
# )
rot_indices = torch.full((1, num_kpts), k, dtype=torch.long, device=device)
all_keypoints.append(feat['keypoints'])
all_scores.append(feat['keypoint_scores'])
all_descriptors.append(feat['descriptors'])
all_rotations.append(rot_indices)
# Concatenate all features along the keypoint dimension (dim=1)
feats_merged = {
'keypoints': torch.cat(all_keypoints, dim=1),
'keypoint_scores': torch.cat(all_scores, dim=1),
'descriptors': torch.cat(all_descriptors, dim=1),
'rotations': torch.cat(all_rotations, dim=1)
}
num_kpts = feats_merged['keypoints'].shape[1]
# perm = torch.randperm(num_kpts, device=device)
# feats_merged['keypoints'] = feats_merged['keypoints'][:, perm, :]
# feats_merged['keypoint_scores'] = feats_merged['keypoint_scores'][:, perm]
# feats_merged['descriptors'] = feats_merged['descriptors'][:, perm, :]
# Optional: If you want to retain other keys like 'shape' or 'image_size'
#feats_merged['image_size'] = torch.tensor([w, h], device=device).unsqueeze(0)
#feats_merged['scales'] = torch.tensor([w, h], device=device).unsqueeze(0)
# for f in feats_merged:
# if 'scales' not in f:
# f['scales'] = torch.ones(all_keypoints.shape[:-1], device=device)
# if 'oris' not in f:
# f['oris'] = torch.zeros(all_keypoints.shape[:-1], device=device)
return feats_merged , feats, h, w
def lightglue_matching(feats0, feats1, matcher = None):
if matcher is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
matcher = LightGlue(features='superpoint').eval().to(device)
out_k = matcher({'image0': feats0, 'image1': feats1})
_, _, out_k = [rbd(x) for x in [feats0, feats1, out_k]] # remove batch dim
return out_k['matches']
def feature_matching(feats0, feats1, matcher = None, exhaustive = True):
best_rot = 0
best_num_matches = 0
matches_tensor = None
# Find the best rotation alignment
for rot in [0,1,2,3]:
matches_tensor_rot = lightglue_matching(feats0[0], feats1[rot], matcher = matcher)
if (len(matches_tensor_rot) > best_num_matches):
best_num_matches = len(matches_tensor_rot)
best_rot = rot
matches_tensor = matches_tensor_rot
if matches_tensor is not None and len(matches_tensor) > 0:
matches_np = matches_tensor.cpu().numpy().astype(np.uint32)
else:
return None
# Adjust matches to account for rotations
for k in range(best_rot):
matches_np[:,1] += feats1[k]['keypoints'].shape[1]
all_matches = [matches_np]
if not exhaustive:
return matches_np
# Find the other rotation combinations
rots = []
for rot in [1, 2, 3]:
rot_i = best_rot + rot
if rot_i >=4:
rot_i = rot_i -4
rots.append(rot_i)
# Compute matches for the other rotation combinations
for rot_i in [1,2,3]:
rot_j = rots[rot_i-1]
matches_tensor_rot = lightglue_matching(feats0[rot_i], feats1[rot_j], matcher = matcher)
matches_np_i = matches_tensor_rot.cpu().numpy().astype(np.uint32)
if rot_i > 0:
for k in range(rot_i):
matches_np_i[:,0] += feats0[k]['keypoints'].shape[1]
if rot_j > 0:
for k in range(rot_j):
matches_np_i[:,1] += feats1[k]['keypoints'].shape[1]
all_matches.append(matches_np_i)
print(f"Rotation {rot_i} vs {rot_j}: {len(matches_tensor_rot)} matches")
# Stack all matches together
matches_stacked = (
np.vstack(all_matches) if len(all_matches) and all_matches[0].size else
np.empty((0, 2), dtype=np.uint32)
)
# if best_rot > 0:
# for k in range(best_rot):
# print(f"Adjusting for rotation {k}")
# matches_np[:,1] += feats1[k]['keypoints'].shape[1]
# return matches_np
return matches_stacked
|