| | """Module containing RANSAC modules.""" |
| | import math |
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from kornia.geometry import ( |
| | find_fundamental, |
| | find_homography_dlt, |
| | find_homography_dlt_iterated, |
| | symmetrical_epipolar_distance, |
| | ) |
| | from kornia.geometry.homography import symmetric_transfer_error |
| |
|
| | __all__ = ["RANSAC"] |
| |
|
| |
|
| | class RANSAC(nn.Module): |
| | """Module for robust geometry estimation with RANSAC. |
| | |
| | https://en.wikipedia.org/wiki/Random_sample_consensus |
| | |
| | Args: |
| | model_type: type of model to estimate, e.g. "homography" or "fundamental". |
| | inliers_threshold: threshold for the correspondence to be an inlier. |
| | batch_size: number of generated samples at once. |
| | max_iterations: maximum batches to generate. Actual number of models to try is ``batch_size * max_iterations``. |
| | confidence: desired confidence of the result, used for the early stopping. |
| | max_local_iterations: number of local optimization (polishing) iterations. |
| | """ |
| | supported_models = ['homography', 'fundamental'] |
| |
|
| | def __init__(self, |
| | model_type: str = 'homography', |
| | inl_th: float = 2.0, |
| | batch_size: int = 2048, |
| | max_iter: int = 10, |
| | confidence: float = 0.99, |
| | max_lo_iters: int = 5): |
| | super().__init__() |
| | self.inl_th = inl_th |
| | self.max_iter = max_iter |
| | self.batch_size = batch_size |
| | self.model_type = model_type |
| | self.confidence = confidence |
| | self.max_lo_iters = max_lo_iters |
| | self.model_type = model_type |
| | if model_type == 'homography': |
| | self.error_fn = symmetric_transfer_error |
| | self.minimal_solver = find_homography_dlt |
| | self.polisher_solver = find_homography_dlt_iterated |
| | self.minimal_sample_size = 4 |
| | elif model_type == 'fundamental': |
| | self.error_fn = symmetrical_epipolar_distance |
| | self.minimal_solver = find_fundamental |
| | self.minimal_sample_size = 8 |
| | |
| | |
| | self.polisher_solver = find_fundamental |
| | else: |
| | raise NotImplementedError(f"{model_type} is unknown. Try one of {self.supported_models}") |
| |
|
| | def sample(self, |
| | sample_size: int, |
| | pop_size: int, |
| | batch_size: int, |
| | device: torch.device = torch.device('cpu')) -> torch.Tensor: |
| | """Minimal sampler, but unlike traditional RANSAC we sample in batches to get benefit of the parallel |
| | processing, esp. |
| | |
| | on GPU |
| | """ |
| | rand = torch.rand(batch_size, pop_size, device=device) |
| | _, out = rand.topk(k=sample_size, dim=1) |
| | return out |
| |
|
| | @staticmethod |
| | def max_samples_by_conf(n_inl: int, num_tc: int, sample_size: int, conf: float) -> float: |
| | """Formula to update max_iter in order to stop iterations earlier |
| | https://en.wikipedia.org/wiki/Random_sample_consensus.""" |
| | if n_inl == num_tc: |
| | return 1.0 |
| | return math.log(1.0 - conf) / math.log(1. - math.pow(n_inl / num_tc, sample_size)) |
| |
|
| | def estimate_model_from_minsample(self, |
| | kp1: torch.Tensor, |
| | kp2: torch.Tensor) -> torch.Tensor: |
| | batch_size, sample_size = kp1.shape[:2] |
| | H = self.minimal_solver(kp1, |
| | kp2, |
| | torch.ones(batch_size, |
| | sample_size, |
| | dtype=kp1.dtype, |
| | device=kp1.device)) |
| | return H |
| |
|
| | def verify(self, |
| | kp1: torch.Tensor, |
| | kp2: torch.Tensor, |
| | models: torch.Tensor, inl_th: float) -> Tuple[torch.Tensor, torch.Tensor, float]: |
| | if len(kp1.shape) == 2: |
| | kp1 = kp1[None] |
| | if len(kp2.shape) == 2: |
| | kp2 = kp2[None] |
| | batch_size = models.shape[0] |
| | errors = self.error_fn(kp1.expand(batch_size, -1, 2), |
| | kp2.expand(batch_size, -1, 2), |
| | models) |
| | inl = (errors <= inl_th) |
| | models_score = inl.to(kp1).sum(dim=1) |
| | best_model_idx = models_score.argmax() |
| | best_model_score = models_score[best_model_idx].item() |
| | model_best = models[best_model_idx].clone() |
| | inliers_best = inl[best_model_idx] |
| | return model_best, inliers_best, best_model_score |
| |
|
| | def remove_bad_samples(self, kp1: torch.Tensor, kp2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """""" |
| | |
| | |
| | return kp1, kp2 |
| |
|
| | def remove_bad_models(self, models: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | main_diagonal = torch.diagonal(models, |
| | dim1=1, |
| | dim2=2) |
| | mask = main_diagonal.abs().min(dim=1)[0] > 1e-4 |
| | return models[mask] |
| |
|
| | def polish_model(self, |
| | kp1: torch.Tensor, |
| | kp2: torch.Tensor, |
| | inliers: torch.Tensor) -> torch.Tensor: |
| | |
| | kp1_inl = kp1[inliers][None] |
| | kp2_inl = kp2[inliers][None] |
| | num_inl = kp1_inl.size(1) |
| | model = self.polisher_solver(kp1_inl, |
| | kp2_inl, |
| | torch.ones(1, |
| | num_inl, |
| | dtype=kp1_inl.dtype, |
| | device=kp1_inl.device)) |
| | return model |
| |
|
| | def forward(self, |
| | kp1: torch.Tensor, |
| | kp2: torch.Tensor, |
| | weights: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| | r"""Main forward method to execute the RANSAC algorithm. |
| | |
| | Args: |
| | kp1 (torch.Tensor): source image keypoints :math:`(N, 2)`. |
| | kp2 (torch.Tensor): distance image keypoints :math:`(N, 2)`. |
| | weights (torch.Tensor): optional correspondences weights. Not used now |
| | |
| | Returns: |
| | - Estimated model, shape of :math:`(1, 3, 3)`. |
| | - The inlier/outlier mask, shape of :math:`(1, N)`, where N is number of input correspondences. |
| | """ |
| | if not isinstance(kp1, torch.Tensor): |
| | raise TypeError(f"Input kp1 is not torch.Tensor. Got {type(kp1)}") |
| | if not isinstance(kp2, torch.Tensor): |
| | raise TypeError(f"Input kp2 is not torch.Tensor. Got {type(kp2)}") |
| | if not len(kp1.shape) == 2: |
| | raise ValueError(f"Invalid kp1 shape, we expect Nx2 Got: {kp1.shape}") |
| | if not len(kp2.shape) == 2: |
| | raise ValueError(f"Invalid kp2 shape, we expect Nx2 Got: {kp2.shape}") |
| | if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size): |
| | raise ValueError(f"kp1 and kp2 should be \ |
| | equal shape at at least [{self.minimal_sample_size}, 2], \ |
| | got {kp1.shape}, {kp2.shape}") |
| |
|
| | best_score_total: float = float(self.minimal_sample_size) |
| | num_tc: int = len(kp1) |
| | best_model_total = torch.zeros(3, 3, dtype=kp1.dtype, device=kp1.device) |
| | inliers_best_total: torch.Tensor = torch.zeros(num_tc, 1, device=kp1.device, dtype=torch.bool) |
| | for i in range(self.max_iter): |
| | |
| | idxs = self.sample(self.minimal_sample_size, num_tc, self.batch_size, kp1.device) |
| | kp1_sampled = kp1[idxs] |
| | kp2_sampled = kp2[idxs] |
| |
|
| | kp1_sampled, kp2_sampled = self.remove_bad_samples(kp1_sampled, kp2_sampled) |
| | |
| | models = self.estimate_model_from_minsample(kp1_sampled, kp2_sampled) |
| | models = self.remove_bad_models(models) |
| | if (models is None) or (len(models) == 0): |
| | continue |
| | |
| | model, inliers, model_score = self.verify(kp1, kp2, models, self.inl_th) |
| | |
| | if model_score > best_score_total: |
| | |
| | for lo_step in range(self.max_lo_iters): |
| | model_lo = self.polish_model(kp1, kp2, inliers) |
| | if (model_lo is None) or (len(model_lo) == 0): |
| | continue |
| | _, inliers_lo, score_lo = self.verify(kp1, kp2, model_lo, self.inl_th) |
| | |
| | if score_lo > model_score: |
| | model = model_lo.clone()[0] |
| | inliers = inliers_lo.clone() |
| | model_score = score_lo |
| | else: |
| | break |
| | |
| | best_model_total = model.clone() |
| | inliers_best_total = inliers.clone() |
| | best_score_total = model_score |
| |
|
| | |
| | new_max_iter = int(self.max_samples_by_conf(int(best_score_total), |
| | num_tc, |
| | self.minimal_sample_size, |
| | self.confidence)) |
| | |
| | |
| | if (i + 1) * self.batch_size >= new_max_iter: |
| | break |
| | |
| | return best_model_total, inliers_best_total |
| |
|