| import copy |
| import time |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional, Tuple |
|
|
| import numpy as np |
| import open3d as o3d |
| import torch |
| try: |
| from easydict import EasyDict as edict |
| except Exception: |
| class edict(dict): |
| """Minimal EasyDict fallback (dot access).""" |
|
|
| def __getattr__(self, k): |
| try: |
| return self[k] |
| except KeyError as e: |
| raise AttributeError(k) from e |
|
|
| def __setattr__(self, k, v): |
| self[k] = v |
|
|
| from tools import metrics |
| from r3pm_net.config_loader import get_method_paths |
|
|
|
|
| @dataclass |
| class _PredatorRunner: |
| predator_root: Path |
| config_path: Path |
| weights_path: Path |
| device: torch.device |
| config: edict |
| model: torch.nn.Module |
| neighborhood_limits: np.ndarray |
| input_num_points: int |
|
|
|
|
| _RUNNER: Optional[_PredatorRunner] = None |
| _METHOD_CFG = get_method_paths().get("predator", {}) |
|
|
|
|
| def _build_kpconv_architecture(num_layers: int) -> list: |
| |
| arch = ["simple", "resnetb"] |
| for _ in range(num_layers - 1): |
| arch += ["resnetb_strided", "resnetb", "resnetb"] |
| for _ in range(num_layers - 2): |
| arch += ["nearest_upsample", "unary"] |
| arch += ["nearest_upsample", "last_unary"] |
| return arch |
|
|
|
|
| def _get_predator_architecture(cfg_in: edict) -> list: |
| """ |
| OverlapPredator defines dataset-specific architectures in `configs/models.py`. |
| We try to use that (it must match the released checkpoints), and fall back to |
| the demo-style architecture builder if unavailable. |
| """ |
| try: |
| from configs.models import architectures as arch_dict |
|
|
| dataset_name = getattr(cfg_in, "dataset", None) |
| if dataset_name in arch_dict: |
| return arch_dict[dataset_name] |
| except Exception: |
| pass |
|
|
| return _build_kpconv_architecture(int(getattr(cfg_in, "num_layers", 3))) |
|
|
|
|
| def _resolve_path(predator_root: Path, p: str | Path) -> Path: |
| p = Path(p) |
| return p if p.is_absolute() else (predator_root / p) |
|
|
|
|
| def _maybe_downsample_xyz(xyz: np.ndarray, max_points: int) -> np.ndarray: |
| if max_points <= 0 or xyz.shape[0] <= max_points: |
| return xyz |
| idx = np.random.permutation(xyz.shape[0])[:max_points] |
| return xyz[idx] |
|
|
|
|
| def _to_o3d_feature(desc: np.ndarray) -> "o3d.pipelines.registration.Feature": |
| feat = o3d.pipelines.registration.Feature() |
| feat.data = np.asarray(desc, dtype=np.float32).T |
| return feat |
|
|
|
|
| def _ransac_pose_estimation( |
| src_xyz: np.ndarray, |
| tgt_xyz: np.ndarray, |
| src_desc: np.ndarray, |
| tgt_desc: np.ndarray, |
| *, |
| distance_threshold: float = 0.05, |
| ransac_n: int = 3, |
| mutual: bool = False, |
| ) -> np.ndarray: |
| src_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(src_xyz)) |
| tgt_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(tgt_xyz)) |
|
|
| src_feat = _to_o3d_feature(src_desc) |
| tgt_feat = _to_o3d_feature(tgt_desc) |
|
|
| estimation = o3d.pipelines.registration.TransformationEstimationPointToPoint(False) |
| checkers = [ |
| o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), |
| o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold), |
| ] |
| criteria = o3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000) |
|
|
| |
| try: |
| result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( |
| src_pcd, |
| tgt_pcd, |
| src_feat, |
| tgt_feat, |
| mutual, |
| distance_threshold, |
| estimation, |
| ransac_n, |
| checkers, |
| criteria, |
| ) |
| except TypeError: |
| result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( |
| source=src_pcd, |
| target=tgt_pcd, |
| source_feature=src_feat, |
| target_feature=tgt_feat, |
| mutual_filter=mutual, |
| max_correspondence_distance=distance_threshold, |
| estimation_method=estimation, |
| ransac_n=ransac_n, |
| checkers=checkers, |
| criteria=criteria, |
| ) |
|
|
| return np.asarray(result.transformation, dtype=np.float64) |
|
|
|
|
| def _init_runner( |
| predator_root: Path, |
| config_path: Path, |
| weights_path: Optional[Path], |
| *, |
| device: Optional[str | torch.device] = None, |
| input_num_points: Optional[int] = None, |
| calibrate_neighborhood_limits: bool = True, |
| ) -> _PredatorRunner: |
| |
| import sys |
|
|
| if str(predator_root) not in sys.path: |
| sys.path.insert(0, str(predator_root)) |
|
|
| from lib.utils import load_config |
| from datasets.my_dataloader import calibrate_neighbors, collate_fn_descriptor |
| from models.architectures import KPFCNN |
|
|
| cfg = edict(load_config(str(config_path))) |
|
|
| if device is None: |
| device_t = torch.device("cuda" if bool(cfg.gpu_mode) and torch.cuda.is_available() else "cpu") |
| else: |
| device_t = torch.device(device) if not isinstance(device, torch.device) else device |
|
|
| |
| ckpt_path = _resolve_path(predator_root, weights_path) if weights_path else _resolve_path(predator_root, cfg.pretrain) |
| state = torch.load(str(ckpt_path), map_location=device_t) |
| state_dict = state["state_dict"] if isinstance(state, dict) and "state_dict" in state else state |
|
|
| def _try_build_and_load(cfg_in: edict) -> Optional[torch.nn.Module]: |
| cfg_in.device = device_t |
| cfg_in.architecture = _get_predator_architecture(cfg_in) |
| m = KPFCNN(cfg_in).to(device_t) |
| m.eval() |
| try: |
| m.load_state_dict(state_dict, strict=False) |
| except RuntimeError: |
| return None |
| return m |
|
|
| |
| cfg_candidates: list[edict] = [] |
| cfg_candidates.append(cfg) |
|
|
| |
| first_fd = int(getattr(cfg, "first_feats_dim", 0) or 0) |
| for cand in [first_fd // 2, 256, 128, 64]: |
| if cand and cand != first_fd: |
| c = edict(dict(cfg)) |
| c.first_feats_dim = int(cand) |
| cfg_candidates.append(c) |
|
|
| model = None |
| chosen_cfg = None |
| for c in cfg_candidates: |
| m = _try_build_and_load(c) |
| if m is not None: |
| model = m |
| chosen_cfg = c |
| break |
|
|
| if model is None or chosen_cfg is None: |
| |
| raise RuntimeError( |
| f"Failed to load OverlapPredator weights at '{ckpt_path}'. " |
| f"Config '{config_path}' seems incompatible with checkpoint tensor shapes." |
| ) |
|
|
| |
| if input_num_points is None: |
| input_num_points = int(getattr(cfg, "num_points", 1024)) |
|
|
| if calibrate_neighborhood_limits: |
| |
| class _SinglePairDataset: |
| def __init__(self, config): |
| self.config = config |
|
|
| def __len__(self): |
| return 1 |
|
|
| def __getitem__(self, _): |
| |
| n = max(64, int(input_num_points)) |
| src = np.random.randn(n, 3).astype(np.float32) |
| tgt = np.random.randn(n, 3).astype(np.float32) |
| src_feats = np.ones((n, 1), dtype=np.float32) |
| tgt_feats = np.ones((n, 1), dtype=np.float32) |
| rot = np.eye(3, dtype=np.float32) |
| trans = np.zeros((3, 1), dtype=np.float32) |
| matching_inds = torch.ones(1, 2).long() |
| sample = torch.ones(1) |
| gt = np.eye(4, dtype=np.float32) |
| return src, tgt, src_feats, tgt_feats, rot, trans, matching_inds, src, tgt, sample, gt |
|
|
| dummy_ds = _SinglePairDataset(chosen_cfg) |
| neighborhood_limits = calibrate_neighbors(dummy_ds, chosen_cfg, collate_fn=collate_fn_descriptor) |
| else: |
| |
| |
| n_layers = int(getattr(chosen_cfg, "num_layers", 5) or 5) |
| neighborhood_limits = np.asarray([256] * n_layers, dtype=np.int32) |
|
|
| return _PredatorRunner( |
| predator_root=predator_root, |
| config_path=config_path, |
| weights_path=ckpt_path, |
| device=device_t, |
| config=chosen_cfg, |
| model=model, |
| neighborhood_limits=neighborhood_limits, |
| input_num_points=int(input_num_points), |
| ) |
|
|
|
|
| def predator_reg_and_eval( |
| source: "o3d.geometry.PointCloud", |
| target: "o3d.geometry.PointCloud", |
| *, |
| gt_transformation: Optional[np.ndarray] = None, |
| predator_root: str | Path = _METHOD_CFG.get("root", "/home/ykashefbahrami/master_thesis/OverlapPredator"), |
| config_path: str | Path = _METHOD_CFG.get("config_path", "/home/ykashefbahrami/master_thesis/OverlapPredator/configs/test/modelnet.yaml"), |
| weights_path: Optional[str | Path] = _METHOD_CFG.get("weights_path", None), |
| ransac_n_points: int = 1000, |
| ransac_distance_threshold: float = 0.05, |
| ransac_n: int = 3, |
| sampling: str = "prob", |
| mutual: bool = False, |
| device: Optional[str | torch.device] = None, |
| input_num_points: Optional[int] = 1024, |
| ) -> Tuple["o3d.geometry.PointCloud", tuple]: |
| """ |
| Run OverlapPredator on a (source, target) pair and evaluate with the same |
| metric outputs as the Learning3D harness in this repo. |
| """ |
| global _RUNNER |
| predator_root_p = Path(predator_root).resolve() |
| config_path_p = Path(config_path).resolve() |
| weights_path_p = Path(weights_path).resolve() if weights_path is not None else None |
|
|
| if _RUNNER is None: |
| _RUNNER = _init_runner( |
| predator_root_p, |
| config_path_p, |
| weights_path_p, |
| device=device, |
| input_num_points=input_num_points, |
| ) |
|
|
| |
| from datasets.my_dataloader import collate_fn_descriptor |
|
|
| src_xyz = np.asarray(source.points, dtype=np.float32) |
| tgt_xyz = np.asarray(target.points, dtype=np.float32) |
| src_xyz = _maybe_downsample_xyz(src_xyz, _RUNNER.input_num_points) |
| tgt_xyz = _maybe_downsample_xyz(tgt_xyz, _RUNNER.input_num_points) |
|
|
| src_feats = np.ones((src_xyz.shape[0], 1), dtype=np.float32) |
| tgt_feats = np.ones((tgt_xyz.shape[0], 1), dtype=np.float32) |
|
|
| rot = np.eye(3, dtype=np.float32) |
| trans = np.zeros((3, 1), dtype=np.float32) |
| matching_inds = torch.ones(1, 2).long() |
| sample = torch.ones(1) |
| gt = np.asarray(gt_transformation, dtype=np.float32) if gt_transformation is not None else np.eye(4, dtype=np.float32) |
|
|
| |
| batch = collate_fn_descriptor( |
| [(src_xyz, tgt_xyz, src_feats, tgt_feats, rot, trans, matching_inds, src_xyz, tgt_xyz, sample, gt)], |
| config=_RUNNER.config, |
| neighborhood_limits=_RUNNER.neighborhood_limits, |
| ) |
|
|
| |
| for k, v in list(batch.items()): |
| if isinstance(v, list): |
| batch[k] = [t.to(_RUNNER.device) for t in v] |
| elif torch.is_tensor(v): |
| batch[k] = v.to(_RUNNER.device) |
|
|
| start = time.time() |
| with torch.no_grad(): |
| feats, scores_overlap, scores_saliency = _RUNNER.model(batch) |
| feats = feats.detach().cpu() |
| scores_overlap = scores_overlap.detach().cpu() |
| scores_saliency = scores_saliency.detach().cpu() |
|
|
| pcd = batch["points"][0].detach().cpu() |
| len_src = int(batch["stack_lengths"][0][0].detach().cpu().item()) |
| src_pcd = pcd[:len_src] |
| tgt_pcd = pcd[len_src:] |
|
|
| src_desc = feats[:len_src].numpy() |
| tgt_desc = feats[len_src:].numpy() |
| src_scores = (scores_overlap[:len_src] * scores_saliency[:len_src]).numpy().flatten() |
| tgt_scores = (scores_overlap[len_src:] * scores_saliency[len_src:]).numpy().flatten() |
|
|
| def _sample_idx(scores: np.ndarray, n: int) -> np.ndarray: |
| n_all = scores.shape[0] |
| if n_all <= n: |
| return np.arange(n_all) |
| if sampling == "topk": |
| return np.argsort(-scores)[:n] |
| if sampling == "random": |
| return np.random.permutation(n_all)[:n] |
| |
| s = float(scores.sum()) |
| if not np.isfinite(s) or s <= 0.0: |
| return np.random.permutation(n_all)[:n] |
| probs = scores / s |
| return np.random.choice(np.arange(n_all), size=n, replace=False, p=probs) |
|
|
| src_idx = _sample_idx(src_scores, ransac_n_points) |
| tgt_idx = _sample_idx(tgt_scores, ransac_n_points) |
|
|
| tsfm = _ransac_pose_estimation( |
| src_pcd[src_idx].numpy(), |
| tgt_pcd[tgt_idx].numpy(), |
| src_desc[src_idx], |
| tgt_desc[tgt_idx], |
| distance_threshold=ransac_distance_threshold, |
| ransac_n=ransac_n, |
| mutual=mutual, |
| ) |
| end = time.time() |
|
|
| pc_result = copy.deepcopy(source).transform(tsfm) |
| eval_results = metrics.all_evaluations( |
| source, |
| target, |
| pc_result, |
| end - start, |
| gt_transformation=gt_transformation, |
| est_transformation=tsfm, |
| corres=None, |
| ) |
|
|
| return pc_result, eval_results |
|
|
|
|