import torch import numpy as np import cv2 from lightglue import LightGlue from lightglue.utils import rbd def unrotate_kps_W(kps_rot, k, H, W): import numpy as np # Ensure inputs are Numpy if hasattr(kps_rot, 'cpu'): kps_rot = kps_rot.cpu().numpy() if hasattr(k, 'cpu'): k = k.cpu().numpy() # Squeeze if necessary if k.ndim > 1: k = k.squeeze() if kps_rot.ndim > 2: kps_rot = kps_rot.squeeze() x_r = kps_rot[:, 0] y_r = kps_rot[:, 1] x = np.zeros_like(x_r) y = np.zeros_like(y_r) mask0 = (k == 0) x[mask0], y[mask0] = x_r[mask0], y_r[mask0] mask1 = (k == 1) x[mask1], y[mask1] = (W - 1) - y_r[mask1], x_r[mask1] mask2 = (k == 2) x[mask2], y[mask2] = (W - 1) - x_r[mask2], (H - 1) - y_r[mask2] mask3 = (k == 3) x[mask3], y[mask3] = y_r[mask3], (H - 1) - x_r[mask3] return np.stack([x, y], axis=-1) def unrotate_kps(kps_rot, k, H, W): import torch # k is how many times you rotated CCW by 90° to create the rotated image x_r, y_r = kps_rot[:, 0].clone(), kps_rot[:, 1].clone() if k == 0: x, y = x_r, y_r elif k == 1: # 90° CCW x = (W - 1) - y_r y = x_r elif k == 2: # 180° x = (W - 1) - x_r y = (H - 1) - y_r elif k == 3: # 270° CCW x = y_r y = (H - 1) - x_r else: raise ValueError("k must be 0..3") return torch.stack([x, y], dim=-1) # def lightglue_matching(path_to_image0, path_to_image1, plot=False, features='superpoint'): # from lightglue import LightGlue, SuperPoint, SIFT # from lightglue.utils import load_image, rbd # from lightglue import viz2d # import torch # # --- Models on GPU --- # device = 'cuda' if torch.cuda.is_available() else 'cpu' # if features == 'superpoint': # extractor = SuperPoint(max_num_keypoints=2048).eval().to(device) # if features == 'sift': # extractor = SIFT(max_num_keypoints=2048).eval().to(device) # matcher = LightGlue(features=features).eval().to(device) # # --- Load images as Torch tensors (3,H,W) in [0,1] --- # timg0 = load_image(path_to_image0).to(device) # timg1 = load_image(path_to_image1).to(device) # # --- Extract local features --- # feats0 = extractor.extract(timg0) # auto-resize inside # max_num_matches = -1 # best_k = 0 # best_feats0 = None # best_feats1 = None # for k in range(4): # timg1_rotated = torch.rot90(timg1, k, dims=(1, 2)) # feats1_k = extractor.extract(timg1_rotated) # out_k = matcher({'image0': feats0, 'image1': feats1_k}) # feats0_k, feats1_k, out_k = [rbd(x) for x in [feats0, feats1_k, out_k]] # remove batch dim # matches_k = out_k['matches'] # (K,2) long # num_k = len(matches_k) # if num_k > max_num_matches: # max_num_matches = num_k # matches = matches_k # best_feats0 = feats0_k # best_feats1 = feats1_k # best_k = k # # --- Keypoints in matched order (Torch tensors on CPU) --- # H1, W1 = timg1.shape[-2], timg1.shape[-1] # kpts0 = best_feats0['keypoints'][matches[:, 0]] # kpts1 = best_feats1['keypoints'][matches[:, 1]] # kpts1 = unrotate_kps(kpts1, best_k, H1, W1) # (K,2) mapped to original image1 coords # desc0 = best_feats0['descriptors'][matches[:, 0]] # desc1 = best_feats1['descriptors'][matches[:, 1]] # if plot: # if len(kpts0) == 0 or len(kpts1) == 0: # print("No matches found.") # return None, None # ax = viz2d.plot_images([timg0.cpu(), timg1.cpu()]) # viz2d.plot_matches(kpts0.cpu(), kpts1.cpu(), color=None, lw=0.8, axes=ax) # #ax0 = ax[0] if isinstance(ax, (list, tuple, np.ndarray)) else ax # #fig = ax0.figure # #return kpts0, kpts1 #, fig, ax # return kpts0, kpts1, desc0, desc1 def lightglue_keypoints(path_to_image0, features='superpoint', rotations = [0,1,2,3]): from lightglue import LightGlue, SuperPoint, SIFT from lightglue.utils import load_image, rbd from lightglue import viz2d import torch # --- Models on GPU --- device = 'cuda' if torch.cuda.is_available() else 'cpu' if features == 'superpoint': extractor = SuperPoint(max_num_keypoints=2048).eval().to(device) if features == 'sift': extractor = SIFT(max_num_keypoints=2048).eval().to(device) # --- Load images as Torch tensors (3,H,W) in [0,1] --- timg = load_image(path_to_image0).to(device) _, h, w = timg.shape # --- Extract local features --- feats = {} for k in (rotations): timg_rotated = torch.rot90(timg, k, dims=(1, 2)) feats[k] = extractor.extract(timg_rotated) print(f"Extracted {feats[k]['keypoints'].shape[1]} keypoints for rotation {k}") # --- Merge features back to original coordinate system --- all_keypoints = [] all_scores = [] all_descriptors = [] all_rotations = [] for k, feat in feats.items(): kpts = feat['keypoints'] # Shape (1, N, 2) num_kpts = kpts.shape[1] if k == 0: kpts_corrected = kpts elif k == 1: kpts_corrected = torch.stack( [w - 1 - kpts[..., 1], kpts[..., 0]], dim=-1 ) elif k == 2: kpts_corrected = torch.stack( [w - 1 - kpts[..., 0], h - 1 - kpts[..., 1]], dim=-1 ) elif k == 3: kpts_corrected = torch.stack( [kpts[..., 1], h - 1 - kpts[..., 0]], dim=-1 ) rot_indices = torch.full((1, num_kpts), k, dtype=torch.long, device=device) all_keypoints.append(feat['keypoints']) all_scores.append(feat['keypoint_scores']) all_descriptors.append(feat['descriptors']) all_rotations.append(rot_indices) # Concatenate all features along the keypoint dimension (dim=1) feats_merged = { 'keypoints': torch.cat(all_keypoints, dim=1), 'keypoint_scores': torch.cat(all_scores, dim=1), 'descriptors': torch.cat(all_descriptors, dim=1), 'rotations': torch.cat(all_rotations, dim=1) } num_kpts = feats_merged['keypoints'].shape[1] # perm = torch.randperm(num_kpts, device=device) # feats_merged['keypoints'] = feats_merged['keypoints'][:, perm, :] # feats_merged['keypoint_scores'] = feats_merged['keypoint_scores'][:, perm] # feats_merged['descriptors'] = feats_merged['descriptors'][:, perm, :] # Optional: If you want to retain other keys like 'shape' or 'image_size' feats_merged['image_size'] = torch.tensor([w, h], device=device).unsqueeze(0) return feats_merged , h, w def lightglue_matching(feats0, feats1, plot=False, features='superpoint', path_to_image0=None, path_to_image1=None): from lightglue import LightGlue, SuperPoint, SIFT from lightglue.utils import load_image, rbd from lightglue import viz2d import torch # --- Models on GPU --- device = 'cuda' if torch.cuda.is_available() else 'cpu' matcher = LightGlue(features=features).eval().to(device) # --- Load images as Torch tensors (3,H,W) in [0,1] --- if plot: timg0 = load_image(path_to_image0).to(device) timg1 = load_image(path_to_image1).to(device) # --- Extract local features --- max_num_matches = -1 best_k = 0 best_feats0 = None best_feats1 = None for k in range(1): #timg1_rotated = torch.rot90(timg1, k, dims=(1, 2)) feats1_k = feats1 #extractor.extract(timg1_rotated) out_k = matcher({'image0': feats0, 'image1': feats1_k}) feats0_k, feats1_k, out_k = [rbd(x) for x in [feats0, feats1_k, out_k]] # remove batch dim matches_k = out_k['matches'] # (K,2) long num_k = len(matches_k) if num_k > max_num_matches: max_num_matches = num_k matches = matches_k best_feats0 = feats0_k best_feats1 = feats1_k best_k = k print(f"LightGlue found {len(matches)} matches.") # --- Keypoints in matched order (Torch tensors on CPU) --- #H1, W1 = timg1.shape[-2], timg1.shape[-1] # kpts0 = best_feats0['keypoints'][matches[:, 0]] # kpts1 = best_feats1['keypoints'][matches[:, 1]] # #kpts1 = unrotate_kps(kpts1, best_k, H1, W1) # (K,2) mapped to original image1 coords # desc0 = best_feats0['descriptors'][matches[:, 0]] # desc1 = best_feats1['descriptors'][matches[:, 1]] # pts0 = kpts0.detach().cpu().numpy().astype(np.float32) # (K,2) # pts1 = kpts1.detach().cpu().numpy().astype(np.float32) # (K,2) # H, inliers = cv2.findHomography(pts0, pts1, cv2.RANSAC, 5.0) # if inliers is not None: # mask = inliers.ravel() == 1 # mask_tensor = torch.from_numpy(mask).to(matches.device) # matches = matches[mask_tensor] # else: # # If geometry check failed completely, return no matches # return None # if plot: # if len(kpts0) == 0 or len(kpts1) == 0: # print("No matches found.") # return None, None # ax = viz2d.plot_images([timg0.cpu(), timg1.cpu()]) # viz2d.plot_matches(kpts0.cpu(), kpts1.cpu(), color=None, lw=0.8, axes=ax) # #ax0 = ax[0] if isinstance(ax, (list, tuple, np.ndarray)) else ax # #fig = ax0.figure # #return kpts0, kpts1 #, fig, ax return matches