| import os | |
| import torch | |
| import py3_wget | |
| from matching.im_models.lightglue import SIFT, SuperPoint | |
| from matching.utils import add_to_path | |
| from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher | |
| add_to_path(THIRD_PARTY_DIR.joinpath('SphereGlue')) | |
| from model.sphereglue import SphereGlue | |
| from utils.Utils import sphericalToCartesian | |
| def unit_cartesian(points): | |
| phi, theta = torch.split(torch.as_tensor(points), 1, dim=1) | |
| unitCartesian = sphericalToCartesian(phi, theta, 1).squeeze(dim=2) | |
| return unitCartesian | |
| class SphereGlueBase(BaseMatcher): | |
| """ | |
| This class is the parent for all methods that use LightGlue as a matcher, | |
| with different local features. It implements the forward which is the same | |
| regardless of the feature extractor of choice. | |
| Therefore this class should *NOT* be instatiated, as it needs its children to define | |
| the extractor and the matcher. | |
| """ | |
| def __init__(self, device="cpu", **kwargs): | |
| super().__init__(device, **kwargs) | |
| self.sphereglue_cfg = { | |
| "K": kwargs.get("K", 2), | |
| "GNN_layers": kwargs.get("GNN_layers", ["cross"]), | |
| "match_threshold": kwargs.get("match_threshold", 0.2), | |
| "sinkhorn_iterations": kwargs.get("sinkhorn_iterations", 20), | |
| "aggr": kwargs.get("aggr", "add"), | |
| "knn": kwargs.get("knn", 20), | |
| } | |
| self.skip_ransac = True | |
| def download_weights(self): | |
| if not os.path.isfile(self.model_path): | |
| print("Downloading SphereGlue weights") | |
| py3_wget.download_file(self.weights_url, self.model_path) | |
| def _forward(self, img0, img1): | |
| """ | |
| "extractor" and "matcher" are instantiated by the subclasses. | |
| """ | |
| feats0 = self.extractor.extract(img0) | |
| feats1 = self.extractor.extract(img1) | |
| unit_cartesian1 = unit_cartesian(feats0["keypoints"][0]).unsqueeze(dim=0).to(self.device) | |
| unit_cartesian2 = unit_cartesian(feats1["keypoints"][0]).unsqueeze(dim=0).to(self.device) | |
| inputs = { | |
| "h1": feats0["descriptors"], | |
| "h2": feats1["descriptors"], | |
| "scores1": feats0["keypoint_scores"], | |
| "scores2": feats1["keypoint_scores"], | |
| "unitCartesian1": unit_cartesian1, | |
| "unitCartesian2": unit_cartesian2, | |
| } | |
| outputs = self.matcher(inputs) | |
| kpts0, kpts1, matches = ( | |
| feats0["keypoints"].squeeze(dim=0), | |
| feats1["keypoints"].squeeze(dim=0), | |
| outputs["matches0"].squeeze(dim=0), | |
| ) | |
| desc0 = feats0["descriptors"].squeeze(dim=0) | |
| desc1 = feats1["descriptors"].squeeze(dim=0) | |
| mask = matches.ge(0) | |
| kpts0_idx = torch.masked_select(torch.arange(matches.shape[0]).to(mask.device), mask) | |
| kpts1_idx = torch.masked_select(matches, mask) | |
| mkpts0 = kpts0[kpts0_idx] | |
| mkpts1 = kpts1[kpts1_idx] | |
| return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1 | |
| class SiftSphereGlue(SphereGlueBase): | |
| model_path = WEIGHTS_DIR.joinpath("sift-sphereglue.pt") | |
| weights_url = "https://github.com/vishalsharbidar/SphereGlue/raw/refs/heads/main/model_weights/sift/autosaved.pt" | |
| def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): | |
| super().__init__(device, **kwargs) | |
| self.download_weights() | |
| self.sphereglue_cfg.update({ | |
| "descriptor_dim": 128, | |
| "output_dim": 128*2, | |
| "max_kpts": max_num_keypoints | |
| }) | |
| self.extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().to(self.device) | |
| self.matcher = SphereGlue(config=self.sphereglue_cfg).to(self.device) | |
| self.matcher.load_state_dict(torch.load(self.model_path, map_location=self.device)["MODEL_STATE_DICT"]) | |
| class SuperpointSphereGlue(SphereGlueBase): | |
| model_path = WEIGHTS_DIR.joinpath("superpoint-sphereglue.pt") | |
| weights_url = "https://github.com/vishalsharbidar/SphereGlue/raw/refs/heads/main/model_weights/superpoint/autosaved.pt" | |
| def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): | |
| super().__init__(device, **kwargs) | |
| self.download_weights() | |
| self.sphereglue_cfg.update({ | |
| "descriptor_dim": 256, | |
| "output_dim": 256*2, | |
| "max_kpts": max_num_keypoints | |
| }) | |
| self.extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(self.device) | |
| self.matcher = SphereGlue(config=self.sphereglue_cfg).to(self.device) | |
| self.matcher.load_state_dict(torch.load(self.model_path, map_location=self.device)["MODEL_STATE_DICT"]) |