initial commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +32 -0
- .gitignore +12 -0
- README.md +251 -0
- assets/r3pmnet_overview.png +3 -0
- assets/sioux_cranfield.png +3 -0
- assets/sioux_scans.png +3 -0
- assets/success_cases.png +3 -0
- assets/teaser.png +3 -0
- config/default.yaml +24 -0
- config/eval.yaml +28 -0
- dataloader/README.md +71 -0
- dataloader/__init__.py +1 -0
- dataloader/args.txt +14 -0
- dataloader/data_dict_generator.py +119 -0
- dataloader/dataset_generator.py +258 -0
- dataloader/user_data.py +127 -0
- environment.yml +16 -0
- pyproject.toml +25 -0
- r3pm_net/__init__.py +5 -0
- r3pm_net/config_loader.py +164 -0
- r3pm_net/feature_extractor.py +8 -0
- r3pm_net/model.py +382 -0
- r3pm_net/paths.py +11 -0
- scripts/eval_modelnet40.py +335 -0
- scripts/eval_sioux_cranfield.py +302 -0
- scripts/eval_sioux_scans.py +341 -0
- scripts/modelnet40.sh +45 -0
- scripts/sioux_cranfield.sh +46 -0
- scripts/sioux_scans.sh +45 -0
- src/train.py +366 -0
- thirdparty/__init__.py +1 -0
- thirdparty/learning3d/data_utils/__init__.py +4 -0
- thirdparty/learning3d/data_utils/dataloaders.py +454 -0
- thirdparty/learning3d/data_utils/user_data.py +119 -0
- thirdparty/learning3d/examples/test_curvenet.py +118 -0
- thirdparty/learning3d/examples/test_dcp.py +139 -0
- thirdparty/learning3d/examples/test_deepgmr.py +144 -0
- thirdparty/learning3d/examples/test_flownet.py +113 -0
- thirdparty/learning3d/examples/test_masknet.py +159 -0
- thirdparty/learning3d/examples/test_masknet2.py +162 -0
- thirdparty/learning3d/examples/test_pcn.py +118 -0
- thirdparty/learning3d/examples/test_pcrnet.py +120 -0
- thirdparty/learning3d/examples/test_pnlk.py +121 -0
- thirdparty/learning3d/examples/test_pointconv.py +126 -0
- thirdparty/learning3d/examples/test_pointnet.py +121 -0
- thirdparty/learning3d/examples/test_prnet.py +126 -0
- thirdparty/learning3d/examples/test_rpmnet.py +120 -0
- thirdparty/learning3d/examples/train_PointNetLK.py +240 -0
- thirdparty/learning3d/examples/train_dcp.py +249 -0
- thirdparty/learning3d/examples/train_deepgmr.py +244 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,35 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/r3pmnet_overview.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/sioux_cranfield.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/sioux_scans.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/success_cases.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
thirdparty/learning3d/pretrained/exp_classifier/models/best_model_snap.t7 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
thirdparty/learning3d/pretrained/exp_classifier/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
thirdparty/learning3d/pretrained/exp_classifier/models/best_ptnet_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
thirdparty/learning3d/pretrained/exp_curvenet/models/model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
thirdparty/learning3d/pretrained/exp_dcp/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
thirdparty/learning3d/pretrained/exp_flownet/models/model.best.t7 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
thirdparty/learning3d/pretrained/exp_ipcrnet/models/best_model_v1.t7 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
thirdparty/learning3d/pretrained/exp_ipcrnet/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
thirdparty/learning3d/pretrained/exp_ipcrnet/models/best_ptnet_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
thirdparty/learning3d/pretrained/exp_masknet/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.01.t7 filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.6.t7 filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.7.t7 filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.8.t7 filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.9.t7 filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_100.t7 filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_200.t7 filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_300.t7 filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_400.t7 filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_500.t7 filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
thirdparty/learning3d/pretrained/exp_pcn/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
thirdparty/learning3d/pretrained/exp_pnlk/models/best_model_snap.t7 filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
thirdparty/learning3d/pretrained/exp_pnlk/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
thirdparty/learning3d/pretrained/exp_pnlk/models/best_ptnet_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
thirdparty/learning3d/pretrained/exp_prnet/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
thirdparty/learning3d/pretrained/exp_prnet/models/model.99.t7 filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
thirdparty/learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*:*Zone.Identifier
|
| 3 |
+
.ipynb_checkpoints/
|
| 4 |
+
*.ipynb
|
| 5 |
+
|
| 6 |
+
checkpoints/
|
| 7 |
+
data/
|
| 8 |
+
results/
|
| 9 |
+
registration_plys/
|
| 10 |
+
logs/
|
| 11 |
+
notebooks/
|
| 12 |
+
kernels/
|
README.md
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- # R3PM-Net
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
This repository contains the official implementation of the paper:
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<strong><a href="https://arxiv.org/abs/2604.05060">R3PM-Net: Real-time, Robust, Real-world Point Matching Network</a></strong><br>
|
| 9 |
+
<strong>(AI4RWC@CVPRW 2026 - Oral Presentation)</strong>
|
| 10 |
+
</p> -->
|
| 11 |
+
|
| 12 |
+
<p align="center">
|
| 13 |
+
|
| 14 |
+
<h1 align="center">R3PM-Net: Real-time, Robust, Real-world Point Matching Network</h1>
|
| 15 |
+
<p align="center"> <strong>AI4RWC@CVPRW 2026 - Oral Presentation</strong></p>
|
| 16 |
+
<h3 align="center"><a href="https://arxiv.org/abs/2604.05060">Paper</a> | <a href="https://yasiikb.github.io/R3PM-Net/">Project Page</a> | <a href="https://huggingface.co/datasets/YasiiKB/R3PM-Net">Dataset</a></h3>
|
| 17 |
+
<div align="center"></div>
|
| 18 |
+
</p>
|
| 19 |
+
<p align="center"> <img src="assets/r3pmnet_overview.png" width="95%"> </p>
|
| 20 |
+
<p align="left"><i>Figure 1. Overview of the R3PM-Net Architecture. R3PM-Net employs a global-aware feature extraction module with shared weights to learn geometric similarities across a full receptive field.</i></p>
|
| 21 |
+
|
| 22 |
+
## Introduction
|
| 23 |
+
|
| 24 |
+
R3PM-Net is a lightweight, global-aware, object-level point matching network designed to bridge the gap between approaches trained and evaluated on clean, dense, synthetic and real-world industrial point cloud data by prioritizing both generalizability and real-time efficiency.
|
| 25 |
+
|
| 26 |
+
<p align="center"> <img src="assets/teaser.png" width="40%"> </p>
|
| 27 |
+
<p align="left"><i>Figure 2. Examples of R3PM-Net performance on the Sioux-Cranfield dataset.</i></p>
|
| 28 |
+
|
| 29 |
+
### Datasets
|
| 30 |
+
|
| 31 |
+
We propose two datasets; **Sioux-Cranfield** and **Sioux-Scans**, to address the gap between synthetic datasets and real-world industrial data.
|
| 32 |
+
|
| 33 |
+
<p align="center">
|
| 34 |
+
<table>
|
| 35 |
+
<tr>
|
| 36 |
+
<td align="center">
|
| 37 |
+
<img src="assets/sioux_cranfield.png" height="250">
|
| 38 |
+
<br>
|
| 39 |
+
<sub><b>Sioux-Cranfield</b></sub>
|
| 40 |
+
</td>
|
| 41 |
+
<td align="center">
|
| 42 |
+
<img src="assets/sioux_scans.png" height="250">
|
| 43 |
+
<br>
|
| 44 |
+
<sub><b>Sioux-Scans</b></sub>
|
| 45 |
+
</td>
|
| 46 |
+
</tr>
|
| 47 |
+
</table>
|
| 48 |
+
</p>
|
| 49 |
+
<p align="left"><i>Figure 3. CAD models of the Sioux-Cranfield dataset (Left). The first six belong to the Cranfield Assembly benchmark and the rest are contributions of this paper (Sioux dataset). Sioux-Scans point cloud data (Right). Target (blue) and Source (yellow) point clouds for seven distinct objects.</i></p>
|
| 50 |
+
|
| 51 |
+
## Environment Setup
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
# 1. Create environment
|
| 55 |
+
conda env create -f environment.yml
|
| 56 |
+
conda activate r3pm_net
|
| 57 |
+
|
| 58 |
+
# Optionally, install the dependencies and run manually:
|
| 59 |
+
pip install -e .
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
To run the evaluations, please refer to each method's repo to set up the environment:
|
| 63 |
+
[Predator](https://github.com/prs-eth/OverlapPredator),
|
| 64 |
+
[GeoTransformer](https://github.com/qinzheng93/geotransformer),
|
| 65 |
+
[LoGDesc](https://github.com/karim416/LoGDesc), and
|
| 66 |
+
[RegTR](https://github.com/yewzijian/regtr).
|
| 67 |
+
|
| 68 |
+
Everything must be installed into the **same** conda enviromnet.
|
| 69 |
+
|
| 70 |
+
## Data Preparation
|
| 71 |
+
|
| 72 |
+
### ModelNet40
|
| 73 |
+
|
| 74 |
+
Download the dataset from [ModelNet40](http://modelnet.cs.princeton.edu/ModelNet40.zip) and extract it to:
|
| 75 |
+
|
| 76 |
+
```
|
| 77 |
+
data/ModelNet40
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
To save time, download the downsampled ModelNet40 test set from [ModelNet40_Downsampled](https://huggingface.co/datasets/YasiiKB/R3PM-Net/blob/main/down_sampled_modelnet40.zip) and put it in:
|
| 81 |
+
|
| 82 |
+
```
|
| 83 |
+
data/down_sampled_modelnet40
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Sioux-Cranfield
|
| 87 |
+
|
| 88 |
+
Download the dataset from [Sioux_Cranfiled](https://huggingface.co/datasets/YasiiKB/R3PM-Net/blob/main/sioux_cranfield.zip) and put it in:
|
| 89 |
+
|
| 90 |
+
```
|
| 91 |
+
data/sioux_cranfield
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### Sioux-Scans
|
| 95 |
+
|
| 96 |
+
Download the dataset from [Sioux_Scans](https://huggingface.co/datasets/YasiiKB/R3PM-Net/blob/main/sioux_scans.zip) and put it in:
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
data/sioux_scans
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### Fine-tune
|
| 103 |
+
|
| 104 |
+
Download the pickle files (.pkl) from [here](https://huggingface.co/datasets/YasiiKB/R3PM-Net/blob/main/simulators.zip) and put them in:
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
data/simulators
|
| 108 |
+
```
|
| 109 |
+
These pickle files are created from a subset of the Sioux-Cranfield containing the "teeth", "cube", "lime" and "lego" CAD models. There are 320 point cloud pairs, with 80-20 train-test split.
|
| 110 |
+
|
| 111 |
+
Optionally, to create your own datasets, use the scripts in `dataloader`, refering to the README file in that directory.
|
| 112 |
+
|
| 113 |
+
## Pre-trained Models
|
| 114 |
+
|
| 115 |
+
Please download the pretrained model of each method from their repo (links provided above) and follow their instructions as to where to put them.
|
| 116 |
+
|
| 117 |
+
We use RPMNet's pre-trained model (*clean-trained*) for our Zero-shot version. Download it from [here](https://github.com/vinits5/learning3d/tree/master/pretrained/exp_rpmnet/models) and put it in:
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
checkpoints/
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
*Note:* You need to fine-tune the model yourself (see bleow) to get the fine-tuned weights which then you can put in the same directory.
|
| 124 |
+
|
| 125 |
+
## Folder Structure
|
| 126 |
+
|
| 127 |
+
```text
|
| 128 |
+
r3pm_net/
|
| 129 |
+
├── assets/
|
| 130 |
+
├── config/
|
| 131 |
+
│ ├── default.yaml # Training defaults
|
| 132 |
+
│ └── eval.yaml # Paths for evaluation scripts
|
| 133 |
+
├── checkpoints/ # Pre-trained models' weights
|
| 134 |
+
├── data/
|
| 135 |
+
│ ├── down_sampled_modelnet40/
|
| 136 |
+
│ ├── ModelNet40/
|
| 137 |
+
│ ├── sioux_cranfield/
|
| 138 |
+
│ └── sioux_scans/
|
| 139 |
+
├── dataloader/ # Dataset dict generation & loaders
|
| 140 |
+
├── logs/ # Experiment logs
|
| 141 |
+
├── r3pm_net/ # Core package (model, feature extractor, config)
|
| 142 |
+
├── scripts/ # SLURM/Bash and evaluation scripts
|
| 143 |
+
│ ├── eval_modelnet40.py
|
| 144 |
+
│ ├── eval_sioux_cranfield.py
|
| 145 |
+
│ ├── eval_sioux_scans.py
|
| 146 |
+
│ ├── modelnet40.sh
|
| 147 |
+
│ ├── sioux_cranfield.sh
|
| 148 |
+
│ └── sioux_scans.sh
|
| 149 |
+
├── src/
|
| 150 |
+
│ └── train.py # Training
|
| 151 |
+
├── thirdparty/learning3d/ # learning3d (RPMNet, losses, ops, …)
|
| 152 |
+
├── tools/ # Registration eval, metrics, visualization
|
| 153 |
+
├── environment.yml
|
| 154 |
+
├── pyproject.toml
|
| 155 |
+
└── README.md
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## Train
|
| 159 |
+
|
| 160 |
+
To train the model using `data/simulators` or your own dataset run:
|
| 161 |
+
```bash
|
| 162 |
+
python src/train.py
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
## Evaluation
|
| 166 |
+
|
| 167 |
+
Scripts are provided in `scripts/` to reproduce results.
|
| 168 |
+
|
| 169 |
+
**ModelNet40**
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
bash scripts/modelnet40.sh
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
**Sioux-Cranfield**
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
bash scripts/sioux_cranfield.sh
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
**Sioux-Scans**
|
| 182 |
+
This evaluates the proposed hybrid Coarse-to-Fine Registration approach.
|
| 183 |
+
|
| 184 |
+
```bash
|
| 185 |
+
bash scripts/sioux_scans.sh
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
### Manual Execution
|
| 189 |
+
|
| 190 |
+
For example for evaluation on `Sioux-Cranfield`, run:
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
python scripts/eval_sioux_cranfield.py
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
## Results
|
| 197 |
+
*IMPORTANT NOTE: Unfortunately, we cannot release the feature-extraction model and the fine-tuned weights. Therefore, to re-poduce these results you need to implement the feature extractor (based on the paper) and fine-tune it with the provided data.*
|
| 198 |
+
|
| 199 |
+
### ModelNet40
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
| Method | RRE [°] ↓ | RTE [cm] ↓ | CD [cm] ↓ | Fitness ↑ | In. RMSE [cm] ↓ | Time [s] ↓ |
|
| 203 |
+
| ------------------- | ----------------- | ----------------- | ----------------- | ----------------- | ------------------ | ----------------- |
|
| 204 |
+
| RPMNet | 30.898 | **0.002** | 0.153 | *0.998* | 0.094 | *0.021* |
|
| 205 |
+
| Predator | 7.262 | 0.028 | *0.045* | **1.000** | *0.026* | 0.071 |
|
| 206 |
+
| GeoTransformer | 50.357 | 0.215 | 0.255 | 0.921 | 0.101 | 0.065 |
|
| 207 |
+
| RegTR | **1.712** | *0.007* | **0.017** | **1.000** | **0.009** | 0.045 |
|
| 208 |
+
| LoGDesc | 42.762 | 0.158 | 0.183 | 0.978 | 0.097 | 0.075 |
|
| 209 |
+
| **R3PM-Net (ours)** | *5.198* | 0.010 | 0.052 | **1.000** | 0.029 | **0.007** |
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
> **Notes:** **Best** results are in bold; *Second-best* results are underlined.
|
| 213 |
+
|
| 214 |
+
### Sioux-Cranfield
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
| Method | RRE [°] ↓ | RTE [cm] ↓ | CD [cm] ↓ | Fitness ↑ | In. RMSE [cm] ↓ | Time [s] ↓ |
|
| 218 |
+
| ------------------- | ----------------- | ----------------- | ----------------- | ----------------- | ------------------ | ----------------- |
|
| 219 |
+
| RPMNet | 32.217 | **0.002** | 0.160 | *0.997* | 0.098 | 0.021 |
|
| 220 |
+
| Predator | 16.448 | 0.044 | 0.072 | **1.000** | 0.042 | 0.071 |
|
| 221 |
+
| GeoTrans. | 45.582 | 0.183 | 0.297 | 0.906 | 0.111 | 0.065 |
|
| 222 |
+
| RegTR | **1.311** | *0.004* | **0.023** | **1.000** | **0.012** | 0.045 |
|
| 223 |
+
| LoGDesc | 121.224 | 0.773 | 0.692 | 0.718 | 0.224 | 0.075 |
|
| 224 |
+
| **R3PM-Net (ours)** | *5.451* | 0.006 | *0.054* | **1.000** | *0.030* | **0.006** |
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
### Sioux-Scans
|
| 228 |
+
<p align="center"> <img src="assets/success_cases.png" width="85%"> </p>
|
| 229 |
+
|
| 230 |
+
<p align="left"><i>Figure 4. Qualitative registration results of R3PM-Net on real-world event-camera data. It successfully aligns the "teeth" and "cube" models. The fine-tuned version also solves the "lime" and "house".</i></p>
|
| 231 |
+
|
| 232 |
+
## Acknowledgement
|
| 233 |
+
|
| 234 |
+
We adapted some codes from some awesome repositories including [Learning3D](https://github.com/vinits5/learning3d) and [RPMNet](https://github.com/yewzijian/RPMNet). Thanks for making the codes publicly available.
|
| 235 |
+
|
| 236 |
+
## Citation
|
| 237 |
+
|
| 238 |
+
If you find this repository useful, please consider citing:
|
| 239 |
+
|
| 240 |
+
```bibtex
|
| 241 |
+
@misc{kashefbahrami2026r3pmnetrealtimerobustrealworld,
|
| 242 |
+
title={R3PM-Net: Real-time, Robust, Real-world Point Matching Network},
|
| 243 |
+
author={Yasaman Kashefbahrami and Erkut Akdag and Panagiotis Meletis and Evgeniya Balmashnova and Dip Goswami and Egor Bondarau},
|
| 244 |
+
year={2026},
|
| 245 |
+
eprint={2604.05060},
|
| 246 |
+
archivePrefix={arXiv},
|
| 247 |
+
primaryClass={cs.CV},
|
| 248 |
+
url={https://arxiv.org/abs/2604.05060},
|
| 249 |
+
}
|
| 250 |
+
```
|
| 251 |
+
|
assets/r3pmnet_overview.png
ADDED
|
Git LFS Details
|
assets/sioux_cranfield.png
ADDED
|
Git LFS Details
|
assets/sioux_scans.png
ADDED
|
Git LFS Details
|
assets/success_cases.png
ADDED
|
Git LFS Details
|
assets/teaser.png
ADDED
|
Git LFS Details
|
config/default.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default training and data paths for R3PM-Net
|
| 2 |
+
|
| 3 |
+
exp_name: exp_r3pmnet
|
| 4 |
+
eval: false
|
| 5 |
+
save_dir: ""
|
| 6 |
+
|
| 7 |
+
fine_tune_feature_extractor: tune
|
| 8 |
+
transfer_weights: "" # Optional: leave empty to skip loading
|
| 9 |
+
emb_dims: 1024
|
| 10 |
+
symfn: max
|
| 11 |
+
|
| 12 |
+
seed: 1234
|
| 13 |
+
workers: 4
|
| 14 |
+
batch_size: 5
|
| 15 |
+
epochs: 2
|
| 16 |
+
start_epoch: 0
|
| 17 |
+
optimizer: Adam
|
| 18 |
+
resume: ""
|
| 19 |
+
pretrained: ""
|
| 20 |
+
device: cuda:0
|
| 21 |
+
|
| 22 |
+
# Pickled Registration Dataset dicts (keys: template, source, transformation)
|
| 23 |
+
train_dict_path: data/simulators/data_dict_train.pkl
|
| 24 |
+
test_dict_path: data/simulators/data_dict_test.pkl
|
config/eval.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Paths for scripts/evaluation (loaded by scripts/*.py).
|
| 2 |
+
|
| 3 |
+
data_root: data
|
| 4 |
+
pretrained_rpmnet_dir: checkpoints
|
| 5 |
+
|
| 6 |
+
modelnet40:
|
| 7 |
+
dataset_path: data/ModelNet40
|
| 8 |
+
cache_dir: data/down_sampled_modelnet40
|
| 9 |
+
|
| 10 |
+
sioux:
|
| 11 |
+
base_dir: data
|
| 12 |
+
|
| 13 |
+
methods:
|
| 14 |
+
geotransformer:
|
| 15 |
+
root: GeoTransformer
|
| 16 |
+
exp_subdir: GeoTransformer/experiments/geotransformer.modelnet.rpmnet.stage4.gse.k3.max.oacl.stage2.sinkhorn
|
| 17 |
+
weights_path: GeoTransformer/weights/geotransformer-modelnet.pth.tar
|
| 18 |
+
predator:
|
| 19 |
+
root: OverlapPredator
|
| 20 |
+
config_path: OverlapPredator/configs/test/modelnet.yaml
|
| 21 |
+
weights_path: null
|
| 22 |
+
logdesc:
|
| 23 |
+
root: LoGDesc
|
| 24 |
+
weights_path: LoGDesc/pre-trained/best_model.pth
|
| 25 |
+
regtr:
|
| 26 |
+
root: RegTR
|
| 27 |
+
ckpt_path: RegTR/trained_models/modelnet/ckpt/model-best.pth
|
| 28 |
+
config_path: RegTR/trained_models/modelnet/config.yaml
|
dataloader/README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataloaders
|
| 2 |
+
|
| 3 |
+
This directory contains scripts to generate Simulated datasets for training and testing.
|
| 4 |
+
It uses functionalities from `tools` folder to:
|
| 5 |
+
|
| 6 |
+
`generate_dataset`
|
| 7 |
+
|
| 8 |
+
- load and downsample data
|
| 9 |
+
- compute normals if needed
|
| 10 |
+
- apply random transformations
|
| 11 |
+
- add augmentations (noise, outliers and occlusion)
|
| 12 |
+
- save point clouds
|
| 13 |
+
|
| 14 |
+
**Note: Two input point clouds must have same length.**
|
| 15 |
+
|
| 16 |
+
`generate_dataset_dict`
|
| 17 |
+
|
| 18 |
+
- save generated dataset in dictionaries suitable to train models (following Learning3d requirments)
|
| 19 |
+
- checks dimensions (to meet Learning3d requirments)
|
| 20 |
+
|
| 21 |
+
`combine_dataset_dict`
|
| 22 |
+
|
| 23 |
+
- shuffle and combine all generated dictionaries into one
|
| 24 |
+
- split train and test sets
|
| 25 |
+
|
| 26 |
+
## How to generate datasets?
|
| 27 |
+
|
| 28 |
+
Modify the `args.txt` file to contain the correct paths and other specifications e.g. downsampling rate, noise level, etc. Other default arguments in `data_dict_generator.py` can also be changed.
|
| 29 |
+
|
| 30 |
+
1. Generate transformed target point clouds + GT transforms
|
| 31 |
+
Change `--action in dataloader/args.txt` to `generate_dataset` and run:
|
| 32 |
+
```
|
| 33 |
+
python dataloader/data_dict_generator.py @dataloader/args.txt
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
2. Generate train/test .pkl dicts
|
| 37 |
+
Change `--action in dataloader/args.txt` to `generate_dataset_dict`, then run the above script again.
|
| 38 |
+
|
| 39 |
+
3. Combine multiple object dicts
|
| 40 |
+
Set `--action combine_dataset_dict` and run again to get train and test `dict.pkl` files.
|
| 41 |
+
|
| 42 |
+
### Manual Run (without args.txt)
|
| 43 |
+
Optinally you can manually run:
|
| 44 |
+
```
|
| 45 |
+
python dataloader/data_dict_generator.py \
|
| 46 |
+
--pcdPath /path/to/source_scan.pcd \
|
| 47 |
+
--cadPath /path/to/object.stl \
|
| 48 |
+
--name teeth \
|
| 49 |
+
--action generate_dataset \
|
| 50 |
+
--every_k_points 100 \
|
| 51 |
+
--num_transformation 50 \
|
| 52 |
+
--angles 0 90 180 \
|
| 53 |
+
--translation_range -1 1 \
|
| 54 |
+
--index 0 \
|
| 55 |
+
--noise_level 0 \
|
| 56 |
+
--outlier_level 0 \
|
| 57 |
+
--outlier_bounds -0.05 0.05 \
|
| 58 |
+
--occ_level 0 \
|
| 59 |
+
--save_path data/simulators
|
| 60 |
+
```
|
| 61 |
+
then:
|
| 62 |
+
```
|
| 63 |
+
python dataloader/data_dict_generator.py \
|
| 64 |
+
--pcdPath /path/to/source_scan.pcd \
|
| 65 |
+
--cadPath /path/to/object.stl \
|
| 66 |
+
--name teeth \
|
| 67 |
+
--action generate_dataset_dict \
|
| 68 |
+
--dataset_size 50 \
|
| 69 |
+
--index 0 \
|
| 70 |
+
--save_path data/simulators
|
| 71 |
+
```
|
dataloader/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Dataset loaders and generators
|
dataloader/args.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--pcdPath data/sioux_scans/teeth_clean.ply
|
| 2 |
+
--cadPath data/sioux_cranfield/teeth.stl
|
| 3 |
+
--action combine_dataset_dict
|
| 4 |
+
--name teeth
|
| 5 |
+
--every_k_points 100
|
| 6 |
+
--num_transformation 50
|
| 7 |
+
--angles 0 90 180
|
| 8 |
+
--translation_range -1 1
|
| 9 |
+
--dataset_size 50
|
| 10 |
+
--index 0
|
| 11 |
+
--noise_level 0
|
| 12 |
+
--outlier_level 0
|
| 13 |
+
--outlier_bounds -0.05 0.05
|
| 14 |
+
--occ_level 0
|
dataloader/data_dict_generator.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import shlex
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 14 |
+
|
| 15 |
+
from tools import data
|
| 16 |
+
from dataloader.dataset_generator import combine_dataset_dict, generate_dataset, generate_dataset_dict
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 20 |
+
|
| 21 |
+
def main():
|
| 22 |
+
# Set up the argument parser
|
| 23 |
+
parser = argparse.ArgumentParser(description='Automate dataset generation and processing.')
|
| 24 |
+
|
| 25 |
+
# Define arguments (change these as needed)
|
| 26 |
+
parser.add_argument('--pcdPath', type=str, required=True, help='Path to the PCD file')
|
| 27 |
+
parser.add_argument('--cadPath', type=str, required=True, help='Path to the CAD file')
|
| 28 |
+
parser.add_argument('--action', type=str, choices=['generate_dataset', 'generate_dataset_dict', 'combine_dataset_dict'], required=True, help='Action to perform')
|
| 29 |
+
parser.add_argument('--compute_normals', action='store_true', help='Flag to compute normals')
|
| 30 |
+
parser.add_argument('--every_k_points', type=int, default=1, help='Sampling rate for points')
|
| 31 |
+
parser.add_argument('--save', action='store_true', help='Flag to save the generated dataset')
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
'--save_path',
|
| 34 |
+
type=str,
|
| 35 |
+
default='data/simulators',
|
| 36 |
+
help='Directory to save generated datasets (relative to repo root if not absolute)',
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument('--name', type=str, required=True, help='Name identifier for the dataset (e.g., teeth, cube, etc.)')
|
| 39 |
+
|
| 40 |
+
# Additional parameters for dataset generation (change these as needed)
|
| 41 |
+
parser.add_argument('--num_transformation', type=int, default=50, help='Number of transformations')
|
| 42 |
+
parser.add_argument('--angles', type=int, nargs='+', default=list(range(0, 360, 10)), help='Rotation angles')
|
| 43 |
+
parser.add_argument('--translation_range', type=float, nargs=2, default=(-1, 1), help='Translation range')
|
| 44 |
+
parser.add_argument('--dataset_size', type=int, default=400, help='Size of the dataset to generate')
|
| 45 |
+
parser.add_argument('--index', type=int, default=0, help='Index for dataset generation')
|
| 46 |
+
parser.add_argument('--noise_level', type=float, default=0, help='Noise level')
|
| 47 |
+
parser.add_argument('--outlier_level', type=float, default=0, help='Outlier level')
|
| 48 |
+
parser.add_argument('--outlier_bounds', type=float, nargs=2, default=(-10, 10), help='Outlier bounds')
|
| 49 |
+
parser.add_argument('--occ_level', type=float, default=0, help='Occlusion level')
|
| 50 |
+
|
| 51 |
+
# Parse the arguments
|
| 52 |
+
|
| 53 |
+
# Check if an argument file is being used
|
| 54 |
+
if sys.argv[1].startswith('@'):
|
| 55 |
+
args_file = sys.argv[1][1:] # Strip the '@' from the filename
|
| 56 |
+
with open(args_file, 'r') as file:
|
| 57 |
+
# Read and split arguments from the file
|
| 58 |
+
args = parser.parse_args(shlex.split(file.read()))
|
| 59 |
+
else:
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
|
| 62 |
+
# Print out the arguments to verify
|
| 63 |
+
print(vars(args))
|
| 64 |
+
|
| 65 |
+
# Load the data
|
| 66 |
+
np.random.seed(42)
|
| 67 |
+
if args.compute_normals:
|
| 68 |
+
_, cad, _, cad_normals = data.load_data(args.pcdPath, args.cadPath, every_k_points=args.every_k_points, same_length=True, compute_normals=True)
|
| 69 |
+
suffix = '_with_normals'
|
| 70 |
+
else:
|
| 71 |
+
_, cad = data.load_data(args.pcdPath, args.cadPath, every_k_points=args.every_k_points, same_length=True)
|
| 72 |
+
cad_normals = None
|
| 73 |
+
suffix = ''
|
| 74 |
+
source = copy.deepcopy(cad)
|
| 75 |
+
|
| 76 |
+
rp = Path(args.save_path)
|
| 77 |
+
if not rp.is_absolute():
|
| 78 |
+
rp = _REPO_ROOT / args.save_path
|
| 79 |
+
ROOT_DIR = str(rp.resolve())
|
| 80 |
+
if not ROOT_DIR.endswith(os.sep):
|
| 81 |
+
ROOT_DIR += os.sep
|
| 82 |
+
|
| 83 |
+
# Perform the selected action
|
| 84 |
+
if args.action == 'generate_dataset':
|
| 85 |
+
logging.info('Generating dataset...')
|
| 86 |
+
generate_dataset(source, args.pcdPath, args.cadPath, args.num_transformation, args.angles, args.translation_range, args.index, args.noise_level, args.outlier_level, args.outlier_bounds, args.occ_level, save_dir=ROOT_DIR)
|
| 87 |
+
|
| 88 |
+
elif args.action == 'generate_dataset_dict':
|
| 89 |
+
logging.info('Generating dataset dictionary...')
|
| 90 |
+
output_train_file = f'{ROOT_DIR}data_dict_train_{args.name}{suffix}.pkl'
|
| 91 |
+
output_test_file = f'{ROOT_DIR}data_dict_test_{args.name}{suffix}.pkl'
|
| 92 |
+
generate_dataset_dict(source, args.dataset_size, args.index, output_train_file, output_test_file, cad_normals)
|
| 93 |
+
|
| 94 |
+
elif args.action == 'combine_dataset_dict':
|
| 95 |
+
logging.info('Combining dataset dictionaries...')
|
| 96 |
+
train_files = [
|
| 97 |
+
f'{ROOT_DIR}data_dict_train_teeth{suffix}.pkl'
|
| 98 |
+
# f'{ROOT_DIR}data_dict_train_elephant{suffix}.pkl',
|
| 99 |
+
# f'{ROOT_DIR}data_dict_train_house{suffix}.pkl',
|
| 100 |
+
# f'{ROOT_DIR}data_dict_train_shoe{suffix}.pkl'
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
test_files = [
|
| 104 |
+
f'{ROOT_DIR}data_dict_test_teeth{suffix}.pkl'
|
| 105 |
+
# f'{ROOT_DIR}data_dict_test_elephant{suffix}.pkl',
|
| 106 |
+
# f'{ROOT_DIR}data_dict_test_house{suffix}.pkl',
|
| 107 |
+
# f'{ROOT_DIR}data_dict_test_shoe{suffix}.pkl'
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
output_train_file = f'{ROOT_DIR}data_dict_train_{suffix}.pkl'
|
| 111 |
+
output_test_file = f'{ROOT_DIR}data_dict_test_{suffix}.pkl'
|
| 112 |
+
|
| 113 |
+
combine_dataset_dict(train_files, test_files, output_train_file, output_test_file)
|
| 114 |
+
|
| 115 |
+
else:
|
| 116 |
+
logging.warning('No valid action selected.')
|
| 117 |
+
|
| 118 |
+
if __name__ == '__main__':
|
| 119 |
+
main()
|
dataloader/dataset_generator.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import open3d as o3
|
| 10 |
+
|
| 11 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 14 |
+
|
| 15 |
+
from tools import augmentation, data, transformations
|
| 16 |
+
|
| 17 |
+
_SIM_DATA = _REPO_ROOT / "data" / "simulators"
|
| 18 |
+
'''
|
| 19 |
+
This module provides functions to generate a dataset of point clouds with random transformations, with options for noise, outliers, and occlusions.
|
| 20 |
+
It also includes functions to check the shape of the data and to generate a data dictionary for training and testing,
|
| 21 |
+
and a function to combine multiple dataset dictionaries.
|
| 22 |
+
'''
|
| 23 |
+
|
| 24 |
+
def generate_dataset(pcd, pcdPath, cadPath, num_transformation, angles, translation_range, index, noise_level = 0, outlier_level = 0, outlier_bounds = (-10, 10), occ_level = 0, save_dir=None):
|
| 25 |
+
'''
|
| 26 |
+
A function to generate a dataset of point clouds with random transformations.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
pcd (open3d.geometry.PointCloud): The source point cloud
|
| 30 |
+
pcdPath (str): The path to the source point cloud
|
| 31 |
+
cadPath (str): The path to the target point cloud
|
| 32 |
+
num_transformation (int): The number of transformations to generate
|
| 33 |
+
angles (numpy.ndarray): The range of angles for the random transformations
|
| 34 |
+
translation_range (tuple): The range of translations for the random transformations
|
| 35 |
+
index (int): The index to start saving the generated dataset
|
| 36 |
+
noise_level (float): The level of noise to add to the point clouds
|
| 37 |
+
outlier_level (float): The level of outliers to add to the point clouds
|
| 38 |
+
occ_level (float): The level of occlusions to add to the point clouds
|
| 39 |
+
save (bool): A flag to save the generated dataset
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
None
|
| 43 |
+
'''
|
| 44 |
+
np.random.seed(42)
|
| 45 |
+
target_list = []
|
| 46 |
+
gt_transformation_list = []
|
| 47 |
+
|
| 48 |
+
for i in range(num_transformation):
|
| 49 |
+
# Generate random gt transformation
|
| 50 |
+
x_angle= np.random.uniform(angles[0], angles[-1], size=1)
|
| 51 |
+
y_angle= np.random.uniform(angles[0], angles[-1], size=1)
|
| 52 |
+
z_angle= np.random.uniform(angles[0], angles[-1], size=1)
|
| 53 |
+
gt_transformation = transformations.create_transformation(x_angle, y_angle, z_angle, translation_range)
|
| 54 |
+
|
| 55 |
+
target = copy.deepcopy(pcd)
|
| 56 |
+
target.transform(gt_transformation)
|
| 57 |
+
|
| 58 |
+
if noise_level != 0:
|
| 59 |
+
target = augmentation.apply_noise(target, noise_level)
|
| 60 |
+
print('Noise applied')
|
| 61 |
+
|
| 62 |
+
if outlier_level != 0 or occ_level != 0:
|
| 63 |
+
_, another_cad = data.load_data(pcdPath, cadPath, every_k_points=1)
|
| 64 |
+
target = copy.deepcopy(another_cad).transform(gt_transformation)
|
| 65 |
+
if occ_level != 0:
|
| 66 |
+
target, _ = augmentation.apply_occlusion(target, occ_level)
|
| 67 |
+
print('Occlusion applied')
|
| 68 |
+
if outlier_level != 0:
|
| 69 |
+
target = augmentation.add_outliers(target, outlier_level, outlier_lowerbound=outlier_bounds[0], outlier_upperbound=outlier_bounds[1])
|
| 70 |
+
print('Outliers applied')
|
| 71 |
+
|
| 72 |
+
# randomly take points away from target to get to same length as source
|
| 73 |
+
if len(target.points) >= len(pcd.points):
|
| 74 |
+
np.random.seed(42)
|
| 75 |
+
target_points = np.asarray(target.points)
|
| 76 |
+
indices = np.random.choice(len(target_points), 1441, replace=False) # change len(source.points) to a specific num if you want to have a fixed number of points
|
| 77 |
+
sampled_points = target_points[indices]
|
| 78 |
+
target.points = o3.utility.Vector3dVector(sampled_points)
|
| 79 |
+
else:
|
| 80 |
+
print('Target has fewer points than source and can\'t be downsampled to the same length.')
|
| 81 |
+
|
| 82 |
+
print(f'size of source and target: {len(pcd.points)}, {len(target.points)}')
|
| 83 |
+
target_list.append(target)
|
| 84 |
+
gt_transformation_list.append(gt_transformation)
|
| 85 |
+
|
| 86 |
+
# Save the generated dataset
|
| 87 |
+
if save_dir is not None:
|
| 88 |
+
if not os.path.exists(save_dir):
|
| 89 |
+
os.makedirs(save_dir)
|
| 90 |
+
|
| 91 |
+
for i, (target, transformation) in enumerate(zip(target_list, gt_transformation_list)):
|
| 92 |
+
target_path = os.path.join(save_dir, f"target_{i+index}.pcd")
|
| 93 |
+
transformation_path = os.path.join(save_dir, f"transformation_{i+index}.npy")
|
| 94 |
+
o3.io.write_point_cloud(target_path, target)
|
| 95 |
+
np.save(transformation_path, transformation)
|
| 96 |
+
|
| 97 |
+
def check_shape(data, expected_shape_3d, expected_shape_6d):
|
| 98 |
+
return data.shape == expected_shape_3d or data.shape == expected_shape_6d
|
| 99 |
+
|
| 100 |
+
def generate_dataset_dict(source, dataset_size, index, output_train_file_path, output_test_file_path, source_normals = None):
|
| 101 |
+
'''
|
| 102 |
+
This function shuffles the dataset and generates a data_dict for the training and testing data following the pattern acceptable to Learning3D.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
source (open3d.geometry.PointCloud): The source point cloud
|
| 106 |
+
dataset_size (int): The size of the dataset
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
None
|
| 110 |
+
'''
|
| 111 |
+
np.random.seed(42)
|
| 112 |
+
transformed_pcds = []
|
| 113 |
+
gt_transformations = []
|
| 114 |
+
|
| 115 |
+
# Load the transformed point clouds and ground truth transformations
|
| 116 |
+
for i in range(index,index+dataset_size):
|
| 117 |
+
transformed_pcd = o3.io.read_point_cloud(str(_SIM_DATA / f"target_{i}.pcd"))
|
| 118 |
+
gt_transformation = np.load(str(_SIM_DATA / f"transformation_{i}.npy"))
|
| 119 |
+
|
| 120 |
+
if source_normals is not None: # we also need target normals
|
| 121 |
+
M = np.linalg.inv(gt_transformation).T
|
| 122 |
+
target_normals = np.dot(source_normals, M[:3,:3]) # transformed_normals = normals * (transformation)^-1.T
|
| 123 |
+
transformed_points = np.concatenate((np.asarray(transformed_pcd.points), target_normals), axis=1)
|
| 124 |
+
else:
|
| 125 |
+
transformed_points = np.asarray(transformed_pcd.points).astype(np.float32)
|
| 126 |
+
|
| 127 |
+
transformed_pcds.append(transformed_points)
|
| 128 |
+
gt_transformations.append(gt_transformation)
|
| 129 |
+
|
| 130 |
+
# Shuffle the transformed point clouds and ground truth transformations in the same way
|
| 131 |
+
temp = list(zip(transformed_pcds, gt_transformations))
|
| 132 |
+
random.shuffle(temp)
|
| 133 |
+
transformed_pcds, gt_transformations = zip(*temp)
|
| 134 |
+
|
| 135 |
+
# Convert lists to numpy arrays
|
| 136 |
+
transformed_pcds_np = np.array(transformed_pcds)
|
| 137 |
+
gt_transformations_np = np.array(gt_transformations)
|
| 138 |
+
|
| 139 |
+
if source_normals is not None:
|
| 140 |
+
source = np.concatenate((np.asarray(source.points), source_normals), axis=1)
|
| 141 |
+
else:
|
| 142 |
+
source = np.asarray(source.points).astype(np.float32)
|
| 143 |
+
|
| 144 |
+
data_dict = {
|
| 145 |
+
'template': np.tile(source, (dataset_size, 1, 1)),
|
| 146 |
+
'source': transformed_pcds_np,
|
| 147 |
+
'transformation': gt_transformations_np
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
# Split the data_dict into training and testing data_dict
|
| 151 |
+
train_size = int(0.8 * dataset_size)
|
| 152 |
+
test_size = dataset_size - train_size
|
| 153 |
+
num_points = len(source)
|
| 154 |
+
|
| 155 |
+
data_dict_train = {}
|
| 156 |
+
data_dict_test = {}
|
| 157 |
+
for key in data_dict.keys():
|
| 158 |
+
data_dict_train[key] = data_dict[key][0:train_size]
|
| 159 |
+
data_dict_test[key] = data_dict[key][train_size:]
|
| 160 |
+
|
| 161 |
+
assert set(data_dict_train.keys()) == {'template', 'source', 'transformation'}
|
| 162 |
+
assert set(data_dict_test.keys()) == {'template', 'source', 'transformation'}
|
| 163 |
+
|
| 164 |
+
expected_shape_3d_train = (train_size, num_points, 3)
|
| 165 |
+
expected_shape_6d_train = (train_size, num_points, 6)
|
| 166 |
+
|
| 167 |
+
assert check_shape(data_dict_train['template'], expected_shape_3d_train, expected_shape_6d_train), f"Expected shape: {expected_shape_3d_train} or {expected_shape_6d_train}, but got {data_dict_train['template'].shape}"
|
| 168 |
+
assert check_shape(data_dict_train['source'], expected_shape_3d_train, expected_shape_6d_train), f"Expected shape: {expected_shape_3d_train} or {expected_shape_6d_train}, but got {data_dict_train['source'].shape}"
|
| 169 |
+
assert data_dict_train['transformation'].shape == (train_size, 4, 4), f"Expected shape: {(train_size, 4, 4)}, but got {data_dict_train['transformation'].shape}"
|
| 170 |
+
|
| 171 |
+
expected_shape_3d_test = (test_size, num_points, 3)
|
| 172 |
+
expected_shape_6d_test = (test_size, num_points, 6)
|
| 173 |
+
|
| 174 |
+
assert check_shape(data_dict_test['template'], expected_shape_3d_test, expected_shape_6d_test), f"Expected shape: {expected_shape_3d_test} or {expected_shape_6d_test}, but got {data_dict_test['template'].shape}"
|
| 175 |
+
assert check_shape(data_dict_test['source'], expected_shape_3d_test, expected_shape_6d_test), f"Expected shape: {expected_shape_3d_test} or {expected_shape_6d_test}, but got {data_dict_test['source'].shape}"
|
| 176 |
+
assert data_dict_test['transformation'].shape == (test_size, 4, 4), f"Expected shape: {(test_size, 4, 4)}, but got {data_dict_test['transformation'].shape}"
|
| 177 |
+
|
| 178 |
+
with open(output_train_file_path, 'wb') as f:
|
| 179 |
+
pickle.dump(data_dict_train, f)
|
| 180 |
+
print(f"train_dict saved to {output_train_file_path}")
|
| 181 |
+
|
| 182 |
+
with open(output_test_file_path, 'wb') as f:
|
| 183 |
+
pickle.dump(data_dict_test, f)
|
| 184 |
+
print(f"test_dict saved to {output_test_file_path}")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def combine_dataset_dict(train_files, test_files, output_train_file_path, output_test_file_path):
|
| 188 |
+
'''
|
| 189 |
+
Combine and shuffle dictionaries from multiple files.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
train_files (list of str): List of file paths to training dictionaries.
|
| 193 |
+
test_files (list of str): List of file paths to testing dictionaries.
|
| 194 |
+
output_train_file (str): Output file path for the combined training dictionary.
|
| 195 |
+
output_test_file (str): Output file path for the combined testing dictionary.
|
| 196 |
+
'''
|
| 197 |
+
|
| 198 |
+
# Load the dictionaries from the .pkl files
|
| 199 |
+
train_dicts = [pickle.load(open(file, 'rb')) for file in train_files]
|
| 200 |
+
test_dicts = [pickle.load(open(file, 'rb')) for file in test_files]
|
| 201 |
+
|
| 202 |
+
# Combine the dictionaries
|
| 203 |
+
combined_train_dict = {}
|
| 204 |
+
combined_test_dict = {}
|
| 205 |
+
|
| 206 |
+
for key in train_dicts[0].keys():
|
| 207 |
+
combined_train_dict[key] = np.concatenate([d[key] for d in train_dicts], axis=0)
|
| 208 |
+
combined_test_dict[key] = np.concatenate([d[key] for d in test_dicts], axis=0)
|
| 209 |
+
|
| 210 |
+
# Shuffle
|
| 211 |
+
train_combined_list = list(zip(combined_train_dict['template'], combined_train_dict['source'], combined_train_dict['transformation']))
|
| 212 |
+
test_combined_list = list(zip(combined_test_dict['template'], combined_test_dict['source'], combined_test_dict['transformation']))
|
| 213 |
+
|
| 214 |
+
random.shuffle(train_combined_list)
|
| 215 |
+
random.shuffle(test_combined_list)
|
| 216 |
+
|
| 217 |
+
combined_train_dict['template'], combined_train_dict['source'], combined_train_dict['transformation'] = zip(*train_combined_list)
|
| 218 |
+
combined_test_dict['template'], combined_test_dict['source'], combined_test_dict['transformation'] = zip(*test_combined_list)
|
| 219 |
+
|
| 220 |
+
# Convert back to numpy arrays
|
| 221 |
+
combined_train_dict['template'] = np.array(combined_train_dict['template'])
|
| 222 |
+
combined_train_dict['source'] = np.array(combined_train_dict['source'])
|
| 223 |
+
combined_train_dict['transformation'] = np.array(combined_train_dict['transformation'])
|
| 224 |
+
|
| 225 |
+
combined_test_dict['template'] = np.array(combined_test_dict['template'])
|
| 226 |
+
combined_test_dict['source'] = np.array(combined_test_dict['source'])
|
| 227 |
+
combined_test_dict['transformation'] = np.array(combined_test_dict['transformation'])
|
| 228 |
+
|
| 229 |
+
# Checks
|
| 230 |
+
train_size = len(combined_train_dict['source'])
|
| 231 |
+
test_size = len(combined_test_dict['source'])
|
| 232 |
+
num_points = combined_train_dict['source'].shape[1]
|
| 233 |
+
|
| 234 |
+
assert set(combined_train_dict.keys()) == {'template', 'source', 'transformation'}
|
| 235 |
+
assert set(combined_test_dict.keys()) == {'template', 'source', 'transformation'}
|
| 236 |
+
|
| 237 |
+
expected_shape_3d_train = (train_size, num_points, 3)
|
| 238 |
+
expected_shape_6d_train = (train_size, num_points, 6)
|
| 239 |
+
|
| 240 |
+
assert check_shape(combined_train_dict['template'], expected_shape_3d_train, expected_shape_6d_train), f"Expected shape: {expected_shape_3d_train} or {expected_shape_6d_train}, but got {combined_train_dict['template'].shape}"
|
| 241 |
+
assert check_shape(combined_train_dict['source'], expected_shape_3d_train, expected_shape_6d_train), f"Expected shape: {expected_shape_3d_train} or {expected_shape_6d_train}, but got {combined_train_dict['source'].shape}"
|
| 242 |
+
assert combined_train_dict['transformation'].shape == (train_size, 4, 4), f"Expected shape: {(train_size, 4, 4)}, but got {combined_train_dict['transformation'].shape}"
|
| 243 |
+
|
| 244 |
+
expected_shape_3d_test = (test_size, num_points, 3)
|
| 245 |
+
expected_shape_6d_test = (test_size, num_points, 6)
|
| 246 |
+
|
| 247 |
+
assert check_shape(combined_test_dict['template'], expected_shape_3d_test, expected_shape_6d_test), f"Expected shape: {expected_shape_3d_test} or {expected_shape_6d_test}, but got {combined_test_dict['template'].shape}"
|
| 248 |
+
assert check_shape(combined_test_dict['source'], expected_shape_3d_test, expected_shape_6d_test), f"Expected shape: {expected_shape_3d_test} or {expected_shape_6d_test}, but got {combined_test_dict['source'].shape}"
|
| 249 |
+
assert combined_test_dict['transformation'].shape == (test_size, 4, 4), f"Expected shape: {(test_size, 4, 4)}, but got {combined_test_dict['transformation'].shape}"
|
| 250 |
+
|
| 251 |
+
# Save the dictionaries
|
| 252 |
+
with open(output_train_file_path, 'wb') as f:
|
| 253 |
+
pickle.dump(combined_train_dict, f)
|
| 254 |
+
print(f"combined_train_dict saved to {output_train_file_path}")
|
| 255 |
+
|
| 256 |
+
with open(output_test_file_path, 'wb') as f:
|
| 257 |
+
pickle.dump(combined_test_dict, f)
|
| 258 |
+
print(f"combined_test_dict saved to {output_train_file_path}")
|
dataloader/user_data.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class ClassificationData:
|
| 6 |
+
def __init__(self, data_dict):
|
| 7 |
+
self.data_dict = data_dict
|
| 8 |
+
self.pcs = self.find_attribute('pcs')
|
| 9 |
+
self.labels = self.find_attribute('labels')
|
| 10 |
+
self.check_data()
|
| 11 |
+
|
| 12 |
+
def find_attribute(self, attribute):
|
| 13 |
+
try:
|
| 14 |
+
attribute_data = self.data_dict[attribute]
|
| 15 |
+
except:
|
| 16 |
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
| 17 |
+
return attribute_data
|
| 18 |
+
|
| 19 |
+
def check_data(self):
|
| 20 |
+
assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
|
| 21 |
+
assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
|
| 22 |
+
|
| 23 |
+
if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
|
| 24 |
+
if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)
|
| 25 |
+
|
| 26 |
+
assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return self.pcs.shape[0]
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, index):
|
| 33 |
+
return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RegistrationData:
|
| 37 |
+
def __init__(self, data_dict):
|
| 38 |
+
self.data_dict = data_dict
|
| 39 |
+
self.template = self.find_attribute('template')
|
| 40 |
+
self.source = self.find_attribute('source')
|
| 41 |
+
self.transformation = self.find_attribute('transformation')
|
| 42 |
+
self.check_data()
|
| 43 |
+
|
| 44 |
+
# def find_attribute(self, attribute):
|
| 45 |
+
# try:
|
| 46 |
+
# attribute_data = self.data[attribute]
|
| 47 |
+
# except:
|
| 48 |
+
# print("Given data directory has no key attribute \"{}\"".format(attribute))
|
| 49 |
+
# return attribute_data
|
| 50 |
+
|
| 51 |
+
def find_attribute(self, attribute):
|
| 52 |
+
attribute_data = None
|
| 53 |
+
if attribute in self.data_dict:
|
| 54 |
+
attribute_data = self.data_dict[attribute]
|
| 55 |
+
else:
|
| 56 |
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
| 57 |
+
return attribute_data
|
| 58 |
+
|
| 59 |
+
def check_data(self):
|
| 60 |
+
assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
|
| 61 |
+
assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
|
| 62 |
+
assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)
|
| 63 |
+
|
| 64 |
+
if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
|
| 65 |
+
if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
|
| 66 |
+
if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)
|
| 67 |
+
|
| 68 |
+
assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
|
| 69 |
+
assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"
|
| 70 |
+
|
| 71 |
+
def __len__(self):
|
| 72 |
+
return self.template.shape[0]
|
| 73 |
+
|
| 74 |
+
def __getitem__(self, index):
|
| 75 |
+
return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class FlowData:
|
| 79 |
+
def __init__(self, data_dict):
|
| 80 |
+
self.data_dict = data_dict
|
| 81 |
+
self.frame1 = self.find_attribute('frame1')
|
| 82 |
+
self.frame2 = self.find_attribute('frame2')
|
| 83 |
+
self.flow = self.find_attribute('flow')
|
| 84 |
+
self.check_data()
|
| 85 |
+
|
| 86 |
+
def find_attribute(self, attribute):
|
| 87 |
+
try:
|
| 88 |
+
attribute_data = self.data[attribute]
|
| 89 |
+
except:
|
| 90 |
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
| 91 |
+
return attribute_data
|
| 92 |
+
|
| 93 |
+
def check_data(self):
|
| 94 |
+
assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
|
| 95 |
+
assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
|
| 96 |
+
assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)
|
| 97 |
+
|
| 98 |
+
if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
|
| 99 |
+
if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
|
| 100 |
+
if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)
|
| 101 |
+
|
| 102 |
+
assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
|
| 103 |
+
assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
return self.frame1.shape[0]
|
| 107 |
+
|
| 108 |
+
def __getitem__(self, index):
|
| 109 |
+
return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class UserData:
|
| 113 |
+
def __init__(self, application, data_dict):
|
| 114 |
+
self.application = application
|
| 115 |
+
|
| 116 |
+
if self.application == 'classification':
|
| 117 |
+
self.data_class = ClassificationData(data_dict)
|
| 118 |
+
elif self.application == 'registration':
|
| 119 |
+
self.data_class = RegistrationData(data_dict)
|
| 120 |
+
elif self.application == 'flow_estimation':
|
| 121 |
+
self.data_class = FlowData(data_dict)
|
| 122 |
+
|
| 123 |
+
def __len__(self):
|
| 124 |
+
return len(self.data_class)
|
| 125 |
+
|
| 126 |
+
def __getitem__(self, index):
|
| 127 |
+
return self.data_class[index]
|
environment.yml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: r3pm_net
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- pytorch
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.9
|
| 7 |
+
- pip
|
| 8 |
+
- open3d
|
| 9 |
+
- pytorch
|
| 10 |
+
- hatchling
|
| 11 |
+
- ipykernel
|
| 12 |
+
- pip:
|
| 13 |
+
- tabulate
|
| 14 |
+
- pyyaml
|
| 15 |
+
- -e .
|
| 16 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "r3pm_net"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "R3PM-Net point cloud registration"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"numpy",
|
| 13 |
+
"torch",
|
| 14 |
+
"tensorboardX",
|
| 15 |
+
"tqdm",
|
| 16 |
+
"pyyaml",
|
| 17 |
+
"open3d",
|
| 18 |
+
"tabulate",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[tool.hatch.build.targets.wheel]
|
| 22 |
+
packages = ["r3pm_net", "tools", "dataloader", "thirdparty"]
|
| 23 |
+
|
| 24 |
+
[tool.jupytext]
|
| 25 |
+
formats = "ipynb,py:light"
|
r3pm_net/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""R3PM-Net: point cloud registration with PointNet features."""
|
| 2 |
+
|
| 3 |
+
from .model import R3PMNet
|
| 4 |
+
|
| 5 |
+
__all__ = ["R3PMNet"]
|
r3pm_net/config_loader.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load YAML training config and merge with argparse."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Mapping
|
| 9 |
+
|
| 10 |
+
import yaml
|
| 11 |
+
|
| 12 |
+
from r3pm_net.paths import REPO_ROOT
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _resolve_maybe_relative(path_str: str | None) -> str | None:
|
| 16 |
+
if path_str is None or path_str == "":
|
| 17 |
+
return path_str
|
| 18 |
+
p = Path(path_str)
|
| 19 |
+
if p.is_absolute():
|
| 20 |
+
return str(p)
|
| 21 |
+
return str(REPO_ROOT / p)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_yaml_config(path: str | Path) -> dict[str, Any]:
|
| 25 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 26 |
+
data = yaml.safe_load(f)
|
| 27 |
+
if data is None:
|
| 28 |
+
return {}
|
| 29 |
+
if not isinstance(data, Mapping):
|
| 30 |
+
raise ValueError(f"Config must be a mapping, got {type(data)}")
|
| 31 |
+
return dict(data)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _extract_config_argv(argv: list[str], default_cfg: str) -> tuple[str, list[str]]:
|
| 35 |
+
"""Return (config path for YAML, argv without --config ...)."""
|
| 36 |
+
path = default_cfg
|
| 37 |
+
out: list[str] = []
|
| 38 |
+
i = 0
|
| 39 |
+
while i < len(argv):
|
| 40 |
+
if argv[i] == "--config" and i + 1 < len(argv):
|
| 41 |
+
path = argv[i + 1]
|
| 42 |
+
i += 2
|
| 43 |
+
continue
|
| 44 |
+
if argv[i].startswith("--config="):
|
| 45 |
+
path = argv[i].split("=", 1)[1]
|
| 46 |
+
i += 1
|
| 47 |
+
continue
|
| 48 |
+
out.append(argv[i])
|
| 49 |
+
i += 1
|
| 50 |
+
return path, out
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def parse_train_args(argv: list[str], build_parser) -> argparse.Namespace:
|
| 54 |
+
"""Load YAML from --config (default: config/default.yaml), merge as argparse defaults, then parse CLI."""
|
| 55 |
+
default_cfg = str(REPO_ROOT / "config" / "default.yaml")
|
| 56 |
+
cfg_path, argv_rest = _extract_config_argv(list(argv), default_cfg)
|
| 57 |
+
cfg = load_yaml_config(cfg_path) if Path(cfg_path).is_file() else {}
|
| 58 |
+
parser = build_parser(cfg_path)
|
| 59 |
+
if cfg:
|
| 60 |
+
known = {
|
| 61 |
+
a.dest
|
| 62 |
+
for a in parser._actions
|
| 63 |
+
if getattr(a, "dest", None) and a.dest not in ("help", argparse.SUPPRESS)
|
| 64 |
+
}
|
| 65 |
+
filtered = {k: v for k, v in cfg.items() if k in known}
|
| 66 |
+
parser.set_defaults(**filtered)
|
| 67 |
+
return parser.parse_args(argv_rest)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def resolve_path_args(ns: Any, path_keys: tuple[str, ...]) -> None:
|
| 71 |
+
"""Mutate namespace: resolve listed keys to absolute paths under REPO_ROOT when relative."""
|
| 72 |
+
for key in path_keys:
|
| 73 |
+
val = getattr(ns, key, None)
|
| 74 |
+
if isinstance(val, str) and val:
|
| 75 |
+
setattr(ns, key, _resolve_maybe_relative(val))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def load_eval_yaml() -> dict[str, Any]:
|
| 79 |
+
"""Load ``config/eval.yaml`` if present; otherwise return an empty dict."""
|
| 80 |
+
path = REPO_ROOT / "config" / "eval.yaml"
|
| 81 |
+
if not path.is_file():
|
| 82 |
+
return {}
|
| 83 |
+
return load_yaml_config(path)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_pretrained_rpmnet_dir() -> str:
|
| 87 |
+
"""Directory containing ``clean-trained.pth``, ``best_model_PointNet*.t7``, etc.
|
| 88 |
+
|
| 89 |
+
``R3PM_NET_PRETRAINED_ROOT`` overrides ``pretrained_rpmnet_dir`` in ``config/eval.yaml``.
|
| 90 |
+
"""
|
| 91 |
+
env = os.environ.get("R3PM_NET_PRETRAINED_ROOT")
|
| 92 |
+
if env:
|
| 93 |
+
return str(Path(env).expanduser().resolve())
|
| 94 |
+
cfg = load_eval_yaml()
|
| 95 |
+
rel = (cfg.get("pretrained_rpmnet_dir") or "checkpoints").strip()
|
| 96 |
+
if not rel:
|
| 97 |
+
rel = "checkpoints"
|
| 98 |
+
out = _resolve_maybe_relative(rel)
|
| 99 |
+
return out if out else str(REPO_ROOT / "checkpoints")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_sioux_data_root() -> str:
|
| 103 |
+
"""Base data directory for Sioux scripts (``data`` / ``sioux_cranfield``, etc.)."""
|
| 104 |
+
cfg = load_eval_yaml()
|
| 105 |
+
sioux = cfg.get("sioux") or {}
|
| 106 |
+
base = sioux.get("base_dir") or cfg.get("data_root") or "data"
|
| 107 |
+
out = _resolve_maybe_relative(str(base).strip())
|
| 108 |
+
return out if out else str(REPO_ROOT / "data")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_modelnet40_paths() -> tuple[str, str]:
|
| 112 |
+
"""Return ``(dataset_path, cache_dir)`` for ModelNet40 evaluation."""
|
| 113 |
+
cfg = load_eval_yaml()
|
| 114 |
+
m = cfg.get("modelnet40") or {}
|
| 115 |
+
ds = m.get("dataset_path", "data/ModelNet40")
|
| 116 |
+
cache = m.get("cache_dir", "data/down_sampled_modelnet40")
|
| 117 |
+
dsr = _resolve_maybe_relative(ds)
|
| 118 |
+
cr = _resolve_maybe_relative(cache)
|
| 119 |
+
return (
|
| 120 |
+
dsr if dsr else str(REPO_ROOT / "data" / "ModelNet40"),
|
| 121 |
+
cr if cr else str(REPO_ROOT / "data" / "down_sampled_modelnet40"),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_method_paths() -> dict[str, Any]:
|
| 126 |
+
"""Return resolved path configuration for external registration methods."""
|
| 127 |
+
cfg = load_eval_yaml()
|
| 128 |
+
methods = cfg.get("methods") or {}
|
| 129 |
+
out: dict[str, Any] = {}
|
| 130 |
+
for method_name, method_cfg in methods.items():
|
| 131 |
+
if not isinstance(method_cfg, Mapping):
|
| 132 |
+
continue
|
| 133 |
+
method_out: dict[str, Any] = {}
|
| 134 |
+
for k, v in method_cfg.items():
|
| 135 |
+
if isinstance(v, str) and v.strip():
|
| 136 |
+
rv = _resolve_maybe_relative(v.strip())
|
| 137 |
+
method_out[k] = rv if rv else v
|
| 138 |
+
else:
|
| 139 |
+
method_out[k] = v
|
| 140 |
+
out[str(method_name)] = method_out
|
| 141 |
+
return out
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_sioux_paths() -> dict[str, Any]:
|
| 145 |
+
"""Return Sioux eval paths from config/eval.yaml with absolute paths."""
|
| 146 |
+
cfg = load_eval_yaml()
|
| 147 |
+
sioux = cfg.get("sioux") or {}
|
| 148 |
+
out: dict[str, Any] = {}
|
| 149 |
+
for k, v in sioux.items():
|
| 150 |
+
if isinstance(v, str) and v.strip():
|
| 151 |
+
rv = _resolve_maybe_relative(v.strip())
|
| 152 |
+
out[k] = rv if rv else v
|
| 153 |
+
elif isinstance(v, list):
|
| 154 |
+
vals = []
|
| 155 |
+
for item in v:
|
| 156 |
+
if isinstance(item, str) and item.strip():
|
| 157 |
+
rv = _resolve_maybe_relative(item.strip())
|
| 158 |
+
vals.append(rv if rv else item)
|
| 159 |
+
else:
|
| 160 |
+
vals.append(item)
|
| 161 |
+
out[k] = vals
|
| 162 |
+
else:
|
| 163 |
+
out[k] = v
|
| 164 |
+
return out
|
r3pm_net/feature_extractor.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Feature extractor for R3PM-Net
|
| 2 |
+
'''
|
| 3 |
+
Unfortunately, the feature extractor cannot be provided in the repository due to copyright issues.
|
| 4 |
+
Please implement the feature extractor for R3PM-Net as described in the paper and place it in this file.
|
| 5 |
+
Currently, the feature extractor is set to PPFNet (same as RPMNet).
|
| 6 |
+
'''
|
| 7 |
+
from thirdparty.learning3d.models import PPFNet
|
| 8 |
+
feature_extractor = PPFNet()
|
r3pm_net/model.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from thirdparty.learning3d.utils import square_distance, angle_difference
|
| 8 |
+
from thirdparty.learning3d.ops.transform_functions import convert2transformation
|
| 9 |
+
|
| 10 |
+
_EPS = 1e-5 # To prevent division by zero
|
| 11 |
+
|
| 12 |
+
class ParameterPredictionNet(nn.Module):
|
| 13 |
+
def __init__(self, weights_dim):
|
| 14 |
+
"""PointNet based Parameter prediction network
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
weights_dim: Number of weights to predict (excluding beta), should be something like
|
| 18 |
+
[3], or [64, 3], for 3 types of features
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self._logger = logging.getLogger(self.__class__.__name__)
|
| 24 |
+
|
| 25 |
+
self.weights_dim = weights_dim
|
| 26 |
+
|
| 27 |
+
# Pointnet
|
| 28 |
+
self.prepool = nn.Sequential(
|
| 29 |
+
nn.Conv1d(4, 64, 1),
|
| 30 |
+
nn.GroupNorm(8, 64),
|
| 31 |
+
nn.ReLU(),
|
| 32 |
+
|
| 33 |
+
nn.Conv1d(64, 64, 1),
|
| 34 |
+
nn.GroupNorm(8, 64),
|
| 35 |
+
nn.ReLU(),
|
| 36 |
+
|
| 37 |
+
nn.Conv1d(64, 64, 1),
|
| 38 |
+
nn.GroupNorm(8, 64),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
|
| 41 |
+
nn.Conv1d(64, 128, 1),
|
| 42 |
+
nn.GroupNorm(8, 128),
|
| 43 |
+
nn.ReLU(),
|
| 44 |
+
|
| 45 |
+
nn.Conv1d(128, 1024, 1),
|
| 46 |
+
nn.GroupNorm(16, 1024),
|
| 47 |
+
nn.ReLU(),
|
| 48 |
+
)
|
| 49 |
+
self.pooling = nn.AdaptiveMaxPool1d(1)
|
| 50 |
+
self.postpool = nn.Sequential(
|
| 51 |
+
nn.Linear(1024, 512),
|
| 52 |
+
nn.GroupNorm(16, 512),
|
| 53 |
+
nn.ReLU(),
|
| 54 |
+
|
| 55 |
+
nn.Linear(512, 256),
|
| 56 |
+
nn.GroupNorm(16, 256),
|
| 57 |
+
nn.ReLU(),
|
| 58 |
+
|
| 59 |
+
nn.Linear(256, 2 + np.prod(weights_dim)),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self._logger.info('Predicting weights with dim {}.'.format(self.weights_dim))
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
""" Returns alpha, beta, and gating_weights (if needed)
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
x: List containing two point clouds, x[0] = src (B, J, 3), x[1] = ref (B, K, 3)
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
beta, alpha, weightings
|
| 72 |
+
"""
|
| 73 |
+
# X and Y concatenated
|
| 74 |
+
src_padded = F.pad(x[0], (0, 1), mode='constant', value=0)
|
| 75 |
+
ref_padded = F.pad(x[1], (0, 1), mode='constant', value=1)
|
| 76 |
+
concatenated = torch.cat([src_padded, ref_padded], dim=1)
|
| 77 |
+
|
| 78 |
+
prepool_feat = self.prepool(concatenated.permute(0, 2, 1))
|
| 79 |
+
pooled = torch.flatten(self.pooling(prepool_feat), start_dim=-2)
|
| 80 |
+
raw_weights = self.postpool(pooled)
|
| 81 |
+
|
| 82 |
+
# softplus to ensure positivity
|
| 83 |
+
beta = F.softplus(raw_weights[:, 0])
|
| 84 |
+
alpha = F.softplus(raw_weights[:, 1])
|
| 85 |
+
|
| 86 |
+
return beta, alpha
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def to_numpy(tensor):
|
| 91 |
+
"""Wrapper around .detach().cpu().numpy() """
|
| 92 |
+
if isinstance(tensor, torch.Tensor):
|
| 93 |
+
return tensor.detach().cpu().numpy()
|
| 94 |
+
elif isinstance(tensor, np.ndarray):
|
| 95 |
+
return tensor
|
| 96 |
+
else:
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def se3_transform(g, a, normals=None):
|
| 101 |
+
""" Applies the SE3 transform
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4)
|
| 105 |
+
a: Points to be transformed (N, 3) or (B, N, 3)
|
| 106 |
+
normals: (Optional). If provided, normals will be transformed
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
transformed points of size (N, 3) or (B, N, 3)
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
R = g[..., :3, :3] # (B, 3, 3)
|
| 113 |
+
p = g[..., :3, 3] # (B, 3)
|
| 114 |
+
|
| 115 |
+
if len(g.size()) == len(a.size()):
|
| 116 |
+
b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :]
|
| 117 |
+
else:
|
| 118 |
+
raise NotImplementedError
|
| 119 |
+
b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked
|
| 120 |
+
|
| 121 |
+
if normals is not None:
|
| 122 |
+
rotated_normals = normals @ R.transpose(-1, -2)
|
| 123 |
+
return b, rotated_normals
|
| 124 |
+
|
| 125 |
+
else:
|
| 126 |
+
return b
|
| 127 |
+
|
| 128 |
+
def match_features(feat_src, feat_ref, metric='l2'):
|
| 129 |
+
""" Compute pairwise distance between features
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
feat_src: (B, J, C)
|
| 133 |
+
feat_ref: (B, K, C)
|
| 134 |
+
metric: either 'angle' or 'l2' (squared euclidean)
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Matching matrix (B, J, K). i'th row describes how well the i'th point
|
| 138 |
+
in the src agrees with every point in the ref.
|
| 139 |
+
"""
|
| 140 |
+
if feat_src.shape[-1] != feat_ref.shape[-1]:
|
| 141 |
+
if feat_src.shape[-1] > feat_ref.shape[-1]:
|
| 142 |
+
feat_src = feat_src[:,:,:feat_ref.shape[-1]]
|
| 143 |
+
elif feat_src.shape[-1] < feat_ref.shape[-1]:
|
| 144 |
+
feat_ref = feat_ref[:,:,:feat_src.shape[-1]]
|
| 145 |
+
|
| 146 |
+
assert feat_src.shape[-1] == feat_ref.shape[-1]
|
| 147 |
+
|
| 148 |
+
if metric == 'l2':
|
| 149 |
+
dist_matrix = square_distance(feat_src, feat_ref)
|
| 150 |
+
elif metric == 'angle':
|
| 151 |
+
feat_src_norm = feat_src / (torch.norm(feat_src, dim=-1, keepdim=True) + _EPS)
|
| 152 |
+
feat_ref_norm = feat_ref / (torch.norm(feat_ref, dim=-1, keepdim=True) + _EPS)
|
| 153 |
+
|
| 154 |
+
dist_matrix = angle_difference(feat_src_norm, feat_ref_norm)
|
| 155 |
+
else:
|
| 156 |
+
raise NotImplementedError
|
| 157 |
+
|
| 158 |
+
return dist_matrix
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def sinkhorn(log_alpha, n_iters: int = 5, slack: bool = True, eps: float = -1) -> torch.Tensor:
|
| 162 |
+
""" Run sinkhorn iterations to generate a near doubly stochastic matrix, where each row or column sum to <=1
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
log_alpha: log of positive matrix to apply sinkhorn normalization (B, J, K)
|
| 166 |
+
n_iters (int): Number of normalization iterations
|
| 167 |
+
slack (bool): Whether to include slack row and column
|
| 168 |
+
eps: eps for early termination (Used only for handcrafted RPM). Set to negative to disable.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
log(perm_matrix): Doubly stochastic matrix (B, J, K)
|
| 172 |
+
|
| 173 |
+
Modified from original source taken from:
|
| 174 |
+
Learning Latent Permutations with Gumbel-Sinkhorn Networks
|
| 175 |
+
https://github.com/HeddaCohenIndelman/Learning-Gumbel-Sinkhorn-Permutations-w-Pytorch
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
# Sinkhorn iterations
|
| 179 |
+
prev_alpha = None
|
| 180 |
+
if slack:
|
| 181 |
+
zero_pad = nn.ZeroPad2d((0, 1, 0, 1))
|
| 182 |
+
log_alpha_padded = zero_pad(log_alpha[:, None, :, :])
|
| 183 |
+
|
| 184 |
+
log_alpha_padded = torch.squeeze(log_alpha_padded, dim=1)
|
| 185 |
+
|
| 186 |
+
for i in range(n_iters):
|
| 187 |
+
# Row normalization
|
| 188 |
+
log_alpha_padded = torch.cat((
|
| 189 |
+
log_alpha_padded[:, :-1, :] - (torch.logsumexp(log_alpha_padded[:, :-1, :], dim=2, keepdim=True)),
|
| 190 |
+
log_alpha_padded[:, -1, None, :]), # Don't normalize last row
|
| 191 |
+
dim=1)
|
| 192 |
+
|
| 193 |
+
# Column normalization
|
| 194 |
+
log_alpha_padded = torch.cat((
|
| 195 |
+
log_alpha_padded[:, :, :-1] - (torch.logsumexp(log_alpha_padded[:, :, :-1], dim=1, keepdim=True)),
|
| 196 |
+
log_alpha_padded[:, :, -1, None]), # Don't normalize last column
|
| 197 |
+
dim=2)
|
| 198 |
+
|
| 199 |
+
if eps > 0:
|
| 200 |
+
if prev_alpha is not None:
|
| 201 |
+
abs_dev = torch.abs(torch.exp(log_alpha_padded[:, :-1, :-1]) - prev_alpha)
|
| 202 |
+
if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
|
| 203 |
+
break
|
| 204 |
+
prev_alpha = torch.exp(log_alpha_padded[:, :-1, :-1]).clone()
|
| 205 |
+
|
| 206 |
+
log_alpha = log_alpha_padded[:, :-1, :-1]
|
| 207 |
+
else:
|
| 208 |
+
for i in range(n_iters):
|
| 209 |
+
# Row normalization (i.e. each row sum to 1)
|
| 210 |
+
log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=2, keepdim=True))
|
| 211 |
+
|
| 212 |
+
# Column normalization (i.e. each column sum to 1)
|
| 213 |
+
log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=1, keepdim=True))
|
| 214 |
+
|
| 215 |
+
if eps > 0:
|
| 216 |
+
if prev_alpha is not None:
|
| 217 |
+
abs_dev = torch.abs(torch.exp(log_alpha) - prev_alpha)
|
| 218 |
+
if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
|
| 219 |
+
break
|
| 220 |
+
prev_alpha = torch.exp(log_alpha).clone()
|
| 221 |
+
|
| 222 |
+
return log_alpha
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def compute_rigid_transform(a: torch.Tensor, b: torch.Tensor, weights: torch.Tensor):
|
| 226 |
+
"""Compute rigid transforms between two point sets
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
a (torch.Tensor): (B, M, 3) points
|
| 230 |
+
b (torch.Tensor): (B, N, 3) points
|
| 231 |
+
weights (torch.Tensor): (B, M)
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Transform T (B, 3, 4) to get from a to b, i.e. T*a = b
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + _EPS)
|
| 238 |
+
centroid_a = torch.sum(a * weights_normalized, dim=1)
|
| 239 |
+
centroid_b = torch.sum(b * weights_normalized, dim=1)
|
| 240 |
+
a_centered = a - centroid_a[:, None, :]
|
| 241 |
+
b_centered = b - centroid_b[:, None, :]
|
| 242 |
+
cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized)
|
| 243 |
+
|
| 244 |
+
# Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3]
|
| 245 |
+
# and choose based on determinant to avoid flips
|
| 246 |
+
u, s, v = torch.svd(cov, some=False, compute_uv=True)
|
| 247 |
+
rot_mat_pos = v @ u.transpose(-1, -2)
|
| 248 |
+
v_neg = v.clone()
|
| 249 |
+
v_neg[:, :, 2] *= -1
|
| 250 |
+
rot_mat_neg = v_neg @ u.transpose(-1, -2)
|
| 251 |
+
rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg)
|
| 252 |
+
assert torch.all(torch.det(rot_mat) > 0)
|
| 253 |
+
|
| 254 |
+
# Compute translation (uncenter centroid)
|
| 255 |
+
translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None]
|
| 256 |
+
|
| 257 |
+
transform = torch.cat((rot_mat, translation), dim=2)
|
| 258 |
+
return transform
|
| 259 |
+
|
| 260 |
+
class R3PMNet(nn.Module):
|
| 261 |
+
def __init__(self, feature_model):
|
| 262 |
+
|
| 263 |
+
super().__init__()
|
| 264 |
+
|
| 265 |
+
self.add_slack = True
|
| 266 |
+
self.num_sk_iter = 5
|
| 267 |
+
|
| 268 |
+
self.weights_net = ParameterPredictionNet(weights_dim=[0])
|
| 269 |
+
self.feat_extractor = feature_model
|
| 270 |
+
|
| 271 |
+
def compute_affinity(self, beta, feat_distance, alpha=0.5):
|
| 272 |
+
"""Compute logarithm of Initial match matrix values, i.e. log(m_jk)"""
|
| 273 |
+
if isinstance(alpha, float):
|
| 274 |
+
hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha)
|
| 275 |
+
else:
|
| 276 |
+
hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha[:, None, None])
|
| 277 |
+
return hybrid_affinity
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def split_normals(data):
|
| 281 |
+
if data.shape[2] == 6:
|
| 282 |
+
xyz, normals = data[:, :, :3], data[:, :, 3:6]
|
| 283 |
+
elif data.shape[2] == 3:
|
| 284 |
+
xyz, normals = data, torch.zeros(data.shape).to(data.device)
|
| 285 |
+
return xyz, normals
|
| 286 |
+
|
| 287 |
+
def spam(self, xyz_template, norm_template, xyz_source, norm_source):
|
| 288 |
+
self.beta, self.alpha = self.weights_net([xyz_source, xyz_template])
|
| 289 |
+
|
| 290 |
+
try: # R3PMNET feature extractor
|
| 291 |
+
self.feat_source = self.feat_extractor(xyz_source)
|
| 292 |
+
self.feat_template = self.feat_extractor(xyz_template)
|
| 293 |
+
except:
|
| 294 |
+
self.feat_source = self.feat_extractor(xyz_source, norm_source)
|
| 295 |
+
self.feat_template = self.feat_extractor(xyz_template, norm_template)
|
| 296 |
+
|
| 297 |
+
feat_distance = match_features(self.feat_source, self.feat_template)
|
| 298 |
+
self.affinity = self.compute_affinity(self.beta, feat_distance, alpha=self.alpha)
|
| 299 |
+
|
| 300 |
+
# Compute weighted coordinates
|
| 301 |
+
log_perm_matrix = sinkhorn(self.affinity, n_iters=self.num_sk_iter, slack=self.add_slack)
|
| 302 |
+
self.perm_matrix = torch.exp(log_perm_matrix)
|
| 303 |
+
|
| 304 |
+
try: # R3PMNET features
|
| 305 |
+
weighted_template = self.perm_matrix @ xyz_template[:,:self.perm_matrix.shape[1]] / (torch.sum(self.perm_matrix, dim=2, keepdim=True) + _EPS)
|
| 306 |
+
except:
|
| 307 |
+
weighted_template = self.perm_matrix @ xyz_template / (torch.sum(self.perm_matrix, dim=2, keepdim=True) + _EPS)
|
| 308 |
+
return weighted_template
|
| 309 |
+
|
| 310 |
+
def forward(self, template, source, max_iterations: int = 1):
|
| 311 |
+
"""Forward pass for R3PM-Net
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
data: Dict containing the following fields:
|
| 315 |
+
'points_src': Source points (B, J, 6)
|
| 316 |
+
'points_ref': Reference points (B, K, 6)
|
| 317 |
+
num_iter (int): Number of iterations. Recommended to be 2 for training
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
transform: Transform to apply to source points such that they align to reference
|
| 321 |
+
src_transformed: Transformed source points
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
xyz_template, norm_template = self.split_normals(template)
|
| 325 |
+
xyz_source, norm_source = self.split_normals(source)
|
| 326 |
+
|
| 327 |
+
xyz_source_t, norm_source_t = xyz_source, norm_source # a copy of source to apply transformation to
|
| 328 |
+
|
| 329 |
+
transforms = []
|
| 330 |
+
all_gamma, all_perm_matrices, all_weighted_template = [], [], []
|
| 331 |
+
all_beta, all_alpha = [], []
|
| 332 |
+
|
| 333 |
+
for i in range(max_iterations):
|
| 334 |
+
weighted_template = self.spam(xyz_template, norm_template, xyz_source_t, norm_source_t) # Finding better correspondences after each iteration.
|
| 335 |
+
|
| 336 |
+
# Compute transform and transform points
|
| 337 |
+
try: # R3PMNET features
|
| 338 |
+
transform = compute_rigid_transform(xyz_source[:,:weighted_template.shape[1]], weighted_template, weights=torch.sum(self.perm_matrix, dim=2))
|
| 339 |
+
xyz_source_t, norm_source_t = se3_transform(transform.detach(), xyz_source[:,:weighted_template.shape[1]], norm_source) # Apply transformation to original source.
|
| 340 |
+
except:
|
| 341 |
+
transform = compute_rigid_transform(xyz_source_t, weighted_template, weights=torch.sum(self.perm_matrix, dim=2))
|
| 342 |
+
xyz_source_t, norm_source_t = se3_transform(transform.detach(), xyz_source, norm_source) # Apply transformation to original source.
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
transforms.append(transform)
|
| 346 |
+
all_gamma.append(torch.exp(self.affinity))
|
| 347 |
+
all_perm_matrices.append(self.perm_matrix)
|
| 348 |
+
all_weighted_template.append(weighted_template)
|
| 349 |
+
all_beta.append(to_numpy(self.beta))
|
| 350 |
+
all_alpha.append(to_numpy(self.alpha))
|
| 351 |
+
|
| 352 |
+
est_T = convert2transformation(transforms[max_iterations-1][:, :3, :3], transforms[max_iterations-1][:, :3, 3])
|
| 353 |
+
transformed_source = torch.bmm(est_T[:, :3, :3], source[:,:,:3].permute(0, 2, 1)).permute(0, 2, 1) + est_T[:, :3, 3].unsqueeze(1)
|
| 354 |
+
|
| 355 |
+
try: # for training
|
| 356 |
+
result = {'est_R': est_T[:, :3, :3], # source -> template
|
| 357 |
+
'est_t': est_T[:, :3, 3], # source -> template
|
| 358 |
+
'est_T': est_T, # source -> template
|
| 359 |
+
'r': self.feat_template - self.feat_source,
|
| 360 |
+
'transformed_source': transformed_source}
|
| 361 |
+
except RuntimeError:
|
| 362 |
+
result = {'est_R': est_T[:, :3, :3], # source -> template
|
| 363 |
+
'est_t': est_T[:, :3, 3], # source -> template
|
| 364 |
+
'est_T': est_T, # source -> template
|
| 365 |
+
'transformed_source': transformed_source}
|
| 366 |
+
|
| 367 |
+
result['perm_matrices_init'] = all_gamma
|
| 368 |
+
result['perm_matrices'] = all_perm_matrices
|
| 369 |
+
result['weighted_template'] = all_weighted_template
|
| 370 |
+
result['beta'] = np.stack(all_beta, axis=0)
|
| 371 |
+
result['alpha'] = np.stack(all_alpha, axis=0)
|
| 372 |
+
result['transforms'] = transforms
|
| 373 |
+
|
| 374 |
+
return result
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
if __name__ == '__main__':
|
| 378 |
+
template, source = torch.rand(10,1024,6), torch.rand(10,1024,6)
|
| 379 |
+
|
| 380 |
+
net = R3PMNet()
|
| 381 |
+
result = net(template, source)
|
| 382 |
+
import ipdb; ipdb.set_trace()
|
r3pm_net/paths.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Repository-root resolution for portable paths."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# r3pm_net/paths.py -> parents[1] is the repository root
|
| 6 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def repo_path(*parts: str) -> str:
|
| 10 |
+
"""Join path segments relative to the repository root."""
|
| 11 |
+
return str(REPO_ROOT.joinpath(*parts))
|
scripts/eval_modelnet40.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
# Repository root on PYTHONPATH (run: python scripts/test_modelnet40.py from repo root).
|
| 8 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 9 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 10 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import open3d as o3d
|
| 17 |
+
import torch
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from tools import augmentation, data, l3d_helper, print_results, transformations
|
| 21 |
+
from tools import l3d_registration_and_evaluation, predator_registration_and_evaluation, geotransformer_registration_and_evaluation, logdesc_registration_and_evaluation, regtr_registration_and_evaluation
|
| 22 |
+
from r3pm_net.config_loader import get_method_paths,get_modelnet40_paths, get_pretrained_rpmnet_dir
|
| 23 |
+
|
| 24 |
+
'''
|
| 25 |
+
This script evaluates the performance on the ModelNet40 test dataset.
|
| 26 |
+
The results are averaged ovet the dataset with 2468 samples.
|
| 27 |
+
All the point clouds are normalized to a sphere of radius 1.
|
| 28 |
+
|
| 29 |
+
Augmentations:
|
| 30 |
+
- Transformation = Random rotation (0 - 45) and translation (-0.5 to 0.5)
|
| 31 |
+
- Noise = Gaussian noise with mean 0 and std deviation of 0.01 [optional]
|
| 32 |
+
- Outliers = with level 1 which means 2% of the points are outliers (PC size = 2040) [optional]
|
| 33 |
+
- Occlusion = 90000 radius which means 0.7% of the points are occluded (PC size = 1986) [optional]
|
| 34 |
+
'''
|
| 35 |
+
def set_seed(seed: int) -> None:
|
| 36 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 37 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 38 |
+
|
| 39 |
+
random.seed(seed)
|
| 40 |
+
np.random.seed(seed)
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
torch.cuda.manual_seed_all(seed)
|
| 43 |
+
|
| 44 |
+
torch.backends.cudnn.benchmark = False
|
| 45 |
+
torch.backends.cudnn.deterministic = True
|
| 46 |
+
torch.use_deterministic_algorithms(True)
|
| 47 |
+
|
| 48 |
+
# arguments
|
| 49 |
+
parser = argparse.ArgumentParser(description="ModelNet40 R3PM-Net evaluation")
|
| 50 |
+
parser.add_argument("--seed", type=int, default=42, help="random seed (default: 42)")
|
| 51 |
+
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
set_seed(args.seed)
|
| 54 |
+
method_paths = get_method_paths()
|
| 55 |
+
|
| 56 |
+
pretrained_base_dir = get_pretrained_rpmnet_dir()
|
| 57 |
+
_path_zs = os.path.join(pretrained_base_dir, "clean-trained.pth")
|
| 58 |
+
_path_ft = os.path.join(pretrained_base_dir, "best_model_PointNet.t7") #TODO: CHANGE
|
| 59 |
+
|
| 60 |
+
def fix_off_file(file_path):
|
| 61 |
+
with open(file_path, 'r') as f:
|
| 62 |
+
lines = f.readlines()
|
| 63 |
+
|
| 64 |
+
if lines[0].startswith("OFF") and len(lines[0].strip().split()) > 1:
|
| 65 |
+
header = lines[0].strip()
|
| 66 |
+
new_header = "OFF\n" + header[3:] + "\n"
|
| 67 |
+
lines = [new_header] + lines[1:]
|
| 68 |
+
|
| 69 |
+
with open(file_path, 'w') as f:
|
| 70 |
+
f.writelines(lines)
|
| 71 |
+
print(f"Fixed: {file_path}")
|
| 72 |
+
|
| 73 |
+
def load_modelnet40_test_data(dataset_path, num_points=2000):
|
| 74 |
+
test_data = []
|
| 75 |
+
test_labels = []
|
| 76 |
+
categories = os.listdir(dataset_path)
|
| 77 |
+
for label, category in enumerate(tqdm(categories, desc="Loading Data")):
|
| 78 |
+
test_dir = os.path.join(dataset_path, category, 'test')
|
| 79 |
+
if not os.path.exists(test_dir):
|
| 80 |
+
continue
|
| 81 |
+
for file in tqdm(os.listdir(test_dir), desc=f"Processing {category} Category", leave=False):
|
| 82 |
+
if file.endswith('.off'):
|
| 83 |
+
file_path = os.path.join(test_dir, file)
|
| 84 |
+
mesh = o3d.io.read_triangle_mesh(file_path)
|
| 85 |
+
point_cloud = mesh.sample_points_poisson_disk(number_of_points=num_points)
|
| 86 |
+
test_data.append(point_cloud)
|
| 87 |
+
test_labels.append(label)
|
| 88 |
+
|
| 89 |
+
return test_data, test_labels, categories
|
| 90 |
+
|
| 91 |
+
# download from http://modelnet.cs.princeton.edu/ModelNet40.zip unzip and put the path in the config/eval.yaml
|
| 92 |
+
dataset_path, save_dir = get_modelnet40_paths()
|
| 93 |
+
test_data_path = os.path.join(save_dir, "test_data.npy")
|
| 94 |
+
test_labels_path = os.path.join(save_dir, "test_labels.npy")
|
| 95 |
+
categories_path = os.path.join(save_dir, "categories.npy")
|
| 96 |
+
|
| 97 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 98 |
+
|
| 99 |
+
# Check if data already exists
|
| 100 |
+
if os.path.exists(test_data_path) and os.path.exists(test_labels_path) and os.path.exists(categories_path):
|
| 101 |
+
print("Loading existing test data...")
|
| 102 |
+
test_data_np = np.load(test_data_path, allow_pickle=True)
|
| 103 |
+
test_labels = np.load(test_labels_path)
|
| 104 |
+
categories = np.load(categories_path)
|
| 105 |
+
print("Done! Testing the models...")
|
| 106 |
+
else:
|
| 107 |
+
print("Loading and processing ModelNet40 test data...")
|
| 108 |
+
# Fix all .OFF files in the dataset
|
| 109 |
+
for root, _, files in os.walk(dataset_path):
|
| 110 |
+
for file in files:
|
| 111 |
+
if file.endswith(".off"):
|
| 112 |
+
fix_off_file(os.path.join(root, file))
|
| 113 |
+
|
| 114 |
+
test_data, test_labels, categories = load_modelnet40_test_data(dataset_path)
|
| 115 |
+
|
| 116 |
+
test_data_np = [data.normalize_pc(pc, return_as_np = True) for pc in test_data]
|
| 117 |
+
|
| 118 |
+
np.save(test_data_path, test_data_np)
|
| 119 |
+
np.save(test_labels_path, test_labels)
|
| 120 |
+
np.save(categories_path, categories)
|
| 121 |
+
print("Test data saved!")
|
| 122 |
+
|
| 123 |
+
# Initialize arrays to store results
|
| 124 |
+
rpm_results_all = []
|
| 125 |
+
predator_results_all = []
|
| 126 |
+
geotransformer_results_all = []
|
| 127 |
+
logdesc_results_all = []
|
| 128 |
+
regtr_results_all = []
|
| 129 |
+
r3pm_net_results_all = []
|
| 130 |
+
tuned_r3pm_net_results_all = []
|
| 131 |
+
|
| 132 |
+
rpm_reg_results_all = []
|
| 133 |
+
predator_reg_results_all = []
|
| 134 |
+
geotransformer_reg_results_all = []
|
| 135 |
+
logdesc_reg_results_all = []
|
| 136 |
+
regtr_reg_results_all = []
|
| 137 |
+
r3pm_net_reg_results_all = []
|
| 138 |
+
tuned_r3pm_net_reg_results_all = []
|
| 139 |
+
|
| 140 |
+
all_sources = []
|
| 141 |
+
all_targets = []
|
| 142 |
+
all_angles ={}
|
| 143 |
+
|
| 144 |
+
# Reconstruct Open3D PointCloud objects from saved npy arrays
|
| 145 |
+
test_data = [o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points)) for points in test_data_np]
|
| 146 |
+
|
| 147 |
+
noise_level = 0
|
| 148 |
+
outlier_level = 0
|
| 149 |
+
outlier_lowerbound = -0.5
|
| 150 |
+
outlier_upperbound = 0.5
|
| 151 |
+
# occlusion_level = 90000 # Higher value means less occlusion
|
| 152 |
+
occlusion_level = 0 # Higher value means less occlusion
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# set arguments for models
|
| 156 |
+
rpm_args = l3d_helper.options(modelName="RPMNet")
|
| 157 |
+
rpm_args.pretrained = _path_zs
|
| 158 |
+
|
| 159 |
+
# OverlapPredator (used by Predator runner)
|
| 160 |
+
predator_cfg = method_paths.get("predator", {})
|
| 161 |
+
predator_root = predator_cfg.get("root")
|
| 162 |
+
predator_config_path = predator_cfg.get("config_path")
|
| 163 |
+
predator_weights_path = predator_cfg.get("weights_path")
|
| 164 |
+
|
| 165 |
+
# GeoTransformer
|
| 166 |
+
geo_cfg = method_paths.get("geotransformer", {})
|
| 167 |
+
geotransformer_root = geo_cfg.get("root")
|
| 168 |
+
geotransformer_exp_subdir = geo_cfg.get("exp_subdir")
|
| 169 |
+
geotransformer_weights_path = geo_cfg.get("weights_path")
|
| 170 |
+
|
| 171 |
+
# LoGDesc
|
| 172 |
+
logdesc_cfg = method_paths.get("logdesc", {})
|
| 173 |
+
logdesc_root = logdesc_cfg.get("root")
|
| 174 |
+
logdesc_weights_path = logdesc_cfg.get("weights_path")
|
| 175 |
+
|
| 176 |
+
# RegTR
|
| 177 |
+
regtr_cfg = method_paths.get("regtr", {})
|
| 178 |
+
regtr_root = regtr_cfg.get("root")
|
| 179 |
+
regtr_ckpt_path = regtr_cfg.get("ckpt_path")
|
| 180 |
+
regtr_config_path = regtr_cfg.get("config_path")
|
| 181 |
+
|
| 182 |
+
# R3PM-Net (ours) - ZS - no training
|
| 183 |
+
r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
|
| 184 |
+
r3pm_net_args.pretrained = _path_zs
|
| 185 |
+
|
| 186 |
+
# R3PM-Net (ours) - FT
|
| 187 |
+
tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
|
| 188 |
+
tuned_r3pm_net_args.pretrained = _path_ft
|
| 189 |
+
|
| 190 |
+
for i, item in enumerate(tqdm(test_data, desc="Testing methods")):
|
| 191 |
+
|
| 192 |
+
# Simulate data
|
| 193 |
+
x_angle = int(random.uniform(0, 45))
|
| 194 |
+
y_angle = int(random.uniform(0, 45))
|
| 195 |
+
z_angle = int(random.uniform(0, 45))
|
| 196 |
+
translation_range = (-0.5, 0.5)
|
| 197 |
+
gt_transformation = transformations.create_transformation(x_angle, y_angle, z_angle, translation_range)
|
| 198 |
+
source = copy.deepcopy(item)
|
| 199 |
+
|
| 200 |
+
target = copy.deepcopy(item).transform(gt_transformation)
|
| 201 |
+
|
| 202 |
+
# Apply augmentations
|
| 203 |
+
noisy_source = copy.deepcopy(source)
|
| 204 |
+
if noise_level != 0:
|
| 205 |
+
noisy_source = augmentation.apply_noise(noisy_source, noise_level)
|
| 206 |
+
if outlier_level != 0:
|
| 207 |
+
noisy_source = augmentation.add_outliers(noisy_source, outlier_level, outlier_lowerbound, outlier_upperbound)
|
| 208 |
+
if occlusion_level != 0:
|
| 209 |
+
noisy_source, _ = augmentation.apply_occlusion(noisy_source, occlusion_level)
|
| 210 |
+
if len(noisy_source.points) < 1024: # cannot be smaller than embedding dims in config/default.yaml
|
| 211 |
+
noisy_source = copy.deepcopy(source)
|
| 212 |
+
noisy_source = augmentation.apply_noise(noisy_source, noise_level)
|
| 213 |
+
noisy_source, _ = augmentation.apply_occlusion(noisy_source, occlusion_level * 100)
|
| 214 |
+
assert len(noisy_source.points) >= 1024, "Noisy source point cloud has less than 1024 points."
|
| 215 |
+
|
| 216 |
+
# RPMNet
|
| 217 |
+
rpm_results_pc, rpm_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
|
| 218 |
+
noisy_source, target, 'rpmnet', gt_transformation, rpm_args)
|
| 219 |
+
rpm_results_all.append(rpm_results)
|
| 220 |
+
rpm_reg_results_all.append(rpm_results_pc)
|
| 221 |
+
|
| 222 |
+
# OverlapPredator
|
| 223 |
+
predator_results_pc, predator_results = predator_registration_and_evaluation.predator_reg_and_eval(
|
| 224 |
+
noisy_source,
|
| 225 |
+
target,
|
| 226 |
+
gt_transformation=gt_transformation,
|
| 227 |
+
predator_root=predator_root,
|
| 228 |
+
config_path=predator_config_path,
|
| 229 |
+
weights_path=predator_weights_path,
|
| 230 |
+
ransac_n_points=1000,
|
| 231 |
+
ransac_distance_threshold=0.05,
|
| 232 |
+
ransac_n=3,
|
| 233 |
+
sampling="prob",
|
| 234 |
+
mutual=False,
|
| 235 |
+
input_num_points=1024,
|
| 236 |
+
)
|
| 237 |
+
predator_results_all.append(predator_results)
|
| 238 |
+
predator_reg_results_all.append(predator_results_pc)
|
| 239 |
+
|
| 240 |
+
# GeoTransformer (ModelNet)
|
| 241 |
+
geotransformer_results_pc, geotransformer_results = geotransformer_registration_and_evaluation.geotransformer_reg_and_eval(
|
| 242 |
+
noisy_source,
|
| 243 |
+
target,
|
| 244 |
+
gt_transformation=gt_transformation,
|
| 245 |
+
geotransformer_root=geotransformer_root,
|
| 246 |
+
exp_subdir=geotransformer_exp_subdir,
|
| 247 |
+
weights_path=geotransformer_weights_path,
|
| 248 |
+
)
|
| 249 |
+
geotransformer_results_all.append(geotransformer_results)
|
| 250 |
+
geotransformer_reg_results_all.append(geotransformer_results_pc)
|
| 251 |
+
|
| 252 |
+
# LoGDesc
|
| 253 |
+
logdesc_results_pc, logdesc_results = logdesc_registration_and_evaluation.logdesc_reg_and_eval(
|
| 254 |
+
noisy_source,
|
| 255 |
+
target,
|
| 256 |
+
gt_transformation=gt_transformation,
|
| 257 |
+
logdesc_root=logdesc_root,
|
| 258 |
+
weights_path=logdesc_weights_path,
|
| 259 |
+
max_keypoints=768,
|
| 260 |
+
num_points_per_sample=128,
|
| 261 |
+
sample_radius=0.3,
|
| 262 |
+
topk_matches=128,
|
| 263 |
+
use_kpt=False,
|
| 264 |
+
)
|
| 265 |
+
logdesc_results_all.append(logdesc_results)
|
| 266 |
+
logdesc_reg_results_all.append(logdesc_results_pc)
|
| 267 |
+
|
| 268 |
+
# RegTR (ModelNet)
|
| 269 |
+
regtr_results_pc, regtr_results = regtr_registration_and_evaluation.regtr_reg_and_eval(
|
| 270 |
+
noisy_source,
|
| 271 |
+
target,
|
| 272 |
+
gt_transformation=gt_transformation,
|
| 273 |
+
regtr_root=regtr_root,
|
| 274 |
+
ckpt_path=regtr_ckpt_path,
|
| 275 |
+
config_path=regtr_config_path,
|
| 276 |
+
)
|
| 277 |
+
regtr_results_all.append(regtr_results)
|
| 278 |
+
regtr_reg_results_all.append(regtr_results_pc)
|
| 279 |
+
|
| 280 |
+
# R3PM-Net (ours) - no training
|
| 281 |
+
r3pm_net_results_pc, r3pm_net_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
|
| 282 |
+
noisy_source, target, 'r3pmnet', gt_transformation, r3pm_net_args)
|
| 283 |
+
r3pm_net_results_all.append(r3pm_net_results)
|
| 284 |
+
r3pm_net_reg_results_all.append(r3pm_net_results_pc)
|
| 285 |
+
|
| 286 |
+
# R3PM-Net (ours) (Tuned on 4 sioux data)
|
| 287 |
+
tuned_r3pm_net_results_pc, tuned_r3pm_net_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
|
| 288 |
+
noisy_source, target, 'r3pmnet', gt_transformation, tuned_r3pm_net_args)
|
| 289 |
+
tuned_r3pm_net_results_all.append(tuned_r3pm_net_results)
|
| 290 |
+
tuned_r3pm_net_reg_results_all.append(tuned_r3pm_net_results_pc)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
all_sources.append(noisy_source)
|
| 294 |
+
all_targets.append(target)
|
| 295 |
+
all_angles[i] = {
|
| 296 |
+
"x_angle": x_angle,
|
| 297 |
+
"y_angle": y_angle,
|
| 298 |
+
"z_angle": z_angle,
|
| 299 |
+
"translation": gt_transformation[:3, 3]
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
# Convert results to numpy arrays for easier manipulation
|
| 303 |
+
rpm_results_all = np.array(rpm_results_all)
|
| 304 |
+
predator_results_all = np.array(predator_results_all)
|
| 305 |
+
geotransformer_results_all = np.array(geotransformer_results_all)
|
| 306 |
+
logdesc_results_all = np.array(logdesc_results_all)
|
| 307 |
+
regtr_results_all = np.array(regtr_results_all)
|
| 308 |
+
r3pm_net_results_all = np.array(r3pm_net_results_all)
|
| 309 |
+
tuned_r3pm_net_results_all = np.array(tuned_r3pm_net_results_all)
|
| 310 |
+
|
| 311 |
+
rpm_mean_results = np.mean(rpm_results_all, axis=0)
|
| 312 |
+
predator_mean_results = np.mean(predator_results_all, axis=0)
|
| 313 |
+
geotransformer_mean_results = np.mean(geotransformer_results_all, axis=0)
|
| 314 |
+
logdesc_mean_results = np.mean(logdesc_results_all, axis=0)
|
| 315 |
+
regtr_mean_results = np.mean(regtr_results_all, axis=0)
|
| 316 |
+
r3pm_net_mean_results = np.mean(r3pm_net_results_all, axis=0)
|
| 317 |
+
tuned_r3pm_net_mean_results = np.mean(tuned_r3pm_net_results_all, axis=0)
|
| 318 |
+
|
| 319 |
+
# Print the results
|
| 320 |
+
metric_names = ['mean_rmse', 'mean_rotation_error', 'mean_translation_error',
|
| 321 |
+
'mean_computation_time', 'mean_cd', 'mean_error',
|
| 322 |
+
'mean_fitness', 'mean_inlier_rmse']
|
| 323 |
+
|
| 324 |
+
reports = {
|
| 325 |
+
"RPMNet": dict(zip(metric_names, rpm_mean_results)),
|
| 326 |
+
"Predator": dict(zip(metric_names, predator_mean_results)),
|
| 327 |
+
"GeoTransformer": dict(zip(metric_names, geotransformer_mean_results)),
|
| 328 |
+
"LoGDesc": dict(zip(metric_names, logdesc_mean_results)),
|
| 329 |
+
"RegTR": dict(zip(metric_names, regtr_mean_results)),
|
| 330 |
+
"R3PM-Net (ours) (ZS)": dict(zip(metric_names, r3pm_net_mean_results)),
|
| 331 |
+
"R3PM-Net (ours) (FT)": dict(zip(metric_names, tuned_r3pm_net_mean_results)),
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
# Print the table
|
| 335 |
+
print_results.print_table(reports)
|
scripts/eval_sioux_cranfield.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import open3d as o3d
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
import argparse
|
| 11 |
+
|
| 12 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 15 |
+
|
| 16 |
+
from tools import augmentation, data, l3d_helper, print_results, transformations
|
| 17 |
+
from tools import l3d_registration_and_evaluation, predator_registration_and_evaluation, geotransformer_registration_and_evaluation, logdesc_registration_and_evaluation, regtr_registration_and_evaluation
|
| 18 |
+
from r3pm_net.config_loader import get_method_paths, get_pretrained_rpmnet_dir, get_sioux_data_root, get_sioux_paths
|
| 19 |
+
'''
|
| 20 |
+
This script evaluates the performance on a Sioux-Cranfield dataset
|
| 21 |
+
Cranfield dataset from: https://github.com/Menthy-Denayer/PCR_CAD_Model_Alignment_Comparison/tree/main/datasets
|
| 22 |
+
'''
|
| 23 |
+
def set_seed(seed: int) -> None:
|
| 24 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 25 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 26 |
+
|
| 27 |
+
random.seed(seed)
|
| 28 |
+
np.random.seed(seed)
|
| 29 |
+
torch.manual_seed(seed)
|
| 30 |
+
torch.cuda.manual_seed_all(seed)
|
| 31 |
+
|
| 32 |
+
torch.backends.cudnn.benchmark = False
|
| 33 |
+
torch.backends.cudnn.deterministic = True
|
| 34 |
+
torch.use_deterministic_algorithms(True)
|
| 35 |
+
|
| 36 |
+
# arguments
|
| 37 |
+
parser = argparse.ArgumentParser(description="Sioux-Cranfield R3PM-Net evaluation")
|
| 38 |
+
parser.add_argument("--seed", type=int, default=42, help="random seed (default: 42)")
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
set_seed(args.seed)
|
| 41 |
+
|
| 42 |
+
base_dir = get_sioux_data_root()
|
| 43 |
+
sioux_cfg = get_sioux_paths()
|
| 44 |
+
method_paths = get_method_paths()
|
| 45 |
+
|
| 46 |
+
pretrained_base_dir = get_pretrained_rpmnet_dir()
|
| 47 |
+
_path_zs = os.path.join(pretrained_base_dir, "clean-trained.pth")
|
| 48 |
+
_path_ft = os.path.join(pretrained_base_dir, "best_model_PointNet.t7") #TODO: CHANGE
|
| 49 |
+
|
| 50 |
+
# Paths to the CAD models
|
| 51 |
+
cad_dir_made = os.path.join(base_dir, 'sioux_cranfield')
|
| 52 |
+
|
| 53 |
+
cad_paths = [os.path.join(cad_dir_made, 'Base-Top_Plate.stl'),
|
| 54 |
+
os.path.join(cad_dir_made, 'Pendulum.stl'),
|
| 55 |
+
os.path.join(cad_dir_made, 'Round-Peg.stl'),
|
| 56 |
+
os.path.join(cad_dir_made, 'Separator.stl'),
|
| 57 |
+
os.path.join(cad_dir_made, 'Shaft-New.stl'),
|
| 58 |
+
os.path.join(cad_dir_made, 'Square-Peg.stl'),
|
| 59 |
+
os.path.join(cad_dir_made, 'elephant.stl'),
|
| 60 |
+
os.path.join(cad_dir_made, 'house.stl'),
|
| 61 |
+
os.path.join(cad_dir_made, 'shoe.stl')]
|
| 62 |
+
|
| 63 |
+
# Test parameters
|
| 64 |
+
num_tests = 25
|
| 65 |
+
angles = list(range(0, 45))
|
| 66 |
+
translation_range = (-0.5, 0.5)
|
| 67 |
+
np.random.seed(42)
|
| 68 |
+
|
| 69 |
+
# Augmentation parameters
|
| 70 |
+
noise_level = 0
|
| 71 |
+
outlier_level = 0
|
| 72 |
+
outlier_lowerbound = -0.5
|
| 73 |
+
outlier_upperbound = 0.5
|
| 74 |
+
# occlusion_level = 9000 # Higher value means less occlusion
|
| 75 |
+
occ_level = 0
|
| 76 |
+
|
| 77 |
+
# Make dataset
|
| 78 |
+
sources = []
|
| 79 |
+
targets = []
|
| 80 |
+
x_angles = []
|
| 81 |
+
y_angles = []
|
| 82 |
+
z_angles = []
|
| 83 |
+
gt_transformations = []
|
| 84 |
+
|
| 85 |
+
for cadPath in tqdm (cad_paths, desc="Preparing Sioux-Cranfield Dataset", total=len(cad_paths)):
|
| 86 |
+
|
| 87 |
+
num_points = 2000
|
| 88 |
+
# Load the data
|
| 89 |
+
mesh = o3d.io.read_triangle_mesh(cadPath)
|
| 90 |
+
cad = mesh.sample_points_poisson_disk(number_of_points=num_points) # modify to a suitable number of points
|
| 91 |
+
normalized_point_cloud = data.normalize_pc(cad)
|
| 92 |
+
source = copy.deepcopy(normalized_point_cloud)
|
| 93 |
+
|
| 94 |
+
for test in range(num_tests):
|
| 95 |
+
# Data simulation
|
| 96 |
+
x_angle= np.random.uniform(angles[0], angles[-1], size=1)
|
| 97 |
+
y_angle= np.random.uniform(angles[0], angles[-1], size=1)
|
| 98 |
+
z_angle= np.random.uniform(angles[0], angles[-1], size=1)
|
| 99 |
+
gt_transformation = transformations.create_transformation(x_angle, y_angle, z_angle, translation_range)
|
| 100 |
+
target = copy.deepcopy(normalized_point_cloud).transform(gt_transformation)
|
| 101 |
+
|
| 102 |
+
# Data augmentation
|
| 103 |
+
if occ_level == 0 and noise_level == 0 and outlier_level == 0:
|
| 104 |
+
noisy_source = copy.deepcopy(source)
|
| 105 |
+
|
| 106 |
+
# Noise + Occlusion
|
| 107 |
+
elif occ_level != 0 and noise_level != 0:
|
| 108 |
+
noisy_source_noise = augmentation.apply_noise(source, noise_level)
|
| 109 |
+
noisy_source, _ = augmentation.apply_occlusion(noisy_source_noise, occ_level)
|
| 110 |
+
if len(noisy_source.points) < 1024: # Handle excessive occlusion
|
| 111 |
+
source = copy.deepcopy(target).transform(gt_transformation)
|
| 112 |
+
noisy_source_noise = augmentation.apply_noise(source, noise_level)
|
| 113 |
+
noisy_source, _ = augmentation.apply_occlusion(noisy_source_noise, occ_level * 1.5)
|
| 114 |
+
|
| 115 |
+
# Noise + Outlier
|
| 116 |
+
elif noise_level != 0 and outlier_level != 0:
|
| 117 |
+
noisy_source_noise = augmentation.apply_noise(source, noise_level)
|
| 118 |
+
noisy_source = augmentation.add_outliers(noisy_source_noise, outlier_level, outlier_lowerbound=-0.5, outlier_upperbound=0.5)
|
| 119 |
+
|
| 120 |
+
# Noise + Outlier + Occlusion
|
| 121 |
+
elif occ_level != 0 and noise_level != 0 and outlier_level != 0:
|
| 122 |
+
noisy_source_noise = augmentation.apply_noise(source, noise_level)
|
| 123 |
+
noisy_source, _ = augmentation.apply_occlusion(noisy_source_noise, occ_level)
|
| 124 |
+
if len(noisy_source.points) < 1024: # Handle excessive occlusion
|
| 125 |
+
source = copy.deepcopy(target).transform(gt_transformation)
|
| 126 |
+
noisy_source_noise = augmentation.apply_noise(source, noise_level)
|
| 127 |
+
noisy_source, _ = augmentation.apply_occlusion(noisy_source_noise, occ_level * 1.5)
|
| 128 |
+
noisy_source = augmentation.add_outliers(noisy_source, outlier_level, outlier_lowerbound=-0.5, outlier_upperbound=0.5)
|
| 129 |
+
|
| 130 |
+
# collect dataset in lists
|
| 131 |
+
sources.append(noisy_source)
|
| 132 |
+
targets.append(target)
|
| 133 |
+
x_angles.append(x_angle)
|
| 134 |
+
y_angles.append(y_angle)
|
| 135 |
+
z_angles.append(z_angle)
|
| 136 |
+
gt_transformations.append(gt_transformation)
|
| 137 |
+
|
| 138 |
+
# Initialize arrays to store results
|
| 139 |
+
rpm_results_all = []
|
| 140 |
+
predator_results_all = []
|
| 141 |
+
geotransformer_results_all = []
|
| 142 |
+
logdesc_results_all = []
|
| 143 |
+
regtr_results_all = []
|
| 144 |
+
r3pm_net_results_all = []
|
| 145 |
+
tuned_r3pm_net_results_all = []
|
| 146 |
+
|
| 147 |
+
rpm_reg_results_all = []
|
| 148 |
+
predator_reg_results_all = []
|
| 149 |
+
geotransformer_reg_results_all = []
|
| 150 |
+
logdesc_reg_results_all = []
|
| 151 |
+
regtr_reg_results_all = []
|
| 152 |
+
r3pm_net_reg_results_all = []
|
| 153 |
+
tuned_r3pm_net_reg_results_all = []
|
| 154 |
+
|
| 155 |
+
# set arguments for models
|
| 156 |
+
rpm_args = l3d_helper.options(modelName="RPMNet")
|
| 157 |
+
rpm_args.pretrained = _path_zs
|
| 158 |
+
|
| 159 |
+
# OverlapPredator (used by Predator runner)
|
| 160 |
+
predator_cfg = method_paths.get("predator", {})
|
| 161 |
+
predator_root = predator_cfg.get("root")
|
| 162 |
+
predator_config_path = predator_cfg.get("config_path")
|
| 163 |
+
predator_weights_path = predator_cfg.get("weights_path")
|
| 164 |
+
|
| 165 |
+
# GeoTransformer
|
| 166 |
+
geo_cfg = method_paths.get("geotransformer", {})
|
| 167 |
+
geotransformer_root = geo_cfg.get("root")
|
| 168 |
+
geotransformer_exp_subdir = geo_cfg.get("exp_subdir")
|
| 169 |
+
geotransformer_weights_path = geo_cfg.get("weights_path")
|
| 170 |
+
|
| 171 |
+
# LoGDesc
|
| 172 |
+
logdesc_cfg = method_paths.get("logdesc", {})
|
| 173 |
+
logdesc_root = logdesc_cfg.get("root")
|
| 174 |
+
logdesc_weights_path = logdesc_cfg.get("weights_path")
|
| 175 |
+
|
| 176 |
+
# RegTR
|
| 177 |
+
regtr_cfg = method_paths.get("regtr", {})
|
| 178 |
+
regtr_root = regtr_cfg.get("root")
|
| 179 |
+
regtr_ckpt_path = regtr_cfg.get("ckpt_path")
|
| 180 |
+
regtr_config_path = regtr_cfg.get("config_path")
|
| 181 |
+
|
| 182 |
+
# R3PM-Net (ours) - ZS - no training
|
| 183 |
+
r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
|
| 184 |
+
r3pm_net_args.pretrained = _path_zs
|
| 185 |
+
|
| 186 |
+
# R3PM-Net (ours) - FT
|
| 187 |
+
tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
|
| 188 |
+
tuned_r3pm_net_args.pretrained = _path_ft
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
for i, item in enumerate(tqdm(zip(sources, targets, gt_transformations), desc="Testing methods", total=len(sources))):
|
| 192 |
+
|
| 193 |
+
# RPMNet
|
| 194 |
+
rpm_results_pc, rpm_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
|
| 195 |
+
sources[i], targets[i], 'rpmnet', gt_transformations[i], rpm_args)
|
| 196 |
+
rpm_results_all.append(rpm_results)
|
| 197 |
+
rpm_reg_results_all.append(rpm_results_pc)
|
| 198 |
+
|
| 199 |
+
# OverlapPredator
|
| 200 |
+
predator_results_pc, predator_results = predator_registration_and_evaluation.predator_reg_and_eval(
|
| 201 |
+
sources[i],
|
| 202 |
+
targets[i],
|
| 203 |
+
gt_transformation=gt_transformations[i],
|
| 204 |
+
predator_root=predator_root,
|
| 205 |
+
config_path=predator_config_path,
|
| 206 |
+
weights_path=predator_weights_path,
|
| 207 |
+
ransac_n_points=1000,
|
| 208 |
+
ransac_distance_threshold=0.05,
|
| 209 |
+
ransac_n=3,
|
| 210 |
+
sampling="prob",
|
| 211 |
+
mutual=False,
|
| 212 |
+
input_num_points=1024,
|
| 213 |
+
)
|
| 214 |
+
predator_results_all.append(predator_results)
|
| 215 |
+
predator_reg_results_all.append(predator_results_pc)
|
| 216 |
+
|
| 217 |
+
# GeoTransformer (ModelNet)
|
| 218 |
+
geotransformer_results_pc, geotransformer_results = geotransformer_registration_and_evaluation.geotransformer_reg_and_eval(
|
| 219 |
+
sources[i],
|
| 220 |
+
targets[i],
|
| 221 |
+
gt_transformation=gt_transformations[i],
|
| 222 |
+
geotransformer_root=geotransformer_root,
|
| 223 |
+
exp_subdir=geotransformer_exp_subdir,
|
| 224 |
+
weights_path=geotransformer_weights_path,
|
| 225 |
+
)
|
| 226 |
+
geotransformer_results_all.append(geotransformer_results)
|
| 227 |
+
geotransformer_reg_results_all.append(geotransformer_results_pc)
|
| 228 |
+
|
| 229 |
+
# LoGDesc
|
| 230 |
+
logdesc_results_pc, logdesc_results = logdesc_registration_and_evaluation.logdesc_reg_and_eval(
|
| 231 |
+
sources[i],
|
| 232 |
+
targets[i],
|
| 233 |
+
gt_transformation=gt_transformations[i],
|
| 234 |
+
logdesc_root=logdesc_root,
|
| 235 |
+
weights_path=logdesc_weights_path,
|
| 236 |
+
max_keypoints=768,
|
| 237 |
+
num_points_per_sample=128,
|
| 238 |
+
sample_radius=0.3,
|
| 239 |
+
topk_matches=128,
|
| 240 |
+
use_kpt=False,
|
| 241 |
+
)
|
| 242 |
+
logdesc_results_all.append(logdesc_results)
|
| 243 |
+
logdesc_reg_results_all.append(logdesc_results_pc)
|
| 244 |
+
|
| 245 |
+
# RegTR (ModelNet)
|
| 246 |
+
regtr_results_pc, regtr_results = regtr_registration_and_evaluation.regtr_reg_and_eval(
|
| 247 |
+
sources[i],
|
| 248 |
+
targets[i],
|
| 249 |
+
gt_transformation=gt_transformations[i],
|
| 250 |
+
regtr_root=regtr_root,
|
| 251 |
+
ckpt_path=regtr_ckpt_path,
|
| 252 |
+
config_path=regtr_config_path,
|
| 253 |
+
)
|
| 254 |
+
regtr_results_all.append(regtr_results)
|
| 255 |
+
regtr_reg_results_all.append(regtr_results_pc)
|
| 256 |
+
|
| 257 |
+
# R3PM-Net (ours) - ZS - no training
|
| 258 |
+
r3pm_net_results_pc, r3pm_net_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
|
| 259 |
+
sources[i], targets[i], 'r3pmnet', gt_transformations[i], r3pm_net_args)
|
| 260 |
+
r3pm_net_results_all.append(r3pm_net_results)
|
| 261 |
+
r3pm_net_reg_results_all.append(r3pm_net_results_pc)
|
| 262 |
+
|
| 263 |
+
# R3PM-Net (ours) - FT
|
| 264 |
+
tuned_r3pm_net_results_pc, tuned_r3pm_net_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
|
| 265 |
+
sources[i], targets[i], 'r3pmnet', gt_transformations[i], tuned_r3pm_net_args)
|
| 266 |
+
tuned_r3pm_net_results_all.append(tuned_r3pm_net_results)
|
| 267 |
+
tuned_r3pm_net_reg_results_all.append(tuned_r3pm_net_results_pc)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# Convert results to numpy arrays for easier manipulation
|
| 271 |
+
rpm_results_all = np.array(rpm_results_all)
|
| 272 |
+
predator_results_all = np.array(predator_results_all)
|
| 273 |
+
geotransformer_results_all = np.array(geotransformer_results_all)
|
| 274 |
+
logdesc_results_all = np.array(logdesc_results_all)
|
| 275 |
+
regtr_results_all = np.array(regtr_results_all)
|
| 276 |
+
r3pm_net_results_all = np.array(r3pm_net_results_all)
|
| 277 |
+
tuned_r3pm_net_results_all = np.array(tuned_r3pm_net_results_all)
|
| 278 |
+
|
| 279 |
+
rpm_mean_results = np.mean(rpm_results_all, axis=0)
|
| 280 |
+
predator_mean_results = np.mean(predator_results_all, axis=0)
|
| 281 |
+
geotransformer_mean_results = np.mean(geotransformer_results_all, axis=0)
|
| 282 |
+
logdesc_mean_results = np.mean(logdesc_results_all, axis=0)
|
| 283 |
+
regtr_mean_results = np.mean(regtr_results_all, axis=0)
|
| 284 |
+
r3pm_net_mean_results = np.mean(r3pm_net_results_all, axis=0)
|
| 285 |
+
tuned_r3pm_net_mean_results = np.mean(tuned_r3pm_net_results_all, axis=0)
|
| 286 |
+
|
| 287 |
+
# Print the results
|
| 288 |
+
metric_names = ['mean_rmse', 'mean_rotation_error', 'mean_translation_error',
|
| 289 |
+
'mean_computation_time', 'mean_cd', 'mean_error',
|
| 290 |
+
'mean_fitness', 'mean_inlier_rmse']
|
| 291 |
+
|
| 292 |
+
reports = {
|
| 293 |
+
"RPMNet": dict(zip(metric_names, rpm_mean_results)),
|
| 294 |
+
"Predator": dict(zip(metric_names, predator_mean_results)),
|
| 295 |
+
"GeoTransformer": dict(zip(metric_names, geotransformer_mean_results)),
|
| 296 |
+
"LoGDesc": dict(zip(metric_names, logdesc_mean_results)),
|
| 297 |
+
"RegTR": dict(zip(metric_names, regtr_mean_results)),
|
| 298 |
+
"R3PM-Net (ours) (ZS)": dict(zip(metric_names, r3pm_net_mean_results)),
|
| 299 |
+
"R3PM-Net (ours) (FT)": dict(zip(metric_names, tuned_r3pm_net_mean_results)),}
|
| 300 |
+
|
| 301 |
+
# Print the table
|
| 302 |
+
print_results.print_table(reports)
|
scripts/eval_sioux_scans.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import torch
|
| 7 |
+
from tabulate import tabulate
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 15 |
+
|
| 16 |
+
from tools import data, l3d_helper, visualization
|
| 17 |
+
from tools import icp_registration_and_evaluation, l3d_registration_and_evaluation, predator_registration_and_evaluation, geotransformer_registration_and_evaluation, logdesc_registration_and_evaluation, regtr_registration_and_evaluation
|
| 18 |
+
from r3pm_net.config_loader import get_pretrained_rpmnet_dir, get_sioux_data_root, get_method_paths
|
| 19 |
+
|
| 20 |
+
'''
|
| 21 |
+
This script is used to evaluate the performance of the pipeline with R3PM-Net as global and GICP as local registeration.
|
| 22 |
+
|
| 23 |
+
The script takes the following arguments:
|
| 24 |
+
--local_reg: the local registration method to be used.
|
| 25 |
+
--seed: random seed for python/numpy/torch. The default is 42.
|
| 26 |
+
--verbose: if set to True, the results will be printed in a table format. The default is False.
|
| 27 |
+
'''
|
| 28 |
+
def set_seed(seed: int) -> None:
|
| 29 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 30 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 31 |
+
|
| 32 |
+
random.seed(seed)
|
| 33 |
+
np.random.seed(seed)
|
| 34 |
+
torch.manual_seed(seed)
|
| 35 |
+
torch.cuda.manual_seed_all(seed)
|
| 36 |
+
|
| 37 |
+
torch.backends.cudnn.benchmark = False
|
| 38 |
+
torch.backends.cudnn.deterministic = True
|
| 39 |
+
torch.use_deterministic_algorithms(True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# arguments
|
| 43 |
+
parser = argparse.ArgumentParser(description="Choosing local registration method")
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--local_reg", type=str, default="gicp", help="local registration: gicp or freg"
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument("--seed", type=int, default=42, help="random seed (default: 42)")
|
| 48 |
+
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
set_seed(args.seed)
|
| 51 |
+
print(f"Using {args.local_reg} for local registration")
|
| 52 |
+
|
| 53 |
+
def analyze_results(results: dict, recall_threshold = 1, rmse_threshold = 0.053, verbose = False): # change the default values to your needs
|
| 54 |
+
table = []
|
| 55 |
+
fail_count = 0
|
| 56 |
+
success_count = 0
|
| 57 |
+
for object, values in results.items():
|
| 58 |
+
row = [object] + list(values)
|
| 59 |
+
if round(row[2], 3) < recall_threshold or round(row[3], 3) > rmse_threshold:
|
| 60 |
+
status = 'failed'
|
| 61 |
+
fail_count += 1
|
| 62 |
+
print(f'No match for {object}! Try a different method. If the issue persists, please check the data.')
|
| 63 |
+
else:
|
| 64 |
+
status = 'success'
|
| 65 |
+
success_count += 1
|
| 66 |
+
print(f'Found match for {object}!')
|
| 67 |
+
row.append(status)
|
| 68 |
+
table.append(row)
|
| 69 |
+
|
| 70 |
+
if verbose:
|
| 71 |
+
print(tabulate(table, headers=['Object', 'Chamfer Distance', 'Reg. Recall', 'Inlier RMSE', 'Computation Time', 'Status'], tablefmt='grid'))
|
| 72 |
+
print(f"Success rate: {success_count / (success_count + fail_count) * 100:.2f}%")
|
| 73 |
+
|
| 74 |
+
return table
|
| 75 |
+
|
| 76 |
+
def show_successful_resutls(table, sources, targets, pc_results, method_name = None):
|
| 77 |
+
for i in range (len(table)):
|
| 78 |
+
if table[i][-1] == 'success':
|
| 79 |
+
# visualization.plot_point_cloud(sources[i], targets[i], list(pc_results.values())[i]) # uncomment if below visualization does not work
|
| 80 |
+
visualization.draw_registration_result(targets[i], list(pc_results.values())[i], np.eye(4), method_name)
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
base_dir = get_sioux_data_root()
|
| 84 |
+
scan_dir = os.path.join(base_dir, 'sioux_scans')
|
| 85 |
+
cad_dir = os.path.join(base_dir, 'sioux_cranfield')
|
| 86 |
+
|
| 87 |
+
pcd_paths = [ os.path.join(scan_dir,'teeth_clean.ply'),
|
| 88 |
+
os.path.join(scan_dir,'lime_clean.ply'),
|
| 89 |
+
os.path.join(scan_dir,'cube_clean.ply'),
|
| 90 |
+
os.path.join(scan_dir,'lego_clean.ply'),
|
| 91 |
+
os.path.join(scan_dir,'elephant_clean.ply'),
|
| 92 |
+
os.path.join(scan_dir,'house_clean.ply'),
|
| 93 |
+
os.path.join(scan_dir,'shoe_clean.ply')]
|
| 94 |
+
|
| 95 |
+
cad_paths = [ os.path.join(cad_dir,'teeth.stl'),
|
| 96 |
+
os.path.join(cad_dir,'lime.stl'),
|
| 97 |
+
os.path.join(cad_dir,'cube.stl'),
|
| 98 |
+
os.path.join(cad_dir,'lego.stl'),
|
| 99 |
+
os.path.join(cad_dir,'elephant.stl'),
|
| 100 |
+
os.path.join(cad_dir,'house.stl'),
|
| 101 |
+
os.path.join(cad_dir,'shoe.stl')]
|
| 102 |
+
|
| 103 |
+
# Initialize lists and dictionaries to store results
|
| 104 |
+
rpm_net_results = {}
|
| 105 |
+
rpm_net_pc_results = {}
|
| 106 |
+
predator_results = {}
|
| 107 |
+
predator_pc_results = {}
|
| 108 |
+
geotransformer_results = {}
|
| 109 |
+
geotransformer_pc_results = {}
|
| 110 |
+
logdesc_results = {}
|
| 111 |
+
logdesc_pc_results = {}
|
| 112 |
+
regtr_results = {}
|
| 113 |
+
regtr_pc_results = {}
|
| 114 |
+
r3pm_net_results = {}
|
| 115 |
+
r3pm_net_pc_results ={}
|
| 116 |
+
tuned_r3pm_net_results = {}
|
| 117 |
+
tuned_r3pm_net_pc_results = {}
|
| 118 |
+
subset_tuned_r3pm_net_results = {}
|
| 119 |
+
subset_tuned_r3pm_net_pc_results = {}
|
| 120 |
+
|
| 121 |
+
sources = []
|
| 122 |
+
targets = []
|
| 123 |
+
|
| 124 |
+
pretrained_base_dir = get_pretrained_rpmnet_dir()
|
| 125 |
+
method_paths = get_method_paths()
|
| 126 |
+
_path_zs = os.path.join(pretrained_base_dir, "clean-trained.pth")
|
| 127 |
+
_path_ft = os.path.join(pretrained_base_dir, "best_model_PointNet2.t7") #TODO: CHANGE
|
| 128 |
+
_path_ft_sub = os.path.join(pretrained_base_dir, "best_model_PointNet_subset.t7") #TODO: CHANGE
|
| 129 |
+
|
| 130 |
+
# set arguments for models
|
| 131 |
+
rpm_args = l3d_helper.options(modelName="RPMNet")
|
| 132 |
+
rpm_args.pretrained = _path_zs
|
| 133 |
+
|
| 134 |
+
# OverlapPredator (used by Predator runner)
|
| 135 |
+
predator_cfg = method_paths.get("predator", {})
|
| 136 |
+
predator_root = predator_cfg.get("root")
|
| 137 |
+
predator_config_path = predator_cfg.get("config_path")
|
| 138 |
+
predator_weights_path = predator_cfg.get("weights_path")
|
| 139 |
+
|
| 140 |
+
# GeoTransformer
|
| 141 |
+
geo_cfg = method_paths.get("geotransformer", {})
|
| 142 |
+
geotransformer_root = geo_cfg.get("root")
|
| 143 |
+
geotransformer_exp_subdir = geo_cfg.get("exp_subdir")
|
| 144 |
+
geotransformer_weights_path = geo_cfg.get("weights_path")
|
| 145 |
+
|
| 146 |
+
# LoGDesc
|
| 147 |
+
logdesc_cfg = method_paths.get("logdesc", {})
|
| 148 |
+
logdesc_root = logdesc_cfg.get("root")
|
| 149 |
+
logdesc_weights_path = logdesc_cfg.get("weights_path")
|
| 150 |
+
|
| 151 |
+
# RegTR
|
| 152 |
+
regtr_cfg = method_paths.get("regtr", {})
|
| 153 |
+
regtr_root = regtr_cfg.get("root")
|
| 154 |
+
regtr_ckpt_path = regtr_cfg.get("ckpt_path")
|
| 155 |
+
regtr_config_path = regtr_cfg.get("config_path")
|
| 156 |
+
|
| 157 |
+
# R3PM-Net (ours) - no training
|
| 158 |
+
r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
|
| 159 |
+
r3pm_net_args.pretrained = _path_zs
|
| 160 |
+
|
| 161 |
+
# R3PM-Net (ours) (FT)
|
| 162 |
+
tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
|
| 163 |
+
tuned_r3pm_net_args.pretrained = _path_ft
|
| 164 |
+
|
| 165 |
+
# R3PM-Net (ours) (FT) (Subset)
|
| 166 |
+
subset_tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
|
| 167 |
+
subset_tuned_r3pm_net_args.pretrained = _path_ft_sub
|
| 168 |
+
|
| 169 |
+
for pcdPath, cadPath in tqdm(zip(pcd_paths, cad_paths), desc="Registering objects", total=len(pcd_paths)):
|
| 170 |
+
# Define the number of points to sample from the CAD model (change this based on your data)
|
| 171 |
+
if 'teeth' in pcdPath:
|
| 172 |
+
every_k_points = 100
|
| 173 |
+
key = 'teeth'
|
| 174 |
+
elif'lime' in pcdPath:
|
| 175 |
+
every_k_points = 100
|
| 176 |
+
key = 'lime'
|
| 177 |
+
elif 'cube' in pcdPath:
|
| 178 |
+
every_k_points = 1
|
| 179 |
+
key = 'cube'
|
| 180 |
+
elif 'lego' in pcdPath:
|
| 181 |
+
every_k_points = 10
|
| 182 |
+
key = 'lego'
|
| 183 |
+
elif 'elephant' in pcdPath:
|
| 184 |
+
every_k_points = 30
|
| 185 |
+
key = 'elephant'
|
| 186 |
+
elif 'house' in pcdPath:
|
| 187 |
+
every_k_points = 25
|
| 188 |
+
key = 'house'
|
| 189 |
+
elif 'shoe' in pcdPath:
|
| 190 |
+
every_k_points = 15
|
| 191 |
+
key = 'shoe'
|
| 192 |
+
else:
|
| 193 |
+
print("Unknown object type, using default every_k_points = 1")
|
| 194 |
+
every_k_points = 1
|
| 195 |
+
|
| 196 |
+
# Load the data
|
| 197 |
+
pcd, cad = data.load_data(pcdPath, cadPath, every_k_points=every_k_points)
|
| 198 |
+
source = copy.deepcopy(pcd)
|
| 199 |
+
target = copy.deepcopy(cad)
|
| 200 |
+
|
| 201 |
+
# Normalize the point clouds
|
| 202 |
+
source = data.normalize_pc(source)
|
| 203 |
+
target = data.normalize_pc(target)
|
| 204 |
+
|
| 205 |
+
sources.append(source)
|
| 206 |
+
targets.append(target)
|
| 207 |
+
|
| 208 |
+
gt_transformation = None
|
| 209 |
+
|
| 210 |
+
# Perform the registration
|
| 211 |
+
|
| 212 |
+
# RPMNet
|
| 213 |
+
rpm_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(
|
| 214 |
+
source, target, 'rpmnet', gt_transformation, rpm_args)
|
| 215 |
+
if args.local_reg == 'gicp':
|
| 216 |
+
final_rpm_net_pc_result, final_rpm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(rpm_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
|
| 217 |
+
rpm_net_results[key] = final_rpm_net_results
|
| 218 |
+
rpm_net_pc_results[key] = final_rpm_net_pc_result
|
| 219 |
+
|
| 220 |
+
# OverlapPredator
|
| 221 |
+
predator_results_pc, _ = predator_registration_and_evaluation.predator_reg_and_eval(
|
| 222 |
+
source,
|
| 223 |
+
target,
|
| 224 |
+
gt_transformation=gt_transformation,
|
| 225 |
+
predator_root=predator_root,
|
| 226 |
+
config_path=predator_config_path,
|
| 227 |
+
weights_path=predator_weights_path,
|
| 228 |
+
ransac_n_points=1000,
|
| 229 |
+
ransac_distance_threshold=0.05,
|
| 230 |
+
ransac_n=3,
|
| 231 |
+
sampling="prob",
|
| 232 |
+
mutual=False,
|
| 233 |
+
input_num_points=1024,
|
| 234 |
+
)
|
| 235 |
+
if args.local_reg == 'gicp':
|
| 236 |
+
final_predator_pc_result, final_predator_results = icp_registration_and_evaluation.icp_reg_and_eval(predator_results_pc, target, 'gicp', 1, np.identity(4), gt_transformation)
|
| 237 |
+
predator_results[key] = final_predator_results
|
| 238 |
+
predator_pc_results[key] = final_predator_pc_result
|
| 239 |
+
|
| 240 |
+
# GeoTransformer (ModelNet)
|
| 241 |
+
geotransformer_pc_result, _ = geotransformer_registration_and_evaluation.geotransformer_reg_and_eval(
|
| 242 |
+
source,
|
| 243 |
+
target,
|
| 244 |
+
gt_transformation=gt_transformation,
|
| 245 |
+
geotransformer_root=geotransformer_root,
|
| 246 |
+
exp_subdir=geotransformer_exp_subdir,
|
| 247 |
+
weights_path=geotransformer_weights_path,
|
| 248 |
+
)
|
| 249 |
+
if args.local_reg == 'gicp':
|
| 250 |
+
final_geotransformer_pc_result, final_geotransformer_results = icp_registration_and_evaluation.icp_reg_and_eval(geotransformer_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
|
| 251 |
+
geotransformer_results[key] = final_geotransformer_results
|
| 252 |
+
geotransformer_pc_results[key] = final_geotransformer_pc_result
|
| 253 |
+
|
| 254 |
+
# LoGDesc
|
| 255 |
+
logdesc_pc_result, _ = logdesc_registration_and_evaluation.logdesc_reg_and_eval(
|
| 256 |
+
source,
|
| 257 |
+
target,
|
| 258 |
+
gt_transformation=gt_transformation,
|
| 259 |
+
logdesc_root=logdesc_root,
|
| 260 |
+
weights_path=logdesc_weights_path,
|
| 261 |
+
max_keypoints=768,
|
| 262 |
+
num_points_per_sample=128,
|
| 263 |
+
sample_radius=0.3,
|
| 264 |
+
topk_matches=128,
|
| 265 |
+
use_kpt=False,
|
| 266 |
+
)
|
| 267 |
+
if args.local_reg == 'gicp':
|
| 268 |
+
final_logdesc_pc_result, final_logdesc_results = icp_registration_and_evaluation.icp_reg_and_eval(logdesc_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
|
| 269 |
+
logdesc_results[key] = final_logdesc_results
|
| 270 |
+
logdesc_pc_results[key] = final_logdesc_pc_result
|
| 271 |
+
|
| 272 |
+
# RegTR (ModelNet)
|
| 273 |
+
regtr_pc_result, _ = regtr_registration_and_evaluation.regtr_reg_and_eval(
|
| 274 |
+
source,
|
| 275 |
+
target,
|
| 276 |
+
gt_transformation=gt_transformation,
|
| 277 |
+
regtr_root=regtr_root,
|
| 278 |
+
ckpt_path=regtr_ckpt_path,
|
| 279 |
+
config_path=regtr_config_path,
|
| 280 |
+
)
|
| 281 |
+
if args.local_reg == 'gicp':
|
| 282 |
+
final_regtr_pc_result, final_regtr_results = icp_registration_and_evaluation.icp_reg_and_eval(regtr_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
|
| 283 |
+
regtr_results[key] = final_regtr_results
|
| 284 |
+
regtr_pc_results[key] = final_regtr_pc_result
|
| 285 |
+
|
| 286 |
+
# R3PM-Net (ours) (ZS)
|
| 287 |
+
r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, r3pm_net_args)
|
| 288 |
+
if args.local_reg == 'gicp':
|
| 289 |
+
final_r3pm_net_pc_result, final_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
|
| 290 |
+
r3pm_net_results[key] = final_r3pm_net_results
|
| 291 |
+
r3pm_net_pc_results[key] = final_r3pm_net_pc_result
|
| 292 |
+
|
| 293 |
+
# R3PM-Net (ours) (FT)
|
| 294 |
+
tuned_r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, tuned_r3pm_net_args)
|
| 295 |
+
if args.local_reg == 'gicp':
|
| 296 |
+
final_tuned_r3pm_net_pc_result, final_tuned_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(tuned_r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
|
| 297 |
+
tuned_r3pm_net_results[key] = final_tuned_r3pm_net_results
|
| 298 |
+
tuned_r3pm_net_pc_results[key] = final_tuned_r3pm_net_pc_result
|
| 299 |
+
|
| 300 |
+
# R3PM-Net (ours) (FT) (Subset)
|
| 301 |
+
subset_tuned_r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, subset_tuned_r3pm_net_args)
|
| 302 |
+
if args.local_reg == 'gicp':
|
| 303 |
+
final_subset_tuned_r3pm_net_pc_result, final_subset_tuned_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(subset_tuned_r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
|
| 304 |
+
subset_tuned_r3pm_net_results[key] = final_subset_tuned_r3pm_net_results
|
| 305 |
+
subset_tuned_r3pm_net_pc_results[key] = final_subset_tuned_r3pm_net_pc_result
|
| 306 |
+
|
| 307 |
+
# Print the results
|
| 308 |
+
print("----- RPMNet: -----")
|
| 309 |
+
rpm_net_table = analyze_results(rpm_net_results, verbose=True)
|
| 310 |
+
show_successful_resutls(rpm_net_table, sources, targets, rpm_net_pc_results, 'RPMNet')
|
| 311 |
+
|
| 312 |
+
print("----- Predator: -----")
|
| 313 |
+
predator_table = analyze_results(predator_results, verbose=True)
|
| 314 |
+
show_successful_resutls(predator_table, sources, targets, predator_pc_results, 'Predator')
|
| 315 |
+
|
| 316 |
+
print("----- GeoTransformer: -----")
|
| 317 |
+
geotransformer_table = analyze_results(geotransformer_results, verbose=True)
|
| 318 |
+
show_successful_resutls(geotransformer_table, sources, targets, geotransformer_pc_results, 'GeoTransformer')
|
| 319 |
+
|
| 320 |
+
print("----- LoGDesc: -----")
|
| 321 |
+
logdesc_table = analyze_results(logdesc_results, verbose=True)
|
| 322 |
+
show_successful_resutls(logdesc_table, sources, targets, logdesc_pc_results, 'LoGDesc')
|
| 323 |
+
|
| 324 |
+
print("----- RegTR: -----")
|
| 325 |
+
regtr_table = analyze_results(regtr_results, verbose=True)
|
| 326 |
+
show_successful_resutls(regtr_table, sources, targets, regtr_pc_results, 'RegTR')
|
| 327 |
+
|
| 328 |
+
print("----- R3PM-Net (ours) (ZS): -----")
|
| 329 |
+
r3pm_net_table = analyze_results(r3pm_net_results, verbose=True)
|
| 330 |
+
show_successful_resutls(r3pm_net_table, sources, targets, r3pm_net_pc_results, 'R3PM-Net (ours) (ZS)')
|
| 331 |
+
|
| 332 |
+
print("----- R3PM-Net (ours) (FT): ----- ")
|
| 333 |
+
tuned_r3pm_net_table = analyze_results(tuned_r3pm_net_results, verbose=True)
|
| 334 |
+
show_successful_resutls(tuned_r3pm_net_table, sources, targets, tuned_r3pm_net_pc_results, 'R3PM-Net (ours) (FT)')
|
| 335 |
+
|
| 336 |
+
print("----- R3PM-Net (ours) (FT) (Subset): ----- ")
|
| 337 |
+
subset_tuned_r3pm_net_table = analyze_results(subset_tuned_r3pm_net_results, verbose=True)
|
| 338 |
+
show_successful_resutls(subset_tuned_r3pm_net_table, sources, targets, subset_tuned_r3pm_net_pc_results, 'R3PM-Net (ours) (FT) (Subset)')
|
| 339 |
+
|
| 340 |
+
if __name__ == "__main__":
|
| 341 |
+
main()
|
scripts/modelnet40.sh
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=gpu_h100
|
| 3 |
+
#SBATCH --gpus=1
|
| 4 |
+
#SBATCH --job-name=modelnet40
|
| 5 |
+
#SBATCH --ntasks=1
|
| 6 |
+
#SBATCH --time=09:00:00
|
| 7 |
+
#SBATCH --output=modelnet40_output_%A.txt
|
| 8 |
+
#SBATCH --error=modelnet40_error_%A.txt
|
| 9 |
+
|
| 10 |
+
# Load necessary modules (adjust based on your environment)
|
| 11 |
+
module purge
|
| 12 |
+
module load 2023
|
| 13 |
+
module load CUDA/12.1.1
|
| 14 |
+
|
| 15 |
+
# my miniconda3 path
|
| 16 |
+
export PATH="$HOME/miniconda3/bin:$PATH"
|
| 17 |
+
unset -f conda 2>/dev/null
|
| 18 |
+
source "$HOME/miniconda3/etc/profile.d/conda.sh"
|
| 19 |
+
|
| 20 |
+
# Activate the conda environment
|
| 21 |
+
conda activate r3pm_net
|
| 22 |
+
|
| 23 |
+
if [[ -n "${SLURM_SUBMIT_DIR:-}" ]]; then
|
| 24 |
+
REPO_ROOT="$(cd "${SLURM_SUBMIT_DIR}" && pwd)"
|
| 25 |
+
else
|
| 26 |
+
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 27 |
+
fi
|
| 28 |
+
cd "$REPO_ROOT" || { echo "ERROR: cannot cd to REPO_ROOT=${REPO_ROOT}" >&2; exit 1; }
|
| 29 |
+
if [[ ! -f "${REPO_ROOT}/pyproject.toml" ]]; then
|
| 30 |
+
echo "ERROR: REPO_ROOT=${REPO_ROOT} is not the r3pm_net tree (missing pyproject.toml)." >&2
|
| 31 |
+
echo "Run: cd /path/to/r3pm_net && sbatch scripts/modelnet40.sh" >&2
|
| 32 |
+
exit 1
|
| 33 |
+
fi
|
| 34 |
+
|
| 35 |
+
LOGDIR="${REPO_ROOT}/logs/slurm"
|
| 36 |
+
mkdir -p "$LOGDIR"
|
| 37 |
+
JOB_ID="${SLURM_JOB_ID:-local}"
|
| 38 |
+
|
| 39 |
+
# seeds=(42 61 92 114 123 456 789)
|
| 40 |
+
seeds=(42)
|
| 41 |
+
|
| 42 |
+
for seed in "${seeds[@]}"; do
|
| 43 |
+
srun python scripts/eval_modelnet40.py --seed "${seed}" \
|
| 44 |
+
>"${LOGDIR}/modelnet40_job${JOB_ID}_seed${seed}.log" 2>&1
|
| 45 |
+
done
|
scripts/sioux_cranfield.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=gpu_h100
|
| 3 |
+
#SBATCH --gpus=1
|
| 4 |
+
#SBATCH --job-name=sioux_cranfield
|
| 5 |
+
#SBATCH --ntasks=1
|
| 6 |
+
#SBATCH --time=04:00:00
|
| 7 |
+
#SBATCH --output=sioux_cranfield_output_%A.txt
|
| 8 |
+
#SBATCH --error=sioux_cranfield_error_%A.txt
|
| 9 |
+
|
| 10 |
+
# Load necessary modules (adjust based on your environment)
|
| 11 |
+
module purge
|
| 12 |
+
module load 2023
|
| 13 |
+
module load CUDA/12.1.1
|
| 14 |
+
|
| 15 |
+
# my miniconda3 path
|
| 16 |
+
export PATH="$HOME/miniconda3/bin:$PATH"
|
| 17 |
+
unset -f conda 2>/dev/null
|
| 18 |
+
source "$HOME/miniconda3/etc/profile.d/conda.sh"
|
| 19 |
+
|
| 20 |
+
# Activate the conda environment
|
| 21 |
+
conda activate r3pm_net
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if [[ -n "${SLURM_SUBMIT_DIR:-}" ]]; then
|
| 25 |
+
REPO_ROOT="$(cd "${SLURM_SUBMIT_DIR}" && pwd)"
|
| 26 |
+
else
|
| 27 |
+
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 28 |
+
fi
|
| 29 |
+
cd "$REPO_ROOT" || { echo "ERROR: cannot cd to REPO_ROOT=${REPO_ROOT}" >&2; exit 1; }
|
| 30 |
+
if [[ ! -f "${REPO_ROOT}/pyproject.toml" ]]; then
|
| 31 |
+
echo "ERROR: REPO_ROOT=${REPO_ROOT} is not the r3pm_net tree (missing pyproject.toml)." >&2
|
| 32 |
+
echo "Run: cd /path/to/r3pm_net && sbatch scripts/sioux_cranfield.sh" >&2
|
| 33 |
+
exit 1
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
LOGDIR="${REPO_ROOT}/logs/slurm"
|
| 37 |
+
mkdir -p "$LOGDIR"
|
| 38 |
+
JOB_ID="${SLURM_JOB_ID:-local}"
|
| 39 |
+
|
| 40 |
+
# seeds=(42 61 92 114 123 456 789)
|
| 41 |
+
seeds=(42)
|
| 42 |
+
|
| 43 |
+
for seed in "${seeds[@]}"; do
|
| 44 |
+
srun python scripts/eval_sioux_cranfield.py --seed "${seed}" \
|
| 45 |
+
>"${LOGDIR}/sioux_cranfield_job${JOB_ID}_seed${seed}.log" 2>&1
|
| 46 |
+
done
|
scripts/sioux_scans.sh
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=gpu_h100
|
| 3 |
+
#SBATCH --gpus=1
|
| 4 |
+
#SBATCH --job-name=sioux_scans
|
| 5 |
+
#SBATCH --ntasks=1
|
| 6 |
+
#SBATCH --time=01:00:00
|
| 7 |
+
#SBATCH --output=sioux_scans_output_%A.txt
|
| 8 |
+
#SBATCH --error=sioux_scans_error_%A.txt
|
| 9 |
+
|
| 10 |
+
# Load necessary modules (adjust based on your environment)
|
| 11 |
+
module purge
|
| 12 |
+
module load 2023
|
| 13 |
+
module load CUDA/12.1.1
|
| 14 |
+
|
| 15 |
+
# my miniconda3 path
|
| 16 |
+
export PATH="$HOME/miniconda3/bin:$PATH"
|
| 17 |
+
unset -f conda 2>/dev/null
|
| 18 |
+
source "$HOME/miniconda3/etc/profile.d/conda.sh"
|
| 19 |
+
|
| 20 |
+
# Activate the conda environment
|
| 21 |
+
conda activate r3pm_net
|
| 22 |
+
|
| 23 |
+
if [[ -n "${SLURM_SUBMIT_DIR:-}" ]]; then
|
| 24 |
+
REPO_ROOT="$(cd "${SLURM_SUBMIT_DIR}" && pwd)"
|
| 25 |
+
else
|
| 26 |
+
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 27 |
+
fi
|
| 28 |
+
cd "$REPO_ROOT" || { echo "ERROR: cannot cd to REPO_ROOT=${REPO_ROOT}" >&2; exit 1; }
|
| 29 |
+
if [[ ! -f "${REPO_ROOT}/pyproject.toml" ]]; then
|
| 30 |
+
echo "ERROR: REPO_ROOT=${REPO_ROOT} is not the r3pm_net tree (missing pyproject.toml)." >&2
|
| 31 |
+
echo "Run: cd /path/to/r3pm_net && sbatch scripts/sioux_scans.sh" >&2
|
| 32 |
+
exit 1
|
| 33 |
+
fi
|
| 34 |
+
|
| 35 |
+
LOGDIR="${REPO_ROOT}/logs/slurm"
|
| 36 |
+
mkdir -p "$LOGDIR"
|
| 37 |
+
JOB_ID="${SLURM_JOB_ID:-local}"
|
| 38 |
+
|
| 39 |
+
# seeds=(42 61 92 114 123 456 789)
|
| 40 |
+
seeds=(42)
|
| 41 |
+
|
| 42 |
+
for seed in "${seeds[@]}"; do
|
| 43 |
+
srun python scripts/eval_sioux_scans.py --seed "${seed}" \
|
| 44 |
+
>"${LOGDIR}/sioux_scans_job${JOB_ID}_seed${seed}.log" 2>&1
|
| 45 |
+
done
|
src/train.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from tensorboardX import SummaryWriter
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
# Repository root on PYTHONPATH (for `python src/train.py` or srun).
|
| 14 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 15 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 16 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 17 |
+
|
| 18 |
+
from r3pm_net.model import R3PMNet
|
| 19 |
+
from r3pm_net.config_loader import parse_train_args, resolve_path_args
|
| 20 |
+
from r3pm_net.paths import REPO_ROOT
|
| 21 |
+
from thirdparty.learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
|
| 22 |
+
from dataloader.user_data import UserData
|
| 23 |
+
from r3pm_net.feature_extractor import feature_extractor # import your feature extractor here
|
| 24 |
+
|
| 25 |
+
def _init_(args):
|
| 26 |
+
Path(args.save_dir).mkdir(parents=True, exist_ok=True)
|
| 27 |
+
(REPO_ROOT / "checkpoints" / args.exp_name).mkdir(parents=True, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
if os.path.isfile("main.py"):
|
| 30 |
+
os.system("cp main.py checkpoints" + "/" + args.exp_name + "/" + "main.py.backup")
|
| 31 |
+
if os.path.isfile("model.py"):
|
| 32 |
+
os.system("cp model.py checkpoints" + "/" + args.exp_name + "/" + "model.py.backup")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class IOStream:
|
| 36 |
+
def __init__(self, path):
|
| 37 |
+
self.f = open(path, "a")
|
| 38 |
+
|
| 39 |
+
def cprint(self, text):
|
| 40 |
+
print(text)
|
| 41 |
+
self.f.write(text + "\n")
|
| 42 |
+
self.f.flush()
|
| 43 |
+
|
| 44 |
+
def close(self):
|
| 45 |
+
self.f.close()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_one_epoch(device, model, test_loader):
|
| 49 |
+
model.eval()
|
| 50 |
+
test_loss = 0.0
|
| 51 |
+
count = 0
|
| 52 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 53 |
+
template, source, igt = data
|
| 54 |
+
|
| 55 |
+
template = template.to(device)
|
| 56 |
+
source = source.to(device)
|
| 57 |
+
igt = igt.to(device)
|
| 58 |
+
|
| 59 |
+
output = model(template, source)
|
| 60 |
+
loss_val = FrobeniusNormLoss()(output["est_T"], igt) + RMSEFeaturesLoss()(output["r"])
|
| 61 |
+
|
| 62 |
+
test_loss += loss_val.item()
|
| 63 |
+
count += 1
|
| 64 |
+
|
| 65 |
+
test_loss = float(test_loss) / count
|
| 66 |
+
return test_loss
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test(args, model, test_loader, textio):
|
| 70 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 71 |
+
textio.cprint("Validation Loss: %f" % (test_loss))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def train_one_epoch(device, model, train_loader, optimizer):
|
| 75 |
+
model.train()
|
| 76 |
+
train_loss = 0.0
|
| 77 |
+
count = 0
|
| 78 |
+
for i, data in enumerate(tqdm(train_loader)):
|
| 79 |
+
template, source, igt = data
|
| 80 |
+
|
| 81 |
+
template = template.to(device)
|
| 82 |
+
source = source.to(device)
|
| 83 |
+
igt = igt.to(device)
|
| 84 |
+
|
| 85 |
+
output = model(template, source)
|
| 86 |
+
loss_val = FrobeniusNormLoss()(output["est_T"], igt) + RMSEFeaturesLoss()(output["r"])
|
| 87 |
+
|
| 88 |
+
optimizer.zero_grad()
|
| 89 |
+
loss_val.backward()
|
| 90 |
+
optimizer.step()
|
| 91 |
+
|
| 92 |
+
train_loss += loss_val.item()
|
| 93 |
+
count += 1
|
| 94 |
+
|
| 95 |
+
train_loss = float(train_loss) / count
|
| 96 |
+
return train_loss
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
|
| 100 |
+
Path(args.save_dir).mkdir(parents=True, exist_ok=True)
|
| 101 |
+
|
| 102 |
+
learnable_params = filter(lambda p: p.requires_grad, model.parameters())
|
| 103 |
+
if args.optimizer == "Adam":
|
| 104 |
+
optimizer = torch.optim.Adam(learnable_params)
|
| 105 |
+
else:
|
| 106 |
+
optimizer = torch.optim.SGD(learnable_params, lr=0.1)
|
| 107 |
+
|
| 108 |
+
if checkpoint is not None:
|
| 109 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 110 |
+
|
| 111 |
+
best_test_loss = np.inf
|
| 112 |
+
|
| 113 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 114 |
+
train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
|
| 115 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 116 |
+
|
| 117 |
+
snap = {
|
| 118 |
+
"epoch": epoch + 1,
|
| 119 |
+
"model": model.state_dict(),
|
| 120 |
+
"min_loss": test_loss,
|
| 121 |
+
"optimizer": optimizer.state_dict(),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
if test_loss < best_test_loss:
|
| 125 |
+
best_test_loss = test_loss
|
| 126 |
+
best_snap_path = os.path.join(
|
| 127 |
+
args.save_dir, "best_model_snap.t7")
|
| 128 |
+
best_model_path = os.path.join(
|
| 129 |
+
args.save_dir, "best_model.t7")
|
| 130 |
+
|
| 131 |
+
torch.save(snap, best_snap_path)
|
| 132 |
+
torch.save(model.state_dict(), best_model_path)
|
| 133 |
+
|
| 134 |
+
torch.save(snap, os.path.join(args.save_dir, "model_snap.t7"))
|
| 135 |
+
torch.save(model.state_dict(), os.path.join(args.save_dir, "model.t7"))
|
| 136 |
+
|
| 137 |
+
boardio.add_scalar("Train Loss", train_loss, epoch + 1)
|
| 138 |
+
boardio.add_scalar("Test Loss", test_loss, epoch + 1)
|
| 139 |
+
boardio.add_scalar("Best Test Loss", best_test_loss, epoch + 1)
|
| 140 |
+
|
| 141 |
+
textio.cprint(
|
| 142 |
+
"EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f"
|
| 143 |
+
% (epoch + 1, train_loss, test_loss, best_test_loss)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def build_parser(default_config_path: str):
|
| 148 |
+
parser = argparse.ArgumentParser(description="Point Cloud Registration")
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--config",
|
| 151 |
+
type=str,
|
| 152 |
+
default=default_config_path,
|
| 153 |
+
help="YAML file with defaults (see config/default.yaml); can be overridden on the command line",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--exp_name",
|
| 157 |
+
type=str,
|
| 158 |
+
default="exp_r3pmnet",
|
| 159 |
+
metavar="N",
|
| 160 |
+
help="Name of the experiment",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument("--eval", action="store_true", help="Run evaluation only (no training).")
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--save_dir",
|
| 165 |
+
type=str,
|
| 166 |
+
default="",
|
| 167 |
+
help="Directory to save model checkpoints (default: checkpoints/<exp_name>/models)",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--num_points",
|
| 172 |
+
default=1024,
|
| 173 |
+
type=int,
|
| 174 |
+
metavar="N",
|
| 175 |
+
help="points in point-cloud (default: 1024)",
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--fine_tune_feature_extractor",
|
| 180 |
+
default="tune",
|
| 181 |
+
type=str,
|
| 182 |
+
choices=["fixed", "tune"],
|
| 183 |
+
help="train feature extractor (default: tune)",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--transfer_weights",
|
| 187 |
+
default="",
|
| 188 |
+
type=str,
|
| 189 |
+
metavar="PATH",
|
| 190 |
+
help="optional path to feature extractor checkpoint",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--symfn",
|
| 194 |
+
default="max",
|
| 195 |
+
choices=["max", "avg"],
|
| 196 |
+
help="symmetric function (default: max)",
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
parser.add_argument("--seed", type=int, default=1234)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"-j",
|
| 202 |
+
"--workers",
|
| 203 |
+
default=4,
|
| 204 |
+
type=int,
|
| 205 |
+
metavar="N",
|
| 206 |
+
help="number of data loading workers (default: 4)",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"-b",
|
| 210 |
+
"--batch_size",
|
| 211 |
+
default=5,
|
| 212 |
+
type=int,
|
| 213 |
+
metavar="N",
|
| 214 |
+
help="mini-batch size (default: 5)",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--epochs",
|
| 218 |
+
default=50,
|
| 219 |
+
type=int,
|
| 220 |
+
metavar="N",
|
| 221 |
+
help="number of total epochs to run",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--start_epoch",
|
| 225 |
+
default=0,
|
| 226 |
+
type=int,
|
| 227 |
+
metavar="N",
|
| 228 |
+
help="manual epoch number (useful on restarts)",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--optimizer",
|
| 232 |
+
default="Adam",
|
| 233 |
+
choices=["Adam", "SGD"],
|
| 234 |
+
metavar="METHOD",
|
| 235 |
+
help="name of an optimizer (default: Adam)",
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--resume",
|
| 239 |
+
default="",
|
| 240 |
+
type=str,
|
| 241 |
+
metavar="PATH",
|
| 242 |
+
help="path to latest checkpoint (default: none)",
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--pretrained",
|
| 246 |
+
default="",
|
| 247 |
+
type=str,
|
| 248 |
+
metavar="PATH",
|
| 249 |
+
help="path to pretrained full model (default: none)",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--device",
|
| 253 |
+
default="cuda:0",
|
| 254 |
+
type=str,
|
| 255 |
+
metavar="DEVICE",
|
| 256 |
+
help="use CUDA if available",
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
"--train_dict_path",
|
| 261 |
+
type=str,
|
| 262 |
+
default="data/simulators/data_dict_train.pkl",
|
| 263 |
+
help="Pickled training data_dict",
|
| 264 |
+
)
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
"--test_dict_path",
|
| 267 |
+
type=str,
|
| 268 |
+
default="data/simulators/data_dict_test.pkl",
|
| 269 |
+
help="Pickled test data_dict",
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
return parser
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _torch_load(path, map_location):
|
| 276 |
+
try:
|
| 277 |
+
return torch.load(path, map_location=map_location, weights_only=False)
|
| 278 |
+
except TypeError:
|
| 279 |
+
return torch.load(path, map_location=map_location)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def main():
|
| 283 |
+
args = parse_train_args(sys.argv[1:], build_parser)
|
| 284 |
+
|
| 285 |
+
resolve_path_args(
|
| 286 |
+
args,
|
| 287 |
+
(
|
| 288 |
+
"save_dir",
|
| 289 |
+
"train_dict_path",
|
| 290 |
+
"test_dict_path",
|
| 291 |
+
"resume",
|
| 292 |
+
"pretrained",
|
| 293 |
+
"transfer_weights",
|
| 294 |
+
),
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if not args.save_dir:
|
| 298 |
+
args.save_dir = str(REPO_ROOT / "checkpoints" / args.exp_name / "models")
|
| 299 |
+
|
| 300 |
+
torch.backends.cudnn.deterministic = True
|
| 301 |
+
torch.manual_seed(args.seed)
|
| 302 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 303 |
+
np.random.seed(args.seed)
|
| 304 |
+
|
| 305 |
+
ckpt_dir = REPO_ROOT / "checkpoints" / args.exp_name
|
| 306 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 307 |
+
boardio = SummaryWriter(log_dir=str(ckpt_dir))
|
| 308 |
+
_init_(args)
|
| 309 |
+
|
| 310 |
+
textio = IOStream(str(ckpt_dir / "run.log"))
|
| 311 |
+
textio.cprint(str(args))
|
| 312 |
+
|
| 313 |
+
if not os.path.isfile(args.train_dict_path):
|
| 314 |
+
raise FileNotFoundError(f"Training dict not found: {args.train_dict_path}")
|
| 315 |
+
if not os.path.isfile(args.test_dict_path):
|
| 316 |
+
raise FileNotFoundError(f"Test dict not found: {args.test_dict_path}")
|
| 317 |
+
|
| 318 |
+
with open(args.train_dict_path, "rb") as f:
|
| 319 |
+
data_dict_train = pickle.load(f)
|
| 320 |
+
with open(args.test_dict_path, "rb") as f:
|
| 321 |
+
data_dict_test = pickle.load(f)
|
| 322 |
+
|
| 323 |
+
trainset = UserData("registration", data_dict_train)
|
| 324 |
+
testset = UserData("registration", data_dict_test)
|
| 325 |
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.workers)
|
| 326 |
+
test_loader = DataLoader(testset, batch_size=5, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 327 |
+
|
| 328 |
+
if not torch.cuda.is_available():
|
| 329 |
+
args.device = "cpu"
|
| 330 |
+
args.device = torch.device(args.device)
|
| 331 |
+
|
| 332 |
+
# feature extractor model
|
| 333 |
+
FEATURE_MODEL = feature_extractor
|
| 334 |
+
model = R3PMNet(feature_model=FEATURE_MODEL)
|
| 335 |
+
model = model.to(args.device)
|
| 336 |
+
|
| 337 |
+
if args.transfer_weights and os.path.isfile(args.transfer_weights):
|
| 338 |
+
feat_model_dict = _torch_load(args.transfer_weights, args.device)
|
| 339 |
+
model.feat_extractor.load_state_dict(feat_model_dict)
|
| 340 |
+
|
| 341 |
+
checkpoint = None
|
| 342 |
+
if args.resume:
|
| 343 |
+
assert os.path.isfile(args.resume)
|
| 344 |
+
checkpoint = _torch_load(args.resume, args.device)
|
| 345 |
+
args.start_epoch = checkpoint["epoch"]
|
| 346 |
+
model.load_state_dict(checkpoint["model"])
|
| 347 |
+
|
| 348 |
+
if args.pretrained:
|
| 349 |
+
assert os.path.isfile(args.pretrained)
|
| 350 |
+
try:
|
| 351 |
+
model.load_state_dict(_torch_load(args.pretrained, "cpu"))
|
| 352 |
+
except RuntimeError:
|
| 353 |
+
model_data = _torch_load(args.pretrained, "cpu")
|
| 354 |
+
state_dict = model_data["state_dict"]
|
| 355 |
+
model.load_state_dict(state_dict)
|
| 356 |
+
model.to(args.device)
|
| 357 |
+
|
| 358 |
+
Path(args.save_dir).mkdir(parents=True, exist_ok=True)
|
| 359 |
+
|
| 360 |
+
if args.eval:
|
| 361 |
+
test(args, model, test_loader, textio)
|
| 362 |
+
else:
|
| 363 |
+
train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
main()
|
thirdparty/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Namespace for vendored thirdparty.learning3d
|
thirdparty/learning3d/data_utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataloaders import ModelNet40Data
|
| 2 |
+
from .dataloaders import ClassificationData, RegistrationData, SegmentationData, FlowData, SceneflowDataset
|
| 3 |
+
from .dataloaders import download_modelnet40, deg_to_rad, create_random_transform
|
| 4 |
+
from .user_data import UserData
|
thirdparty/learning3d/data_utils/dataloaders.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import h5py
|
| 9 |
+
import subprocess
|
| 10 |
+
import shlex
|
| 11 |
+
import json
|
| 12 |
+
import glob
|
| 13 |
+
from .. ops import transform_functions, se3
|
| 14 |
+
from sklearn.neighbors import NearestNeighbors
|
| 15 |
+
from scipy.spatial.distance import minkowski
|
| 16 |
+
from scipy.spatial import cKDTree
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
|
| 19 |
+
def download_modelnet40():
|
| 20 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 21 |
+
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
|
| 22 |
+
if not os.path.exists(DATA_DIR):
|
| 23 |
+
os.mkdir(DATA_DIR)
|
| 24 |
+
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
|
| 25 |
+
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
|
| 26 |
+
zipfile = os.path.basename(www)
|
| 27 |
+
os.system('wget --no-check-certificate %s; unzip %s' % (www, zipfile))
|
| 28 |
+
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
|
| 29 |
+
os.system('rm %s' % (zipfile))
|
| 30 |
+
|
| 31 |
+
def load_data(train, use_normals):
|
| 32 |
+
if train: partition = 'train'
|
| 33 |
+
else: partition = 'test'
|
| 34 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 35 |
+
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
|
| 36 |
+
all_data = []
|
| 37 |
+
all_label = []
|
| 38 |
+
for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)):
|
| 39 |
+
f = h5py.File(h5_name)
|
| 40 |
+
if use_normals: data = np.concatenate([f['data'][:], f['normal'][:]], axis=-1).astype('float32')
|
| 41 |
+
else: data = f['data'][:].astype('float32')
|
| 42 |
+
label = f['label'][:].astype('int64')
|
| 43 |
+
f.close()
|
| 44 |
+
all_data.append(data)
|
| 45 |
+
all_label.append(label)
|
| 46 |
+
all_data = np.concatenate(all_data, axis=0)
|
| 47 |
+
all_label = np.concatenate(all_label, axis=0)
|
| 48 |
+
return all_data, all_label
|
| 49 |
+
|
| 50 |
+
def deg_to_rad(deg):
|
| 51 |
+
return np.pi / 180 * deg
|
| 52 |
+
|
| 53 |
+
def create_random_transform(dtype, max_rotation_deg, max_translation):
|
| 54 |
+
max_rotation = deg_to_rad(max_rotation_deg)
|
| 55 |
+
rot = np.random.uniform(-max_rotation, max_rotation, [1, 3])
|
| 56 |
+
trans = np.random.uniform(-max_translation, max_translation, [1, 3])
|
| 57 |
+
quat = transform_functions.euler_to_quaternion(rot, "xyz")
|
| 58 |
+
|
| 59 |
+
vec = np.concatenate([quat, trans], axis=1)
|
| 60 |
+
vec = torch.tensor(vec, dtype=dtype)
|
| 61 |
+
return vec
|
| 62 |
+
|
| 63 |
+
def jitter_pointcloud(pointcloud, sigma=0.04, clip=0.05):
|
| 64 |
+
# N, C = pointcloud.shape
|
| 65 |
+
sigma = 0.04*np.random.random_sample()
|
| 66 |
+
pointcloud += torch.empty(pointcloud.shape).normal_(mean=0, std=sigma).clamp(-clip, clip)
|
| 67 |
+
return pointcloud
|
| 68 |
+
|
| 69 |
+
def farthest_subsample_points(pointcloud1, num_subsampled_points=768):
|
| 70 |
+
pointcloud1 = pointcloud1
|
| 71 |
+
num_points = pointcloud1.shape[0]
|
| 72 |
+
nbrs1 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto',
|
| 73 |
+
metric=lambda x, y: minkowski(x, y)).fit(pointcloud1[:, :3])
|
| 74 |
+
random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1])
|
| 75 |
+
idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((num_subsampled_points,))
|
| 76 |
+
gt_mask = torch.zeros(num_points).scatter_(0, torch.tensor(idx1), 1)
|
| 77 |
+
return pointcloud1[idx1, :], gt_mask
|
| 78 |
+
|
| 79 |
+
def uniform_2_sphere(num: int = None):
|
| 80 |
+
"""Uniform sampling on a 2-sphere
|
| 81 |
+
|
| 82 |
+
Source: https://gist.github.com/andrewbolster/10274979
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
num: Number of vectors to sample (or None if single)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Random Vector (np.ndarray) of size (num, 3) with norm 1.
|
| 89 |
+
If num is None returned value will have size (3,)
|
| 90 |
+
|
| 91 |
+
"""
|
| 92 |
+
if num is not None:
|
| 93 |
+
phi = np.random.uniform(0.0, 2 * np.pi, num)
|
| 94 |
+
cos_theta = np.random.uniform(-1.0, 1.0, num)
|
| 95 |
+
else:
|
| 96 |
+
phi = np.random.uniform(0.0, 2 * np.pi)
|
| 97 |
+
cos_theta = np.random.uniform(-1.0, 1.0)
|
| 98 |
+
|
| 99 |
+
theta = np.arccos(cos_theta)
|
| 100 |
+
x = np.sin(theta) * np.cos(phi)
|
| 101 |
+
y = np.sin(theta) * np.sin(phi)
|
| 102 |
+
z = np.cos(theta)
|
| 103 |
+
|
| 104 |
+
return np.stack((x, y, z), axis=-1)
|
| 105 |
+
|
| 106 |
+
def planar_crop(points, p_keep= 0.7):
|
| 107 |
+
p_keep = np.array(p_keep, dtype=np.float32)
|
| 108 |
+
|
| 109 |
+
rand_xyz = uniform_2_sphere()
|
| 110 |
+
pts = points.numpy()
|
| 111 |
+
centroid = np.mean(pts[:, :3], axis=0)
|
| 112 |
+
points_centered = pts[:, :3] - centroid
|
| 113 |
+
|
| 114 |
+
dist_from_plane = np.dot(points_centered, rand_xyz)
|
| 115 |
+
|
| 116 |
+
mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100)
|
| 117 |
+
idx_x = torch.Tensor(np.nonzero(mask))
|
| 118 |
+
|
| 119 |
+
return torch.Tensor(pts[mask, :3]), idx_x
|
| 120 |
+
|
| 121 |
+
def knn_idx(pts, k):
|
| 122 |
+
kdt = cKDTree(pts)
|
| 123 |
+
_, idx = kdt.query(pts, k=k+1)
|
| 124 |
+
return idx[:, 1:]
|
| 125 |
+
|
| 126 |
+
def get_rri(pts, k):
|
| 127 |
+
# pts: N x 3, original points
|
| 128 |
+
# q: N x K x 3, nearest neighbors
|
| 129 |
+
q = pts[knn_idx(pts, k)]
|
| 130 |
+
p = np.repeat(pts[:, None], k, axis=1)
|
| 131 |
+
# rp, rq: N x K x 1, norms
|
| 132 |
+
rp = np.linalg.norm(p, axis=-1, keepdims=True)
|
| 133 |
+
rq = np.linalg.norm(q, axis=-1, keepdims=True)
|
| 134 |
+
pn = p / rp
|
| 135 |
+
qn = q / rq
|
| 136 |
+
dot = np.sum(pn * qn, -1, keepdims=True)
|
| 137 |
+
# theta: N x K x 1, angles
|
| 138 |
+
theta = np.arccos(np.clip(dot, -1, 1))
|
| 139 |
+
T_q = q - dot * p
|
| 140 |
+
sin_psi = np.sum(np.cross(T_q[:, None], T_q[:, :, None]) * pn[:, None], -1)
|
| 141 |
+
cos_psi = np.sum(T_q[:, None] * T_q[:, :, None], -1)
|
| 142 |
+
psi = np.arctan2(sin_psi, cos_psi) % (2*np.pi)
|
| 143 |
+
idx = np.argpartition(psi, 1)[:, :, 1:2]
|
| 144 |
+
# phi: N x K x 1, projection angles
|
| 145 |
+
phi = np.take_along_axis(psi, idx, axis=-1)
|
| 146 |
+
feat = np.concatenate([rp, rq, theta, phi], axis=-1)
|
| 147 |
+
return feat.reshape(-1, k * 4)
|
| 148 |
+
|
| 149 |
+
def get_rri_cuda(pts, k, npts_per_block=1):
|
| 150 |
+
try:
|
| 151 |
+
import pycuda.autoinit
|
| 152 |
+
from pycuda import gpuarray
|
| 153 |
+
from pycuda.compiler import SourceModule
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print("Error raised in pycuda modules! pycuda only works with GPU, ", e)
|
| 156 |
+
raise
|
| 157 |
+
|
| 158 |
+
mod_rri = SourceModule(open('rri.cu').read() % (k, npts_per_block))
|
| 159 |
+
rri_cuda = mod_rri.get_function('get_rri_feature')
|
| 160 |
+
|
| 161 |
+
N = len(pts)
|
| 162 |
+
pts_gpu = gpuarray.to_gpu(pts.astype(np.float32).ravel())
|
| 163 |
+
k_idx = knn_idx(pts, k)
|
| 164 |
+
k_idx_gpu = gpuarray.to_gpu(k_idx.astype(np.int32).ravel())
|
| 165 |
+
feat_gpu = gpuarray.GPUArray((N * k * 4,), np.float32)
|
| 166 |
+
|
| 167 |
+
rri_cuda(pts_gpu, np.int32(N), k_idx_gpu, feat_gpu,
|
| 168 |
+
grid=(((N-1) // npts_per_block)+1, 1),
|
| 169 |
+
block=(npts_per_block, k, 1))
|
| 170 |
+
|
| 171 |
+
feat = feat_gpu.get().reshape(N, k * 4).astype(np.float32)
|
| 172 |
+
return feat
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class UnknownDataTypeError(Exception):
|
| 176 |
+
def __init__(self, *args):
|
| 177 |
+
if args: self.message = args[0]
|
| 178 |
+
else: self.message = 'Datatype not understood for dataset.'
|
| 179 |
+
|
| 180 |
+
def __str__(self):
|
| 181 |
+
return self.message
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ModelNet40Data(Dataset):
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
train=True,
|
| 188 |
+
num_points=1024,
|
| 189 |
+
download=True,
|
| 190 |
+
randomize_data=False,
|
| 191 |
+
use_normals=False
|
| 192 |
+
):
|
| 193 |
+
super(ModelNet40Data, self).__init__()
|
| 194 |
+
if download: download_modelnet40()
|
| 195 |
+
self.data, self.labels = load_data(train, use_normals)
|
| 196 |
+
if not train: self.shapes = self.read_classes_ModelNet40()
|
| 197 |
+
self.num_points = num_points
|
| 198 |
+
self.randomize_data = randomize_data
|
| 199 |
+
|
| 200 |
+
def __getitem__(self, idx):
|
| 201 |
+
if self.randomize_data: current_points = self.randomize(idx)
|
| 202 |
+
else: current_points = self.data[idx].copy()
|
| 203 |
+
|
| 204 |
+
current_points = torch.from_numpy(current_points[:self.num_points, :]).float()
|
| 205 |
+
label = torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
|
| 206 |
+
|
| 207 |
+
return current_points, label
|
| 208 |
+
|
| 209 |
+
def __len__(self):
|
| 210 |
+
return self.data.shape[0]
|
| 211 |
+
|
| 212 |
+
def randomize(self, idx):
|
| 213 |
+
pt_idxs = np.arange(0, self.num_points)
|
| 214 |
+
np.random.shuffle(pt_idxs)
|
| 215 |
+
return self.data[idx, pt_idxs].copy()
|
| 216 |
+
|
| 217 |
+
def get_shape(self, label):
|
| 218 |
+
return self.shapes[label]
|
| 219 |
+
|
| 220 |
+
def read_classes_ModelNet40(self):
|
| 221 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 222 |
+
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
|
| 223 |
+
file = open(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'shape_names.txt'), 'r')
|
| 224 |
+
shape_names = file.read()
|
| 225 |
+
shape_names = np.array(shape_names.split('\n')[:-1])
|
| 226 |
+
return shape_names
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ClassificationData(Dataset):
|
| 230 |
+
def __init__(self, data_class=ModelNet40Data()):
|
| 231 |
+
super(ClassificationData, self).__init__()
|
| 232 |
+
self.set_class(data_class)
|
| 233 |
+
|
| 234 |
+
def __len__(self):
|
| 235 |
+
return len(self.data_class)
|
| 236 |
+
|
| 237 |
+
def set_class(self, data_class):
|
| 238 |
+
self.data_class = data_class
|
| 239 |
+
|
| 240 |
+
def get_shape(self, label):
|
| 241 |
+
try:
|
| 242 |
+
return self.data_class.get_shape(label)
|
| 243 |
+
except:
|
| 244 |
+
return -1
|
| 245 |
+
|
| 246 |
+
def __getitem__(self, index):
|
| 247 |
+
return self.data_class[index]
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class RegistrationData(Dataset):
|
| 251 |
+
def __init__(self, algorithm, data_class=ModelNet40Data(), partial_source=False, partial_template=False, noise=False, additional_params={}):
|
| 252 |
+
super(RegistrationData, self).__init__()
|
| 253 |
+
available_algorithms = ['PCRNet', 'PointNetLK', 'DCP', 'PRNet', 'iPCRNet', 'RPMNet', 'DeepGMR']
|
| 254 |
+
if algorithm in available_algorithms: self.algorithm = algorithm
|
| 255 |
+
else: raise Exception("Algorithm not available for registration.")
|
| 256 |
+
|
| 257 |
+
self.set_class(data_class)
|
| 258 |
+
self.partial_template = partial_template
|
| 259 |
+
self.partial_source = partial_source
|
| 260 |
+
self.noise = noise
|
| 261 |
+
self.additional_params = additional_params
|
| 262 |
+
self.use_rri = False
|
| 263 |
+
|
| 264 |
+
if self.algorithm == 'PCRNet' or self.algorithm == 'iPCRNet':
|
| 265 |
+
from .. ops.transform_functions import PCRNetTransform
|
| 266 |
+
self.transforms = PCRNetTransform(len(data_class), angle_range=45, translation_range=1)
|
| 267 |
+
if self.algorithm == 'PointNetLK':
|
| 268 |
+
from .. ops.transform_functions import PNLKTransform
|
| 269 |
+
self.transforms = PNLKTransform(0.8, True)
|
| 270 |
+
if self.algorithm == 'RPMNet':
|
| 271 |
+
from .. ops.transform_functions import RPMNetTransform
|
| 272 |
+
self.transforms = RPMNetTransform(0.8, True)
|
| 273 |
+
if self.algorithm == 'DCP' or self.algorithm == 'PRNet':
|
| 274 |
+
from .. ops.transform_functions import DCPTransform
|
| 275 |
+
self.transforms = DCPTransform(angle_range=45, translation_range=1)
|
| 276 |
+
if self.algorithm == 'DeepGMR':
|
| 277 |
+
self.get_rri = get_rri_cuda if torch.cuda.is_available() else get_rri
|
| 278 |
+
from .. ops.transform_functions import DeepGMRTransform
|
| 279 |
+
self.transforms = DeepGMRTransform(angle_range=90, translation_range=1)
|
| 280 |
+
if 'nearest_neighbors' in self.additional_params.keys() and self.additional_params['nearest_neighbors'] > 0:
|
| 281 |
+
self.use_rri = True
|
| 282 |
+
self.nearest_neighbors = self.additional_params['nearest_neighbors']
|
| 283 |
+
|
| 284 |
+
def __len__(self):
|
| 285 |
+
return len(self.data_class)
|
| 286 |
+
|
| 287 |
+
def set_class(self, data_class):
|
| 288 |
+
self.data_class = data_class
|
| 289 |
+
|
| 290 |
+
def __getitem__(self, index):
|
| 291 |
+
template, label = self.data_class[index]
|
| 292 |
+
self.transforms.index = index # for fixed transformations in PCRNet.
|
| 293 |
+
source = self.transforms(template)
|
| 294 |
+
|
| 295 |
+
# Check for Partial Data.
|
| 296 |
+
if self.additional_params.get('partial_point_cloud_method', None) == 'planar_crop':
|
| 297 |
+
source, gt_idx_source = planar_crop(source)
|
| 298 |
+
template, gt_idx_template = planar_crop(template)
|
| 299 |
+
intersect_mask, intersect_x, intersect_y = np.intersect1d(gt_idx_source, gt_idx_template, return_indices=True)
|
| 300 |
+
|
| 301 |
+
self.template_mask = torch.zeros(template.shape[0])
|
| 302 |
+
self.source_mask = torch.zeros(source.shape[0])
|
| 303 |
+
self.template_mask[intersect_y] = 1
|
| 304 |
+
self.source_mask[intersect_x] = 1
|
| 305 |
+
else:
|
| 306 |
+
if self.partial_source: source, self.source_mask = farthest_subsample_points(source)
|
| 307 |
+
if self.partial_template: template, self.template_mask = farthest_subsample_points(template)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# Check for Noise in Source Data.
|
| 312 |
+
if self.noise: source = jitter_pointcloud(source)
|
| 313 |
+
|
| 314 |
+
if self.use_rri:
|
| 315 |
+
template, source = template.numpy(), source.numpy()
|
| 316 |
+
template = np.concatenate([template, self.get_rri(template - template.mean(axis=0), self.nearest_neighbors)], axis=1)
|
| 317 |
+
source = np.concatenate([source, self.get_rri(source - source.mean(axis=0), self.nearest_neighbors)], axis=1)
|
| 318 |
+
template, source = torch.tensor(template).float(), torch.tensor(source).float()
|
| 319 |
+
|
| 320 |
+
igt = self.transforms.igt
|
| 321 |
+
|
| 322 |
+
if self.additional_params.get('use_masknet', False):
|
| 323 |
+
if self.partial_source and self.partial_template:
|
| 324 |
+
return template, source, igt, self.template_mask, self.source_mask
|
| 325 |
+
elif self.partial_source:
|
| 326 |
+
return template, source, igt, self.source_mask
|
| 327 |
+
elif self.partial_template:
|
| 328 |
+
return template, source, igt, self.template_mask
|
| 329 |
+
else:
|
| 330 |
+
return template, source, igt
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class SegmentationData(Dataset):
|
| 334 |
+
def __init__(self):
|
| 335 |
+
super(SegmentationData, self).__init__()
|
| 336 |
+
|
| 337 |
+
def __len__(self):
|
| 338 |
+
pass
|
| 339 |
+
|
| 340 |
+
def __getitem__(self, index):
|
| 341 |
+
pass
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class FlowData(Dataset):
|
| 345 |
+
def __init__(self):
|
| 346 |
+
super(FlowData, self).__init__()
|
| 347 |
+
self.pc1, self.pc2, self.flow = self.read_data()
|
| 348 |
+
|
| 349 |
+
def __len__(self):
|
| 350 |
+
if isinstance(self.pc1, np.ndarray):
|
| 351 |
+
return self.pc1.shape[0]
|
| 352 |
+
elif isinstance(self.pc1, list):
|
| 353 |
+
return len(self.pc1)
|
| 354 |
+
else:
|
| 355 |
+
raise UnknownDataTypeError
|
| 356 |
+
|
| 357 |
+
def read_data(self):
|
| 358 |
+
pass
|
| 359 |
+
|
| 360 |
+
def __getitem__(self, index):
|
| 361 |
+
return self.pc1[index], self.pc2[index], self.flow[index]
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class SceneflowDataset(Dataset):
|
| 365 |
+
def __init__(self, npoints=1024, root='', partition='train'):
|
| 366 |
+
if root == '':
|
| 367 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 368 |
+
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
|
| 369 |
+
root = os.path.join(DATA_DIR, 'data_processed_maxcut_35_20k_2k_8192')
|
| 370 |
+
if not os.path.exists(root):
|
| 371 |
+
print("To download dataset, click here: https://drive.google.com/file/d/1CMaxdt-Tg1Wct8v8eGNwuT7qRSIyJPY-/view")
|
| 372 |
+
exit()
|
| 373 |
+
else:
|
| 374 |
+
print("SceneflowDataset Found Successfully!")
|
| 375 |
+
|
| 376 |
+
self.npoints = npoints
|
| 377 |
+
self.partition = partition
|
| 378 |
+
self.root = root
|
| 379 |
+
if self.partition=='train':
|
| 380 |
+
self.datapath = glob.glob(os.path.join(self.root, 'TRAIN*.npz'))
|
| 381 |
+
else:
|
| 382 |
+
self.datapath = glob.glob(os.path.join(self.root, 'TEST*.npz'))
|
| 383 |
+
self.cache = {}
|
| 384 |
+
self.cache_size = 30000
|
| 385 |
+
|
| 386 |
+
###### deal with one bad datapoint with nan value
|
| 387 |
+
self.datapath = [d for d in self.datapath if 'TRAIN_C_0140_left_0006-0' not in d]
|
| 388 |
+
######
|
| 389 |
+
print(self.partition, ': ',len(self.datapath))
|
| 390 |
+
|
| 391 |
+
def __getitem__(self, index):
|
| 392 |
+
if index in self.cache:
|
| 393 |
+
pos1, pos2, color1, color2, flow, mask1 = self.cache[index]
|
| 394 |
+
else:
|
| 395 |
+
fn = self.datapath[index]
|
| 396 |
+
with open(fn, 'rb') as fp:
|
| 397 |
+
data = np.load(fp)
|
| 398 |
+
pos1 = data['points1'].astype('float32')
|
| 399 |
+
pos2 = data['points2'].astype('float32')
|
| 400 |
+
color1 = data['color1'].astype('float32')
|
| 401 |
+
color2 = data['color2'].astype('float32')
|
| 402 |
+
flow = data['flow'].astype('float32')
|
| 403 |
+
mask1 = data['valid_mask1']
|
| 404 |
+
|
| 405 |
+
if len(self.cache) < self.cache_size:
|
| 406 |
+
self.cache[index] = (pos1, pos2, color1, color2, flow, mask1)
|
| 407 |
+
|
| 408 |
+
if self.partition == 'train':
|
| 409 |
+
n1 = pos1.shape[0]
|
| 410 |
+
sample_idx1 = np.random.choice(n1, self.npoints, replace=False)
|
| 411 |
+
n2 = pos2.shape[0]
|
| 412 |
+
sample_idx2 = np.random.choice(n2, self.npoints, replace=False)
|
| 413 |
+
|
| 414 |
+
pos1 = pos1[sample_idx1, :]
|
| 415 |
+
pos2 = pos2[sample_idx2, :]
|
| 416 |
+
color1 = color1[sample_idx1, :]
|
| 417 |
+
color2 = color2[sample_idx2, :]
|
| 418 |
+
flow = flow[sample_idx1, :]
|
| 419 |
+
mask1 = mask1[sample_idx1]
|
| 420 |
+
else:
|
| 421 |
+
pos1 = pos1[:self.npoints, :]
|
| 422 |
+
pos2 = pos2[:self.npoints, :]
|
| 423 |
+
color1 = color1[:self.npoints, :]
|
| 424 |
+
color2 = color2[:self.npoints, :]
|
| 425 |
+
flow = flow[:self.npoints, :]
|
| 426 |
+
mask1 = mask1[:self.npoints]
|
| 427 |
+
|
| 428 |
+
pos1_center = np.mean(pos1, 0)
|
| 429 |
+
pos1 -= pos1_center
|
| 430 |
+
pos2 -= pos1_center
|
| 431 |
+
|
| 432 |
+
return pos1, pos2, color1, color2, flow, mask1
|
| 433 |
+
|
| 434 |
+
def __len__(self):
|
| 435 |
+
return len(self.datapath)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
if __name__ == '__main__':
|
| 439 |
+
class Data():
|
| 440 |
+
def __init__(self):
|
| 441 |
+
super(Data, self).__init__()
|
| 442 |
+
self.data, self.label = self.read_data()
|
| 443 |
+
|
| 444 |
+
def read_data(self):
|
| 445 |
+
return [4,5,6], [4,5,6]
|
| 446 |
+
|
| 447 |
+
def __len__(self):
|
| 448 |
+
return len(self.data)
|
| 449 |
+
|
| 450 |
+
def __getitem__(self, idx):
|
| 451 |
+
return self.data[idx], self.label[idx]
|
| 452 |
+
|
| 453 |
+
cd = RegistrationData('abc')
|
| 454 |
+
import ipdb; ipdb.set_trace()
|
thirdparty/learning3d/data_utils/user_data.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class ClassificationData:
|
| 6 |
+
def __init__(self, data_dict):
|
| 7 |
+
self.data_dict = data_dict
|
| 8 |
+
self.pcs = self.find_attribute('pcs')
|
| 9 |
+
self.labels = self.find_attribute('labels')
|
| 10 |
+
self.check_data()
|
| 11 |
+
|
| 12 |
+
def find_attribute(self, attribute):
|
| 13 |
+
try:
|
| 14 |
+
attribute_data = self.data_dict[attribute]
|
| 15 |
+
except:
|
| 16 |
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
| 17 |
+
return attribute_data
|
| 18 |
+
|
| 19 |
+
def check_data(self):
|
| 20 |
+
assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
|
| 21 |
+
assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
|
| 22 |
+
|
| 23 |
+
if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
|
| 24 |
+
if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)
|
| 25 |
+
|
| 26 |
+
assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return self.pcs.shape[0]
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, index):
|
| 33 |
+
return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RegistrationData:
|
| 37 |
+
def __init__(self, data_dict):
|
| 38 |
+
self.data_dict = data_dict
|
| 39 |
+
self.template = self.find_attribute('template')
|
| 40 |
+
self.source = self.find_attribute('source')
|
| 41 |
+
self.transformation = self.find_attribute('transformation')
|
| 42 |
+
self.check_data()
|
| 43 |
+
|
| 44 |
+
def find_attribute(self, attribute):
|
| 45 |
+
try:
|
| 46 |
+
attribute_data = self.data[attribute]
|
| 47 |
+
except:
|
| 48 |
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
| 49 |
+
return attribute_data
|
| 50 |
+
|
| 51 |
+
def check_data(self):
|
| 52 |
+
assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
|
| 53 |
+
assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
|
| 54 |
+
assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)
|
| 55 |
+
|
| 56 |
+
if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
|
| 57 |
+
if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
|
| 58 |
+
if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)
|
| 59 |
+
|
| 60 |
+
assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
|
| 61 |
+
assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return self.template.shape[0]
|
| 65 |
+
|
| 66 |
+
def __getitem__(self, index):
|
| 67 |
+
return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class FlowData:
|
| 71 |
+
def __init__(self, data_dict):
|
| 72 |
+
self.data_dict = data_dict
|
| 73 |
+
self.frame1 = self.find_attribute('frame1')
|
| 74 |
+
self.frame2 = self.find_attribute('frame2')
|
| 75 |
+
self.flow = self.find_attribute('flow')
|
| 76 |
+
self.check_data()
|
| 77 |
+
|
| 78 |
+
def find_attribute(self, attribute):
|
| 79 |
+
try:
|
| 80 |
+
attribute_data = self.data[attribute]
|
| 81 |
+
except:
|
| 82 |
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
| 83 |
+
return attribute_data
|
| 84 |
+
|
| 85 |
+
def check_data(self):
|
| 86 |
+
assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
|
| 87 |
+
assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
|
| 88 |
+
assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)
|
| 89 |
+
|
| 90 |
+
if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
|
| 91 |
+
if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
|
| 92 |
+
if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)
|
| 93 |
+
|
| 94 |
+
assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
|
| 95 |
+
assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"
|
| 96 |
+
|
| 97 |
+
def __len__(self):
|
| 98 |
+
return self.frame1.shape[0]
|
| 99 |
+
|
| 100 |
+
def __getitem__(self, index):
|
| 101 |
+
return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class UserData:
|
| 105 |
+
def __init__(self, application, data_dict):
|
| 106 |
+
self.application = application
|
| 107 |
+
|
| 108 |
+
if self.application == 'classification':
|
| 109 |
+
self.data_class = ClassificationData(data_dict)
|
| 110 |
+
elif self.application == 'registration':
|
| 111 |
+
self.data_class = RegistrationData(data_dict)
|
| 112 |
+
elif self.application == 'flow_estimation':
|
| 113 |
+
self.data_class = FlowData(data_dict)
|
| 114 |
+
|
| 115 |
+
def __len__(self):
|
| 116 |
+
return len(self.data_class)
|
| 117 |
+
|
| 118 |
+
def __getitem__(self, index):
|
| 119 |
+
return self.data_class[index]
|
thirdparty/learning3d/examples/test_curvenet.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import CurveNet
|
| 22 |
+
from learning3d.data_utils import ClassificationData, ModelNet40Data
|
| 23 |
+
|
| 24 |
+
def display_open3d(template):
|
| 25 |
+
template_ = o3d.geometry.PointCloud()
|
| 26 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 27 |
+
# template_.paint_uniform_color([1, 0, 0])
|
| 28 |
+
o3d.visualization.draw_geometries([template_])
|
| 29 |
+
|
| 30 |
+
def test_one_epoch(device, model, test_loader, testset):
|
| 31 |
+
model.eval()
|
| 32 |
+
test_loss = 0.0
|
| 33 |
+
pred = 0.0
|
| 34 |
+
count = 0
|
| 35 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 36 |
+
points, target = data
|
| 37 |
+
target = target[:,0]
|
| 38 |
+
|
| 39 |
+
points = points.to(device)
|
| 40 |
+
target = target.to(device)
|
| 41 |
+
|
| 42 |
+
output = model(points)
|
| 43 |
+
loss_val = torch.nn.functional.nll_loss(
|
| 44 |
+
torch.nn.functional.log_softmax(output, dim=1), target, size_average=False)
|
| 45 |
+
print("Ground Truth Label: ", testset.get_shape(target[0].item()))
|
| 46 |
+
print("Predicted Label: ", testset.get_shape(torch.argmax(output[0]).item()))
|
| 47 |
+
display_open3d(points.detach().cpu().numpy()[0])
|
| 48 |
+
|
| 49 |
+
test_loss += loss_val.item()
|
| 50 |
+
count += output.size(0)
|
| 51 |
+
|
| 52 |
+
_, pred1 = output.max(dim=1)
|
| 53 |
+
ag = (pred1 == target)
|
| 54 |
+
am = ag.sum()
|
| 55 |
+
pred += am.item()
|
| 56 |
+
|
| 57 |
+
test_loss = float(test_loss)/count
|
| 58 |
+
accuracy = float(pred)/count
|
| 59 |
+
return test_loss, accuracy
|
| 60 |
+
|
| 61 |
+
def test(args, model, test_loader, testset):
|
| 62 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader, testset)
|
| 63 |
+
print("Accuracy: ", test_accuracy*100)
|
| 64 |
+
|
| 65 |
+
def options():
|
| 66 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 67 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 68 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 69 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 70 |
+
|
| 71 |
+
# settings for input data
|
| 72 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 73 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 74 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 75 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 76 |
+
|
| 77 |
+
# settings for CurveNet
|
| 78 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 79 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 80 |
+
parser.add_argument('-b', '--batch_size', default=32, type=int,
|
| 81 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 82 |
+
parser.add_argument('--num_classes', default=40, type=int,
|
| 83 |
+
metavar='K', help='number of classes to be predicted')
|
| 84 |
+
|
| 85 |
+
# settings for on training
|
| 86 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_curvenet/models/model.t7', type=str,
|
| 87 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 88 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 89 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 90 |
+
|
| 91 |
+
args = parser.parse_args()
|
| 92 |
+
return args
|
| 93 |
+
|
| 94 |
+
def main():
|
| 95 |
+
args = options()
|
| 96 |
+
args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
|
| 97 |
+
|
| 98 |
+
testset = ClassificationData(ModelNet40Data(train=False))
|
| 99 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 100 |
+
|
| 101 |
+
if not torch.cuda.is_available():
|
| 102 |
+
args.device = 'cpu'
|
| 103 |
+
args.device = torch.device(args.device)
|
| 104 |
+
|
| 105 |
+
# Create PointNet Model.
|
| 106 |
+
model = CurveNet(num_classes=args.num_classes, k=20)
|
| 107 |
+
|
| 108 |
+
if args.pretrained:
|
| 109 |
+
assert os.path.isfile(args.pretrained)
|
| 110 |
+
weights = torch.load(args.pretrained, map_location='cpu')
|
| 111 |
+
weights = {k[7:]: v for k, v in weights.items()}
|
| 112 |
+
model.load_state_dict(weights)
|
| 113 |
+
model.to(args.device)
|
| 114 |
+
|
| 115 |
+
test(args, model, test_loader, testset)
|
| 116 |
+
|
| 117 |
+
if __name__ == '__main__':
|
| 118 |
+
main()
|
thirdparty/learning3d/examples/test_dcp.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import DGCNN, DCP
|
| 22 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 23 |
+
|
| 24 |
+
def get_transformations(igt):
|
| 25 |
+
R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt
|
| 26 |
+
translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba
|
| 27 |
+
R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps
|
| 28 |
+
translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab
|
| 29 |
+
return R_ab, translation_ab, R_ba, translation_ba
|
| 30 |
+
|
| 31 |
+
def display_open3d(template, source, transformed_source):
|
| 32 |
+
template_ = o3d.geometry.PointCloud()
|
| 33 |
+
source_ = o3d.geometry.PointCloud()
|
| 34 |
+
transformed_source_ = o3d.geometry.PointCloud()
|
| 35 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 36 |
+
source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
|
| 37 |
+
transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
|
| 38 |
+
template_.paint_uniform_color([1, 0, 0])
|
| 39 |
+
source_.paint_uniform_color([0, 1, 0])
|
| 40 |
+
transformed_source_.paint_uniform_color([0, 0, 1])
|
| 41 |
+
o3d.visualization.draw_geometries([template_, source_, transformed_source_])
|
| 42 |
+
|
| 43 |
+
def test_one_epoch(device, model, test_loader):
|
| 44 |
+
model.eval()
|
| 45 |
+
test_loss = 0.0
|
| 46 |
+
pred = 0.0
|
| 47 |
+
count = 0
|
| 48 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 49 |
+
template, source, igt = data
|
| 50 |
+
transformations = get_transformations(igt)
|
| 51 |
+
transformations = [t.to(device) for t in transformations]
|
| 52 |
+
R_ab, translation_ab, R_ba, translation_ba = transformations
|
| 53 |
+
|
| 54 |
+
template = template.to(device)
|
| 55 |
+
source = source.to(device)
|
| 56 |
+
igt = igt.to(device)
|
| 57 |
+
|
| 58 |
+
output = model(template, source)
|
| 59 |
+
display_open3d(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], output['transformed_source'].detach().cpu().numpy()[0])
|
| 60 |
+
|
| 61 |
+
identity = torch.eye(3).cuda().unsqueeze(0).repeat(template.shape[0], 1, 1)
|
| 62 |
+
loss_val = torch.nn.functional.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), R_ab), identity) \
|
| 63 |
+
+ torch.nn.functional.mse_loss(output['est_t'], translation_ab[:,:,0])
|
| 64 |
+
|
| 65 |
+
cycle_loss = torch.nn.functional.mse_loss(torch.matmul(output['est_R_'].transpose(2, 1), R_ba), identity) \
|
| 66 |
+
+ torch.nn.functional.mse_loss(output['est_t_'], translation_ba[:,:,0])
|
| 67 |
+
loss_val = loss_val + cycle_loss * 0.1
|
| 68 |
+
|
| 69 |
+
test_loss += loss_val.item()
|
| 70 |
+
count += 1
|
| 71 |
+
|
| 72 |
+
test_loss = float(test_loss)/count
|
| 73 |
+
return test_loss
|
| 74 |
+
|
| 75 |
+
def test(args, model, test_loader):
|
| 76 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
|
| 77 |
+
|
| 78 |
+
def options():
|
| 79 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 80 |
+
parser.add_argument('--exp_name', type=str, default='exp_ipcrnet', metavar='N',
|
| 81 |
+
help='Name of the experiment')
|
| 82 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 83 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 84 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 85 |
+
|
| 86 |
+
# settings for input data
|
| 87 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 88 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 89 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 90 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 91 |
+
|
| 92 |
+
# settings for PointNet
|
| 93 |
+
parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
|
| 94 |
+
help='train pointnet (default: tune)')
|
| 95 |
+
parser.add_argument('--emb_dims', default=512, type=int,
|
| 96 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 97 |
+
parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
|
| 98 |
+
help='symmetric function (default: max)')
|
| 99 |
+
|
| 100 |
+
# settings for on training
|
| 101 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 102 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 103 |
+
parser.add_argument('-b', '--batch_size', default=2, type=int,
|
| 104 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 105 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_dcp/models/best_model.t7', type=str,
|
| 106 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 107 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 108 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 109 |
+
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
return args
|
| 112 |
+
|
| 113 |
+
def main():
|
| 114 |
+
args = options()
|
| 115 |
+
torch.backends.cudnn.deterministic = True
|
| 116 |
+
|
| 117 |
+
trainset = RegistrationData('DCP', ModelNet40Data(train=True))
|
| 118 |
+
testset = RegistrationData('DCP', ModelNet40Data(train=False))
|
| 119 |
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
|
| 120 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 121 |
+
|
| 122 |
+
if not torch.cuda.is_available():
|
| 123 |
+
args.device = 'cpu'
|
| 124 |
+
args.device = torch.device(args.device)
|
| 125 |
+
|
| 126 |
+
# Create PointNet Model.
|
| 127 |
+
dgcnn = DGCNN(emb_dims=args.emb_dims)
|
| 128 |
+
model = DCP(feature_model=dgcnn, cycle=True)
|
| 129 |
+
model = model.to(args.device)
|
| 130 |
+
|
| 131 |
+
if args.pretrained:
|
| 132 |
+
assert os.path.isfile(args.pretrained)
|
| 133 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'), strict=False)
|
| 134 |
+
model.to(args.device)
|
| 135 |
+
|
| 136 |
+
test(args, model, test_loader)
|
| 137 |
+
|
| 138 |
+
if __name__ == '__main__':
|
| 139 |
+
main()
|
thirdparty/learning3d/examples/test_deepgmr.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import DeepGMR
|
| 22 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 23 |
+
|
| 24 |
+
def display_open3d(template, source, transformed_source):
|
| 25 |
+
template_ = o3d.geometry.PointCloud()
|
| 26 |
+
source_ = o3d.geometry.PointCloud()
|
| 27 |
+
transformed_source_ = o3d.geometry.PointCloud()
|
| 28 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 29 |
+
source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
|
| 30 |
+
transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
|
| 31 |
+
template_.paint_uniform_color([1, 0, 0])
|
| 32 |
+
source_.paint_uniform_color([0, 1, 0])
|
| 33 |
+
transformed_source_.paint_uniform_color([0, 0, 1])
|
| 34 |
+
o3d.visualization.draw_geometries([template_, source_, transformed_source_])
|
| 35 |
+
|
| 36 |
+
def rotation_error(R, R_gt):
|
| 37 |
+
cos_theta = (torch.einsum('bij,bij->b', R, R_gt) - 1) / 2
|
| 38 |
+
cos_theta = torch.clamp(cos_theta, -1, 1)
|
| 39 |
+
return torch.acos(cos_theta) * 180 / math.pi
|
| 40 |
+
|
| 41 |
+
def translation_error(t, t_gt):
|
| 42 |
+
return torch.norm(t - t_gt, dim=1)
|
| 43 |
+
|
| 44 |
+
def rmse(pts, T, T_gt):
|
| 45 |
+
pts_pred = pts @ T[:, :3, :3].transpose(1, 2) + T[:, :3, 3].unsqueeze(1)
|
| 46 |
+
pts_gt = pts @ T_gt[:, :3, :3].transpose(1, 2) + T_gt[:, :3, 3].unsqueeze(1)
|
| 47 |
+
return torch.norm(pts_pred - pts_gt, dim=2).mean(dim=1)
|
| 48 |
+
|
| 49 |
+
def test_one_epoch(device, model, test_loader):
|
| 50 |
+
model.eval()
|
| 51 |
+
test_loss = 0.0
|
| 52 |
+
pred = 0.0
|
| 53 |
+
count = 0
|
| 54 |
+
rotation_errors, translation_errors, rmses = [], [], []
|
| 55 |
+
|
| 56 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 57 |
+
template, source, igt = data
|
| 58 |
+
|
| 59 |
+
template = template.to(device)
|
| 60 |
+
source = source.to(device)
|
| 61 |
+
igt = igt.to(device)
|
| 62 |
+
|
| 63 |
+
output = model(template, source)
|
| 64 |
+
display_open3d(template.detach().cpu().numpy()[0, :, :3], source.detach().cpu().numpy()[0, :, :3], output['transformed_source'].detach().cpu().numpy()[0])
|
| 65 |
+
|
| 66 |
+
eye = torch.eye(4).expand_as(igt).to(igt.device)
|
| 67 |
+
mse1 = F.mse_loss(output['est_T_inverse'] @ torch.inverse(igt), eye)
|
| 68 |
+
mse2 = F.mse_loss(output['est_T'] @ igt, eye)
|
| 69 |
+
loss = mse1 + mse2
|
| 70 |
+
|
| 71 |
+
r_err = rotation_error(est_T_inverse[:, :3, :3], igt[:, :3, :3])
|
| 72 |
+
t_err = translation_error(est_T_inverse[:, :3, 3], igt[:, :3, 3])
|
| 73 |
+
rmse_val = rmse(template[:, :100], est_T_inverse, igt)
|
| 74 |
+
rotation_errors.append(r_err)
|
| 75 |
+
translation_errors.append(t_err)
|
| 76 |
+
rmses.append(rmse_val)
|
| 77 |
+
|
| 78 |
+
test_loss += loss_val.item()
|
| 79 |
+
count += 1
|
| 80 |
+
|
| 81 |
+
test_loss = float(test_loss)/count
|
| 82 |
+
print("Mean rotation error: {}, Mean translation error: {} and Mean RMSE: {}".format(np.mean(rotation_errors), np.mean(translation_errors), np.mean(rmses)))
|
| 83 |
+
return test_loss
|
| 84 |
+
|
| 85 |
+
def test(args, model, test_loader):
|
| 86 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 87 |
+
|
| 88 |
+
def options():
|
| 89 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 90 |
+
parser.add_argument('--exp_name', type=str, default='exp_deepgmr', metavar='N',
|
| 91 |
+
help='Name of the experiment')
|
| 92 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 93 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 94 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 95 |
+
|
| 96 |
+
# settings for input data
|
| 97 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 98 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 99 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 100 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 101 |
+
|
| 102 |
+
parser.add_argument('--nearest_neighbors', default=20, type=int,
|
| 103 |
+
metavar='K', help='No of nearest neighbors to be estimated.')
|
| 104 |
+
parser.add_argument('--use_rri', default=True, type=bool,
|
| 105 |
+
help='Find nearest neighbors to estimate features from PointNet.')
|
| 106 |
+
|
| 107 |
+
# settings for on training
|
| 108 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 109 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 110 |
+
parser.add_argument('-b', '--batch_size', default=2, type=int,
|
| 111 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 112 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_deepgmr/models/best_model.pth', type=str,
|
| 113 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 114 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 115 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 116 |
+
|
| 117 |
+
args = parser.parse_args()
|
| 118 |
+
return args
|
| 119 |
+
|
| 120 |
+
def main():
|
| 121 |
+
args = options()
|
| 122 |
+
torch.backends.cudnn.deterministic = True
|
| 123 |
+
|
| 124 |
+
trainset = RegistrationData('DeepGMR', ModelNet40Data(train=True))
|
| 125 |
+
testset = RegistrationData('DeepGMR', ModelNet40Data(train=False))
|
| 126 |
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
|
| 127 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 128 |
+
|
| 129 |
+
if not torch.cuda.is_available():
|
| 130 |
+
args.device = 'cpu'
|
| 131 |
+
args.device = torch.device(args.device)
|
| 132 |
+
|
| 133 |
+
model = DeepGMR(use_rri=args.use_rri, nearest_neighbors=args.nearest_neighbors)
|
| 134 |
+
model = model.to(args.device)
|
| 135 |
+
|
| 136 |
+
if args.pretrained:
|
| 137 |
+
assert os.path.isfile(args.pretrained)
|
| 138 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'), strict=False)
|
| 139 |
+
model.to(args.device)
|
| 140 |
+
|
| 141 |
+
test(args, model, test_loader)
|
| 142 |
+
|
| 143 |
+
if __name__ == '__main__':
|
| 144 |
+
main()
|
thirdparty/learning3d/examples/test_flownet.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import open3d as o3d
|
| 6 |
+
import os
|
| 7 |
+
import gc
|
| 8 |
+
import argparse
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
| 14 |
+
from learning3d.models import FlowNet3D
|
| 15 |
+
from learning3d.data_utils import SceneflowDataset
|
| 16 |
+
import numpy as np
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from tensorboardX import SummaryWriter
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
def display_open3d(template, source, transformed_source):
|
| 22 |
+
template_ = o3d.geometry.PointCloud()
|
| 23 |
+
source_ = o3d.geometry.PointCloud()
|
| 24 |
+
transformed_source_ = o3d.geometry.PointCloud()
|
| 25 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 26 |
+
source_.points = o3d.utility.Vector3dVector(source + np.array([0,0.5,0.5]))
|
| 27 |
+
transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
|
| 28 |
+
template_.paint_uniform_color([1, 0, 0])
|
| 29 |
+
source_.paint_uniform_color([0, 1, 0])
|
| 30 |
+
transformed_source_.paint_uniform_color([0, 0, 1])
|
| 31 |
+
o3d.visualization.draw_geometries([template_, source_, transformed_source_])
|
| 32 |
+
|
| 33 |
+
def test_one_epoch(args, net, test_loader):
|
| 34 |
+
net.eval()
|
| 35 |
+
|
| 36 |
+
total_loss = 0
|
| 37 |
+
num_examples = 0
|
| 38 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 39 |
+
data = [d.to(args.device) for d in data]
|
| 40 |
+
pc1, pc2, color1, color2, flow, mask1 = data
|
| 41 |
+
pc1 = pc1.transpose(2,1).contiguous()
|
| 42 |
+
pc2 = pc2.transpose(2,1).contiguous()
|
| 43 |
+
color1 = color1.transpose(2,1).contiguous()
|
| 44 |
+
color2 = color2.transpose(2,1).contiguous()
|
| 45 |
+
flow = flow
|
| 46 |
+
mask1 = mask1.float()
|
| 47 |
+
|
| 48 |
+
batch_size = pc1.size(0)
|
| 49 |
+
num_examples += batch_size
|
| 50 |
+
flow_pred = net(pc1, pc2, color1, color2).permute(0,2,1)
|
| 51 |
+
loss_1 = torch.mean(mask1 * torch.sum((flow_pred - flow) * (flow_pred - flow), -1) / 2.0)
|
| 52 |
+
|
| 53 |
+
pc1, pc2 = pc1.permute(0,2,1), pc2.permute(0,2,1)
|
| 54 |
+
pc1_ = pc1 - flow_pred
|
| 55 |
+
print("Loss: ", loss_1)
|
| 56 |
+
display_open3d(pc1.detach().cpu().numpy()[0], pc2.detach().cpu().numpy()[0], pc1_.detach().cpu().numpy()[0])
|
| 57 |
+
total_loss += loss_1.item() * batch_size
|
| 58 |
+
|
| 59 |
+
return total_loss * 1.0 / num_examples
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def test(args, net, test_loader):
|
| 63 |
+
test_loss = test_one_epoch(args, net, test_loader)
|
| 64 |
+
|
| 65 |
+
def main():
|
| 66 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 67 |
+
parser.add_argument('--model', type=str, default='flownet', metavar='N',
|
| 68 |
+
choices=['flownet'], help='Model to use, [flownet]')
|
| 69 |
+
parser.add_argument('--emb_dims', type=int, default=512, metavar='N',
|
| 70 |
+
help='Dimension of embeddings')
|
| 71 |
+
parser.add_argument('--num_points', type=int, default=2048,
|
| 72 |
+
help='Point Number [default: 2048]')
|
| 73 |
+
parser.add_argument('--test_batch_size', type=int, default=1, metavar='batch_size',
|
| 74 |
+
help='Size of batch)')
|
| 75 |
+
|
| 76 |
+
parser.add_argument('--gaussian_noise', type=bool, default=False, metavar='N',
|
| 77 |
+
help='Wheter to add gaussian noise')
|
| 78 |
+
parser.add_argument('--unseen', type=bool, default=False, metavar='N',
|
| 79 |
+
help='Whether to test on unseen category')
|
| 80 |
+
parser.add_argument('--dataset', type=str, default='SceneflowDataset',
|
| 81 |
+
choices=['SceneflowDataset'], metavar='N',
|
| 82 |
+
help='dataset to use')
|
| 83 |
+
parser.add_argument('--dataset_path', type=str, default='data_processed_maxcut_35_20k_2k_8192', metavar='N',
|
| 84 |
+
help='dataset to use')
|
| 85 |
+
parser.add_argument('--pretrained', type=str, default='learning3d/pretrained/exp_flownet/models/model.best.t7', metavar='N',
|
| 86 |
+
help='Pretrained model path')
|
| 87 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 88 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 89 |
+
|
| 90 |
+
args = parser.parse_args()
|
| 91 |
+
if not torch.cuda.is_available():
|
| 92 |
+
args.device = torch.device('cpu')
|
| 93 |
+
else:
|
| 94 |
+
args.device = torch.device('cuda')
|
| 95 |
+
|
| 96 |
+
if args.dataset == 'SceneflowDataset':
|
| 97 |
+
test_loader = DataLoader(
|
| 98 |
+
SceneflowDataset(npoints=args.num_points, partition='test'),
|
| 99 |
+
batch_size=args.test_batch_size, shuffle=False, drop_last=False)
|
| 100 |
+
else:
|
| 101 |
+
raise Exception("not implemented")
|
| 102 |
+
|
| 103 |
+
net = FlowNet3D()
|
| 104 |
+
assert os.path.exists(args.pretrained), "Pretrained Model Doesn't Exists!"
|
| 105 |
+
net.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 106 |
+
net = net.to(args.device)
|
| 107 |
+
|
| 108 |
+
test(args, net, test_loader)
|
| 109 |
+
print('FINISH')
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == '__main__':
|
| 113 |
+
main()
|
thirdparty/learning3d/examples/test_masknet.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import MaskNet
|
| 22 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 23 |
+
|
| 24 |
+
def pc2open3d(data):
|
| 25 |
+
if torch.is_tensor(data): data = data.detach().cpu().numpy()
|
| 26 |
+
if len(data.shape) == 2:
|
| 27 |
+
pc = o3d.geometry.PointCloud()
|
| 28 |
+
pc.points = o3d.utility.Vector3dVector(data)
|
| 29 |
+
return pc
|
| 30 |
+
else:
|
| 31 |
+
print("Error in the shape of data given to Open3D!, Shape is ", data.shape)
|
| 32 |
+
|
| 33 |
+
def display_results(template, source, masked_template):
|
| 34 |
+
template = pc2open3d(template)
|
| 35 |
+
source = pc2open3d(source)
|
| 36 |
+
masked_template = pc2open3d(masked_template)
|
| 37 |
+
|
| 38 |
+
template.paint_uniform_color([1, 0, 0])
|
| 39 |
+
source.paint_uniform_color([0, 1, 0])
|
| 40 |
+
masked_template.paint_uniform_color([0, 0, 1])
|
| 41 |
+
|
| 42 |
+
o3d.visualization.draw_geometries([template, source])
|
| 43 |
+
o3d.visualization.draw_geometries([masked_template, source])
|
| 44 |
+
|
| 45 |
+
def evaluate_metrics(TP, FP, FN, TN, gt_mask):
|
| 46 |
+
# TP, FP, FN, TN: True +ve, False +ve, False -ve, True -ve
|
| 47 |
+
# gt_mask: Ground Truth mask [Nt, 1]
|
| 48 |
+
|
| 49 |
+
accuracy = (TP + TN)/gt_mask.shape[1]
|
| 50 |
+
misclassification_rate = (FN + FP)/gt_mask.shape[1]
|
| 51 |
+
# Precision: (What portion of positive identifications are actually correct?)
|
| 52 |
+
precision = TP / (TP + FP)
|
| 53 |
+
# Recall: (What portion of actual positives are identified correctly?)
|
| 54 |
+
recall = TP / (TP + FN)
|
| 55 |
+
|
| 56 |
+
fscore = (2*precision*recall) / (precision + recall)
|
| 57 |
+
return accuracy, precision, recall, fscore
|
| 58 |
+
|
| 59 |
+
# Function used to evaluate the predicted mask with ground truth mask.
|
| 60 |
+
def evaluate_mask(gt_mask, predicted_mask, predicted_mask_idx):
|
| 61 |
+
# gt_mask: Ground Truth Mask [Nt, 1]
|
| 62 |
+
# predicted_mask: Mask predicted by network [Nt, 1]
|
| 63 |
+
# predicted_mask_idx: Point indices chosen by network [Ns, 1]
|
| 64 |
+
|
| 65 |
+
if torch.is_tensor(gt_mask): gt_mask = gt_mask.detach().cpu().numpy()
|
| 66 |
+
if torch.is_tensor(gt_mask): predicted_mask = predicted_mask.detach().cpu().numpy()
|
| 67 |
+
if torch.is_tensor(predicted_mask_idx): predicted_mask_idx = predicted_mask_idx.detach().cpu().numpy()
|
| 68 |
+
gt_mask, predicted_mask, predicted_mask_idx = gt_mask.reshape(1,-1), predicted_mask.reshape(1,-1), predicted_mask_idx.reshape(1,-1)
|
| 69 |
+
|
| 70 |
+
gt_idx = np.where(gt_mask == 1)[1].reshape(1,-1) # Find indices of points which are actually in source.
|
| 71 |
+
|
| 72 |
+
# TP + FP = number of source points.
|
| 73 |
+
TP = np.intersect1d(predicted_mask_idx[0], gt_idx[0]).shape[0] # is inliner and predicted as inlier (True Positive) (Find common indices in predicted_mask_idx, gt_idx)
|
| 74 |
+
FP = len([x for x in predicted_mask_idx[0] if x not in gt_idx]) # isn't inlier but predicted as inlier (False Positive)
|
| 75 |
+
FN = FP # is inlier but predicted as outlier (False Negative) (due to binary classification)
|
| 76 |
+
TN = gt_mask.shape[1] - gt_idx.shape[1] - FN # is outlier and predicted as outlier (True Negative)
|
| 77 |
+
return evaluate_metrics(TP, FP, FN, TN, gt_mask)
|
| 78 |
+
|
| 79 |
+
def test_one_epoch(args, model, test_loader):
|
| 80 |
+
model.eval()
|
| 81 |
+
test_loss = 0.0
|
| 82 |
+
pred = 0.0
|
| 83 |
+
count = 0
|
| 84 |
+
precision_list = []
|
| 85 |
+
|
| 86 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 87 |
+
template, source, igt, gt_mask = data
|
| 88 |
+
|
| 89 |
+
template = template.to(args.device)
|
| 90 |
+
source = source.to(args.device)
|
| 91 |
+
igt = igt.to(args.device) # [source] = [igt]*[template]
|
| 92 |
+
gt_mask = gt_mask.to(args.device)
|
| 93 |
+
|
| 94 |
+
masked_template, predicted_mask = model(template, source)
|
| 95 |
+
|
| 96 |
+
# Evaluate mask based on classification metrics.
|
| 97 |
+
accuracy, precision, recall, fscore = evaluate_mask(gt_mask, predicted_mask, predicted_mask_idx = model.mask_idx)
|
| 98 |
+
precision_list.append(precision)
|
| 99 |
+
|
| 100 |
+
# Different ways to visualize results.
|
| 101 |
+
display_results(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], masked_template.detach().cpu().numpy()[0])
|
| 102 |
+
|
| 103 |
+
print("Mean Precision: ", np.mean(precision_list))
|
| 104 |
+
|
| 105 |
+
def test(args, model, test_loader):
|
| 106 |
+
test_one_epoch(args, model, test_loader)
|
| 107 |
+
|
| 108 |
+
def options():
|
| 109 |
+
parser = argparse.ArgumentParser(description='MaskNet: A Fully-Convolutional Network For Inlier Estimation (Testing)')
|
| 110 |
+
|
| 111 |
+
# settings for input data
|
| 112 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 113 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 114 |
+
parser.add_argument('--partial_source', default=True, type=bool,
|
| 115 |
+
help='create partial source point cloud in dataset.')
|
| 116 |
+
parser.add_argument('--noise', default=False, type=bool,
|
| 117 |
+
help='Add noise in source point clouds.')
|
| 118 |
+
parser.add_argument('--outliers', default=False, type=bool,
|
| 119 |
+
help='Add outliers to template point cloud.')
|
| 120 |
+
|
| 121 |
+
# settings for on testing
|
| 122 |
+
parser.add_argument('-j', '--workers', default=1, type=int,
|
| 123 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 124 |
+
parser.add_argument('-b', '--test_batch_size', default=1, type=int,
|
| 125 |
+
metavar='N', help='test-mini-batch size (default: 1)')
|
| 126 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_masknet/models/best_model.t7', type=str,
|
| 127 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 128 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 129 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 130 |
+
parser.add_argument('--unseen', default=False, type=bool,
|
| 131 |
+
help='Use first 20 categories for training and last 20 for testing')
|
| 132 |
+
|
| 133 |
+
args = parser.parse_args()
|
| 134 |
+
return args
|
| 135 |
+
|
| 136 |
+
def main():
|
| 137 |
+
args = options()
|
| 138 |
+
torch.backends.cudnn.deterministic = True
|
| 139 |
+
|
| 140 |
+
testset = RegistrationData('PointNetLK', ModelNet40Data(train=False, num_points=args.num_points),
|
| 141 |
+
partial_source=args.partial_source, noise=args.noise,
|
| 142 |
+
additional_params={'use_masknet': True})
|
| 143 |
+
test_loader = DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 144 |
+
|
| 145 |
+
if not torch.cuda.is_available():
|
| 146 |
+
args.device = 'cpu'
|
| 147 |
+
args.device = torch.device(args.device)
|
| 148 |
+
|
| 149 |
+
# Load Pretrained MaskNet.
|
| 150 |
+
model = MaskNet()
|
| 151 |
+
if args.pretrained:
|
| 152 |
+
assert os.path.isfile(args.pretrained)
|
| 153 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 154 |
+
model = model.to(args.device)
|
| 155 |
+
|
| 156 |
+
test(args, model, test_loader)
|
| 157 |
+
|
| 158 |
+
if __name__ == '__main__':
|
| 159 |
+
main()
|
thirdparty/learning3d/examples/test_masknet2.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import numpy
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
# Only if the files are in example folder.
|
| 13 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
if BASE_DIR[-8:] == 'examples':
|
| 15 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 16 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 17 |
+
|
| 18 |
+
from learning3d.models import MaskNet2
|
| 19 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 20 |
+
|
| 21 |
+
def pc2open3d(data):
|
| 22 |
+
if torch.is_tensor(data): data = data.detach().cpu().numpy()
|
| 23 |
+
if len(data.shape) == 2:
|
| 24 |
+
pc = o3d.geometry.PointCloud()
|
| 25 |
+
pc.points = o3d.utility.Vector3dVector(data)
|
| 26 |
+
return pc
|
| 27 |
+
else:
|
| 28 |
+
print("Error in the shape of data given to Open3D!, Shape is ", data.shape)
|
| 29 |
+
|
| 30 |
+
def display_results(template, source, masked_template, masked_source):
|
| 31 |
+
template = pc2open3d(template)
|
| 32 |
+
source = pc2open3d(source)
|
| 33 |
+
masked_template = pc2open3d(masked_template)
|
| 34 |
+
masked_source = pc2open3d(masked_source)
|
| 35 |
+
|
| 36 |
+
template.paint_uniform_color([1, 0, 0])
|
| 37 |
+
source.paint_uniform_color([0, 1, 0])
|
| 38 |
+
# masked_template.paint_uniform_color([0, 0, 1])
|
| 39 |
+
masked_template.paint_uniform_color([1, 0, 0])
|
| 40 |
+
masked_source.paint_uniform_color([0, 1, 0])
|
| 41 |
+
|
| 42 |
+
o3d.visualization.draw_geometries([template, source])
|
| 43 |
+
o3d.visualization.draw_geometries([masked_template, masked_source])
|
| 44 |
+
|
| 45 |
+
def evaluate_metrics(TP, FP, FN, TN, gt_mask):
|
| 46 |
+
# TP, FP, FN, TN: True +ve, False +ve, False -ve, True -ve
|
| 47 |
+
# gt_mask: Ground Truth mask [Nt, 1]
|
| 48 |
+
|
| 49 |
+
accuracy = (TP + TN)/gt_mask.shape[1]
|
| 50 |
+
misclassification_rate = (FN + FP)/gt_mask.shape[1]
|
| 51 |
+
# Precision: (What portion of positive identifications are actually correct?)
|
| 52 |
+
precision = TP / (TP + FP)
|
| 53 |
+
# Recall: (What portion of actual positives are identified correctly?)
|
| 54 |
+
recall = TP / (TP + FN)
|
| 55 |
+
|
| 56 |
+
fscore = (2*precision*recall) / (precision + recall)
|
| 57 |
+
return accuracy, precision, recall, fscore
|
| 58 |
+
|
| 59 |
+
# Function used to evaluate the predicted mask with ground truth mask.
|
| 60 |
+
def evaluate_mask(gt_mask, predicted_mask, predicted_mask_idx):
|
| 61 |
+
# gt_mask: Ground Truth Mask [Nt, 1]
|
| 62 |
+
# predicted_mask: Mask predicted by network [Nt, 1]
|
| 63 |
+
# predicted_mask_idx: Point indices chosen by network [Ns, 1]
|
| 64 |
+
|
| 65 |
+
if torch.is_tensor(gt_mask): gt_mask = gt_mask.detach().cpu().numpy()
|
| 66 |
+
if torch.is_tensor(gt_mask): predicted_mask = predicted_mask.detach().cpu().numpy()
|
| 67 |
+
if torch.is_tensor(predicted_mask_idx): predicted_mask_idx = predicted_mask_idx.detach().cpu().numpy()
|
| 68 |
+
gt_mask, predicted_mask, predicted_mask_idx = gt_mask.reshape(1,-1), predicted_mask.reshape(1,-1), predicted_mask_idx.reshape(1,-1)
|
| 69 |
+
|
| 70 |
+
gt_idx = np.where(gt_mask == 1)[1].reshape(1,-1) # Find indices of points which are actually in source.
|
| 71 |
+
|
| 72 |
+
# TP + FP = number of source points.
|
| 73 |
+
TP = np.intersect1d(predicted_mask_idx[0], gt_idx[0]).shape[0] # is inliner and predicted as inlier (True Positive) (Find common indices in predicted_mask_idx, gt_idx)
|
| 74 |
+
FP = len([x for x in predicted_mask_idx[0] if x not in gt_idx]) # isn't inlier but predicted as inlier (False Positive)
|
| 75 |
+
FN = FP # is inlier but predicted as outlier (False Negative) (due to binary classification)
|
| 76 |
+
TN = gt_mask.shape[1] - gt_idx.shape[1] - FN # is outlier and predicted as outlier (True Negative)
|
| 77 |
+
return evaluate_metrics(TP, FP, FN, TN, gt_mask)
|
| 78 |
+
|
| 79 |
+
def test_one_epoch(args, model, test_loader):
|
| 80 |
+
model.eval()
|
| 81 |
+
test_loss = 0.0
|
| 82 |
+
pred = 0.0
|
| 83 |
+
count = 0
|
| 84 |
+
|
| 85 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 86 |
+
template, source, igt, gt_template_mask, gt_source_mask = data
|
| 87 |
+
|
| 88 |
+
template = template.to(args.device)
|
| 89 |
+
source = source.to(args.device)
|
| 90 |
+
igt = igt.to(args.device) # [source] = [igt]*[template]
|
| 91 |
+
gt_template_mask = gt_template_mask.to(args.device)
|
| 92 |
+
gt_source_mask = gt_source_mask.to(args.device)
|
| 93 |
+
|
| 94 |
+
masked_template, masked_source, template_mask, source_mask = model(template, source)
|
| 95 |
+
|
| 96 |
+
# TODO: Implement evaluation strategy.
|
| 97 |
+
'''
|
| 98 |
+
Evaluate mask based on classification metrics.
|
| 99 |
+
accuracy, precision, recall, fscore = evaluate_mask(gt_template_mask, template_mask, predicted_mask_idx = model.mask_idx)
|
| 100 |
+
precision_list.append(precision)
|
| 101 |
+
'''
|
| 102 |
+
|
| 103 |
+
# Different ways to visualize results.
|
| 104 |
+
display_results(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], masked_template.detach().cpu().numpy()[0], masked_source.detach().cpu().numpy()[0])
|
| 105 |
+
|
| 106 |
+
def test(args, model, test_loader):
|
| 107 |
+
test_one_epoch(args, model, test_loader)
|
| 108 |
+
|
| 109 |
+
def options():
|
| 110 |
+
parser = argparse.ArgumentParser(description='MaskNet: A Fully-Convolutional Network For Inlier Estimation (Testing)')
|
| 111 |
+
|
| 112 |
+
# settings for input data
|
| 113 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 114 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 115 |
+
parser.add_argument('--partial_source', default=True, type=bool,
|
| 116 |
+
help='create partial source point cloud in dataset.')
|
| 117 |
+
parser.add_argument('--partial_template', default=True, type=bool,
|
| 118 |
+
help='create partial source point cloud in dataset.')
|
| 119 |
+
parser.add_argument('--noise', default=False, type=bool,
|
| 120 |
+
help='Add noise in source point clouds.')
|
| 121 |
+
parser.add_argument('--outliers', default=False, type=bool,
|
| 122 |
+
help='Add outliers to template point cloud.')
|
| 123 |
+
|
| 124 |
+
# settings for on testing
|
| 125 |
+
parser.add_argument('-j', '--workers', default=1, type=int,
|
| 126 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 127 |
+
parser.add_argument('-b', '--test_batch_size', default=1, type=int,
|
| 128 |
+
metavar='N', help='test-mini-batch size (default: 1)')
|
| 129 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_masknet2/models/best_model_0.7.t7', type=str,
|
| 130 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 131 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 132 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 133 |
+
parser.add_argument('--unseen', default=False, type=bool,
|
| 134 |
+
help='Use first 20 categories for training and last 20 for testing')
|
| 135 |
+
|
| 136 |
+
args = parser.parse_args()
|
| 137 |
+
return args
|
| 138 |
+
|
| 139 |
+
def main():
|
| 140 |
+
args = options()
|
| 141 |
+
torch.backends.cudnn.deterministic = True
|
| 142 |
+
|
| 143 |
+
testset = RegistrationData('PointNetLK', ModelNet40Data(train=False, num_points=args.num_points),
|
| 144 |
+
partial_template=args.partial_template, partial_source=args.partial_source,
|
| 145 |
+
noise=args.noise, additional_params={'use_masknet': True, 'partial_point_cloud_method': 'planar_crop'})
|
| 146 |
+
test_loader = DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 147 |
+
|
| 148 |
+
if not torch.cuda.is_available():
|
| 149 |
+
args.device = 'cpu'
|
| 150 |
+
args.device = torch.device(args.device)
|
| 151 |
+
|
| 152 |
+
# Load Pretrained MaskNet.
|
| 153 |
+
model = MaskNet2()
|
| 154 |
+
if args.pretrained:
|
| 155 |
+
assert os.path.isfile(args.pretrained)
|
| 156 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 157 |
+
model = model.to(args.device)
|
| 158 |
+
|
| 159 |
+
test(args, model, test_loader)
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
main()
|
thirdparty/learning3d/examples/test_pcn.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# author: Vinit Sarode (vinitsarode5@gmail.com) 03/23/2020
|
| 2 |
+
|
| 3 |
+
import open3d as o3d
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import logging
|
| 8 |
+
import numpy
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.utils.data
|
| 12 |
+
import torchvision
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
from tensorboardX import SummaryWriter
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
# Only if the files are in example folder.
|
| 18 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
+
if BASE_DIR[-8:] == 'examples':
|
| 20 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 21 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 22 |
+
|
| 23 |
+
from learning3d.models import PCN
|
| 24 |
+
from learning3d.data_utils import ModelNet40Data, ClassificationData
|
| 25 |
+
from learning3d.losses import ChamferDistanceLoss
|
| 26 |
+
|
| 27 |
+
def display_open3d(input_pc, output):
|
| 28 |
+
input_pc_ = o3d.geometry.PointCloud()
|
| 29 |
+
output_ = o3d.geometry.PointCloud()
|
| 30 |
+
input_pc_.points = o3d.utility.Vector3dVector(input_pc)
|
| 31 |
+
output_.points = o3d.utility.Vector3dVector(output + np.array([1,0,0]))
|
| 32 |
+
input_pc_.paint_uniform_color([1, 0, 0])
|
| 33 |
+
output_.paint_uniform_color([0, 1, 0])
|
| 34 |
+
o3d.visualization.draw_geometries([input_pc_, output_])
|
| 35 |
+
|
| 36 |
+
def test_one_epoch(device, model, test_loader):
|
| 37 |
+
model.eval()
|
| 38 |
+
test_loss = 0.0
|
| 39 |
+
pred = 0.0
|
| 40 |
+
count = 0
|
| 41 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 42 |
+
points, _ = data
|
| 43 |
+
|
| 44 |
+
points = points.to(device)
|
| 45 |
+
|
| 46 |
+
output = model(points)
|
| 47 |
+
loss_val = ChamferDistanceLoss()(points, output['coarse_output'])
|
| 48 |
+
print("Loss Val: ", loss_val)
|
| 49 |
+
display_open3d(points[0].detach().cpu().numpy(), output['coarse_output'][0].detach().cpu().numpy())
|
| 50 |
+
|
| 51 |
+
test_loss += loss_val.item()
|
| 52 |
+
count += 1
|
| 53 |
+
|
| 54 |
+
test_loss = float(test_loss)/count
|
| 55 |
+
return test_loss
|
| 56 |
+
|
| 57 |
+
def test(args, model, test_loader):
|
| 58 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 59 |
+
|
| 60 |
+
def options():
|
| 61 |
+
parser = argparse.ArgumentParser(description='Point Completion Network')
|
| 62 |
+
parser.add_argument('--exp_name', type=str, default='exp_pcn', metavar='N',
|
| 63 |
+
help='Name of the experiment')
|
| 64 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 65 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 66 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 67 |
+
|
| 68 |
+
# settings for input data
|
| 69 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 70 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 71 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 72 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 73 |
+
|
| 74 |
+
# settings for PCN
|
| 75 |
+
parser.add_argument('--emb_dims', default=1024, type=int,
|
| 76 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 77 |
+
parser.add_argument('--detailed_output', default=False, type=bool,
|
| 78 |
+
help='Coarse + Fine Output')
|
| 79 |
+
|
| 80 |
+
# settings for on training
|
| 81 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 82 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 83 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 84 |
+
parser.add_argument('-b', '--batch_size', default=32, type=int,
|
| 85 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 86 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_pcn/models/best_model.t7', type=str,
|
| 87 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 88 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 89 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 90 |
+
|
| 91 |
+
args = parser.parse_args()
|
| 92 |
+
return args
|
| 93 |
+
|
| 94 |
+
def main():
|
| 95 |
+
args = options()
|
| 96 |
+
args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
|
| 97 |
+
|
| 98 |
+
trainset = ClassificationData(ModelNet40Data(train=True))
|
| 99 |
+
testset = ClassificationData(ModelNet40Data(train=False))
|
| 100 |
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
|
| 101 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 102 |
+
|
| 103 |
+
if not torch.cuda.is_available():
|
| 104 |
+
args.device = 'cpu'
|
| 105 |
+
args.device = torch.device(args.device)
|
| 106 |
+
|
| 107 |
+
# Create PointNet Model.
|
| 108 |
+
model = PCN(emb_dims=args.emb_dims, detailed_output=args.detailed_output)
|
| 109 |
+
|
| 110 |
+
if args.pretrained:
|
| 111 |
+
assert os.path.isfile(args.pretrained)
|
| 112 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 113 |
+
model.to(args.device)
|
| 114 |
+
|
| 115 |
+
test(args, model, test_loader)
|
| 116 |
+
|
| 117 |
+
if __name__ == '__main__':
|
| 118 |
+
main()
|
thirdparty/learning3d/examples/test_pcrnet.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import PointNet, iPCRNet
|
| 22 |
+
from learning3d.losses import ChamferDistanceLoss
|
| 23 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def display_open3d(template, source, transformed_source):
|
| 27 |
+
template_ = o3d.geometry.PointCloud()
|
| 28 |
+
source_ = o3d.geometry.PointCloud()
|
| 29 |
+
transformed_source_ = o3d.geometry.PointCloud()
|
| 30 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 31 |
+
source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
|
| 32 |
+
transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
|
| 33 |
+
template_.paint_uniform_color([1, 0, 0])
|
| 34 |
+
source_.paint_uniform_color([0, 1, 0])
|
| 35 |
+
transformed_source_.paint_uniform_color([0, 0, 1])
|
| 36 |
+
o3d.visualization.draw_geometries([template_, source_, transformed_source_])
|
| 37 |
+
|
| 38 |
+
def test_one_epoch(device, model, test_loader):
|
| 39 |
+
model.eval()
|
| 40 |
+
test_loss = 0.0
|
| 41 |
+
pred = 0.0
|
| 42 |
+
count = 0
|
| 43 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 44 |
+
template, source, igt = data
|
| 45 |
+
|
| 46 |
+
template = template.to(device)
|
| 47 |
+
source = source.to(device)
|
| 48 |
+
igt = igt.to(device)
|
| 49 |
+
|
| 50 |
+
output = model(template, source)
|
| 51 |
+
display_open3d(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], output['transformed_source'].detach().cpu().numpy()[0])
|
| 52 |
+
loss_val = ChamferDistanceLoss()(template, output['transformed_source'])
|
| 53 |
+
|
| 54 |
+
test_loss += loss_val.item()
|
| 55 |
+
count += 1
|
| 56 |
+
|
| 57 |
+
test_loss = float(test_loss)/count
|
| 58 |
+
return test_loss
|
| 59 |
+
|
| 60 |
+
def test(args, model, test_loader):
|
| 61 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def options():
|
| 65 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 66 |
+
parser.add_argument('--exp_name', type=str, default='exp_ipcrnet', metavar='N',
|
| 67 |
+
help='Name of the experiment')
|
| 68 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 69 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 70 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 71 |
+
|
| 72 |
+
# settings for input data
|
| 73 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 74 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 75 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 76 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 77 |
+
|
| 78 |
+
# settings for PointNet
|
| 79 |
+
parser.add_argument('--emb_dims', default=1024, type=int,
|
| 80 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 81 |
+
parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
|
| 82 |
+
help='symmetric function (default: max)')
|
| 83 |
+
|
| 84 |
+
# settings for on training
|
| 85 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 86 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 87 |
+
parser.add_argument('-b', '--batch_size', default=20, type=int,
|
| 88 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 89 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_ipcrnet/models/best_model.t7', type=str,
|
| 90 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 91 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 92 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 93 |
+
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
return args
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
args = options()
|
| 99 |
+
|
| 100 |
+
testset = RegistrationData('PCRNet', ModelNet40Data(train=False))
|
| 101 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 102 |
+
|
| 103 |
+
if not torch.cuda.is_available():
|
| 104 |
+
args.device = 'cpu'
|
| 105 |
+
args.device = torch.device(args.device)
|
| 106 |
+
|
| 107 |
+
# Create PointNet Model.
|
| 108 |
+
ptnet = PointNet(emb_dims=args.emb_dims)
|
| 109 |
+
model = iPCRNet(feature_model=ptnet)
|
| 110 |
+
model = model.to(args.device)
|
| 111 |
+
|
| 112 |
+
if args.pretrained:
|
| 113 |
+
assert os.path.isfile(args.pretrained)
|
| 114 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 115 |
+
model.to(args.device)
|
| 116 |
+
|
| 117 |
+
test(args, model, test_loader)
|
| 118 |
+
|
| 119 |
+
if __name__ == '__main__':
|
| 120 |
+
main()
|
thirdparty/learning3d/examples/test_pnlk.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import PointNet, PointNetLK
|
| 22 |
+
from learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
|
| 23 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 24 |
+
|
| 25 |
+
def display_open3d(template, source, transformed_source):
|
| 26 |
+
template_ = o3d.geometry.PointCloud()
|
| 27 |
+
source_ = o3d.geometry.PointCloud()
|
| 28 |
+
transformed_source_ = o3d.geometry.PointCloud()
|
| 29 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 30 |
+
source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
|
| 31 |
+
transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
|
| 32 |
+
template_.paint_uniform_color([1, 0, 0])
|
| 33 |
+
source_.paint_uniform_color([0, 1, 0])
|
| 34 |
+
transformed_source_.paint_uniform_color([0, 0, 1])
|
| 35 |
+
o3d.visualization.draw_geometries([template_, source_, transformed_source_])
|
| 36 |
+
|
| 37 |
+
def test_one_epoch(device, model, test_loader):
|
| 38 |
+
model.eval()
|
| 39 |
+
test_loss = 0.0
|
| 40 |
+
pred = 0.0
|
| 41 |
+
count = 0
|
| 42 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 43 |
+
template, source, igt = data
|
| 44 |
+
|
| 45 |
+
template = template.to(device)
|
| 46 |
+
source = source.to(device)
|
| 47 |
+
igt = igt.to(device)
|
| 48 |
+
|
| 49 |
+
output = model(template, source)
|
| 50 |
+
|
| 51 |
+
display_open3d(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], output['transformed_source'].detach().cpu().numpy()[0])
|
| 52 |
+
loss_val = FrobeniusNormLoss()(output['est_T'], igt) + RMSEFeaturesLoss()(output['r'])
|
| 53 |
+
|
| 54 |
+
test_loss += loss_val.item()
|
| 55 |
+
count += 1
|
| 56 |
+
|
| 57 |
+
test_loss = float(test_loss)/count
|
| 58 |
+
return test_loss
|
| 59 |
+
|
| 60 |
+
def test(args, model, test_loader):
|
| 61 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def options():
|
| 65 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 66 |
+
parser.add_argument('--exp_name', type=str, default='exp_pnlk_v1', metavar='N',
|
| 67 |
+
help='Name of the experiment')
|
| 68 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 69 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 70 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 71 |
+
|
| 72 |
+
# settings for input data
|
| 73 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 74 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 75 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 76 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 77 |
+
|
| 78 |
+
# settings for PointNet
|
| 79 |
+
parser.add_argument('--emb_dims', default=1024, type=int,
|
| 80 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 81 |
+
parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
|
| 82 |
+
help='symmetric function (default: max)')
|
| 83 |
+
|
| 84 |
+
# settings for on training
|
| 85 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 86 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 87 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 88 |
+
parser.add_argument('-b', '--batch_size', default=10, type=int,
|
| 89 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 90 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_pnlk/models/best_model.t7', type=str,
|
| 91 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 92 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 93 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 94 |
+
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
return args
|
| 97 |
+
|
| 98 |
+
def main():
|
| 99 |
+
args = options()
|
| 100 |
+
|
| 101 |
+
testset = RegistrationData('PointNetLK', ModelNet40Data(train=False))
|
| 102 |
+
test_loader = DataLoader(testset, batch_size=8, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 103 |
+
|
| 104 |
+
if not torch.cuda.is_available():
|
| 105 |
+
args.device = 'cpu'
|
| 106 |
+
args.device = torch.device(args.device)
|
| 107 |
+
|
| 108 |
+
# Create PointNet Model.
|
| 109 |
+
ptnet = PointNet(emb_dims=args.emb_dims, use_bn=True)
|
| 110 |
+
model = PointNetLK(feature_model=ptnet)
|
| 111 |
+
model = model.to(args.device)
|
| 112 |
+
|
| 113 |
+
if args.pretrained:
|
| 114 |
+
assert os.path.isfile(args.pretrained)
|
| 115 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 116 |
+
model.to(args.device)
|
| 117 |
+
|
| 118 |
+
test(args, model, test_loader)
|
| 119 |
+
|
| 120 |
+
if __name__ == '__main__':
|
| 121 |
+
main()
|
thirdparty/learning3d/examples/test_pointconv.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import create_pointconv
|
| 22 |
+
from learning3d.models import Classifier
|
| 23 |
+
from learning3d.data_utils import ClassificationData, ModelNet40Data
|
| 24 |
+
|
| 25 |
+
def display_open3d(template):
|
| 26 |
+
template_ = o3d.geometry.PointCloud()
|
| 27 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 28 |
+
# template_.paint_uniform_color([1, 0, 0])
|
| 29 |
+
o3d.visualization.draw_geometries([template_])
|
| 30 |
+
|
| 31 |
+
def test_one_epoch(device, model, test_loader, testset):
|
| 32 |
+
model.eval()
|
| 33 |
+
test_loss = 0.0
|
| 34 |
+
pred = 0.0
|
| 35 |
+
count = 0
|
| 36 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 37 |
+
points, target = data
|
| 38 |
+
target = target[:,0]
|
| 39 |
+
|
| 40 |
+
points = points.to(device)
|
| 41 |
+
target = target.to(device)
|
| 42 |
+
|
| 43 |
+
output = model(points)
|
| 44 |
+
loss_val = torch.nn.functional.nll_loss(
|
| 45 |
+
torch.nn.functional.log_softmax(output, dim=1), target, size_average=False)
|
| 46 |
+
print("Ground Truth Label: ", testset.get_shape(target[0].item()))
|
| 47 |
+
print("Predicted Label: ", testset.get_shape(torch.argmax(output[0]).item()))
|
| 48 |
+
display_open3d(points.detach().cpu().numpy()[0])
|
| 49 |
+
|
| 50 |
+
test_loss += loss_val.item()
|
| 51 |
+
count += output.size(0)
|
| 52 |
+
|
| 53 |
+
_, pred1 = output.max(dim=1)
|
| 54 |
+
ag = (pred1 == target)
|
| 55 |
+
am = ag.sum()
|
| 56 |
+
pred += am.item()
|
| 57 |
+
|
| 58 |
+
test_loss = float(test_loss)/count
|
| 59 |
+
accuracy = float(pred)/count
|
| 60 |
+
return test_loss, accuracy
|
| 61 |
+
|
| 62 |
+
def test(args, model, test_loader, testset):
|
| 63 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader, testset)
|
| 64 |
+
|
| 65 |
+
def options():
|
| 66 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 67 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 68 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 69 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 70 |
+
|
| 71 |
+
# settings for input data
|
| 72 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 73 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 74 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 75 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 76 |
+
|
| 77 |
+
# settings for PointNet
|
| 78 |
+
parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
|
| 79 |
+
help='train pointnet (default: tune)')
|
| 80 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 81 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 82 |
+
parser.add_argument('-b', '--batch_size', default=32, type=int,
|
| 83 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 84 |
+
parser.add_argument('--emb_dims', default=1024, type=int,
|
| 85 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 86 |
+
parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
|
| 87 |
+
help='symmetric function (default: max)')
|
| 88 |
+
|
| 89 |
+
# settings for on training
|
| 90 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_classifier/models/best_model.t7', type=str,
|
| 91 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 92 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 93 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 94 |
+
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
return args
|
| 97 |
+
|
| 98 |
+
def main():
|
| 99 |
+
args = options()
|
| 100 |
+
args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
|
| 101 |
+
|
| 102 |
+
testset = ClassificationData(ModelNet40Data(train=False))
|
| 103 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 104 |
+
|
| 105 |
+
if not torch.cuda.is_available():
|
| 106 |
+
args.device = 'cpu'
|
| 107 |
+
args.device = torch.device(args.device)
|
| 108 |
+
|
| 109 |
+
# To use pretrained model provided by authors.
|
| 110 |
+
# PointConv = create_pointconv(classifier=True, pretrained='path of pretrained model.')
|
| 111 |
+
# model = PointConv(emb_dims=args.emb_dims, classifier=True, pretrained='path of pretrained model.')
|
| 112 |
+
|
| 113 |
+
# To use your own pretrained model.
|
| 114 |
+
PointConv = create_pointconv(classifier=False, pretrained=None)
|
| 115 |
+
ptconv = PointConv(emb_dims=args.emb_dims, classifier=True, pretrained=None)
|
| 116 |
+
model = Classifier(feature_model=ptconv)
|
| 117 |
+
|
| 118 |
+
if args.pretrained:
|
| 119 |
+
assert os.path.isfile(args.pretrained)
|
| 120 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 121 |
+
model.to(args.device)
|
| 122 |
+
|
| 123 |
+
test(args, model, test_loader, testset)
|
| 124 |
+
|
| 125 |
+
if __name__ == '__main__':
|
| 126 |
+
main()
|
thirdparty/learning3d/examples/test_pointnet.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import PointNet
|
| 22 |
+
from learning3d.models import Classifier
|
| 23 |
+
from learning3d.data_utils import ClassificationData, ModelNet40Data
|
| 24 |
+
|
| 25 |
+
def display_open3d(template):
|
| 26 |
+
template_ = o3d.geometry.PointCloud()
|
| 27 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 28 |
+
# template_.paint_uniform_color([1, 0, 0])
|
| 29 |
+
o3d.visualization.draw_geometries([template_])
|
| 30 |
+
|
| 31 |
+
def test_one_epoch(device, model, test_loader, testset):
|
| 32 |
+
model.eval()
|
| 33 |
+
test_loss = 0.0
|
| 34 |
+
pred = 0.0
|
| 35 |
+
count = 0
|
| 36 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 37 |
+
points, target = data
|
| 38 |
+
target = target[:,0]
|
| 39 |
+
|
| 40 |
+
points = points.to(device)
|
| 41 |
+
target = target.to(device)
|
| 42 |
+
|
| 43 |
+
output = model(points)
|
| 44 |
+
loss_val = torch.nn.functional.nll_loss(
|
| 45 |
+
torch.nn.functional.log_softmax(output, dim=1), target, size_average=False)
|
| 46 |
+
print("Ground Truth Label: ", testset.get_shape(target[0].item()))
|
| 47 |
+
print("Predicted Label: ", testset.get_shape(torch.argmax(output[0]).item()))
|
| 48 |
+
display_open3d(points.detach().cpu().numpy()[0])
|
| 49 |
+
|
| 50 |
+
test_loss += loss_val.item()
|
| 51 |
+
count += output.size(0)
|
| 52 |
+
|
| 53 |
+
_, pred1 = output.max(dim=1)
|
| 54 |
+
ag = (pred1 == target)
|
| 55 |
+
am = ag.sum()
|
| 56 |
+
pred += am.item()
|
| 57 |
+
|
| 58 |
+
test_loss = float(test_loss)/count
|
| 59 |
+
accuracy = float(pred)/count
|
| 60 |
+
return test_loss, accuracy
|
| 61 |
+
|
| 62 |
+
def test(args, model, test_loader, testset):
|
| 63 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader, testset)
|
| 64 |
+
|
| 65 |
+
def options():
|
| 66 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 67 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 68 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 69 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 70 |
+
|
| 71 |
+
# settings for input data
|
| 72 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 73 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 74 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 75 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 76 |
+
|
| 77 |
+
# settings for PointNet
|
| 78 |
+
parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
|
| 79 |
+
help='train pointnet (default: tune)')
|
| 80 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 81 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 82 |
+
parser.add_argument('-b', '--batch_size', default=32, type=int,
|
| 83 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 84 |
+
parser.add_argument('--emb_dims', default=1024, type=int,
|
| 85 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 86 |
+
parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
|
| 87 |
+
help='symmetric function (default: max)')
|
| 88 |
+
|
| 89 |
+
# settings for on training
|
| 90 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_classifier/models/best_model.t7', type=str,
|
| 91 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 92 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 93 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 94 |
+
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
return args
|
| 97 |
+
|
| 98 |
+
def main():
|
| 99 |
+
args = options()
|
| 100 |
+
args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
|
| 101 |
+
|
| 102 |
+
testset = ClassificationData(ModelNet40Data(train=False))
|
| 103 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 104 |
+
|
| 105 |
+
if not torch.cuda.is_available():
|
| 106 |
+
args.device = 'cpu'
|
| 107 |
+
args.device = torch.device(args.device)
|
| 108 |
+
|
| 109 |
+
# Create PointNet Model.
|
| 110 |
+
ptnet = PointNet(emb_dims=args.emb_dims, use_bn=True)
|
| 111 |
+
model = Classifier(feature_model=ptnet)
|
| 112 |
+
|
| 113 |
+
if args.pretrained:
|
| 114 |
+
assert os.path.isfile(args.pretrained)
|
| 115 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 116 |
+
model.to(args.device)
|
| 117 |
+
|
| 118 |
+
test(args, model, test_loader, testset)
|
| 119 |
+
|
| 120 |
+
if __name__ == '__main__':
|
| 121 |
+
main()
|
thirdparty/learning3d/examples/test_prnet.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import PRNet
|
| 22 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 23 |
+
|
| 24 |
+
def get_transformations(igt):
|
| 25 |
+
R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt
|
| 26 |
+
translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba
|
| 27 |
+
R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps
|
| 28 |
+
translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab
|
| 29 |
+
return R_ab, translation_ab, R_ba, translation_ba
|
| 30 |
+
|
| 31 |
+
def display_open3d(template, source, transformed_source):
|
| 32 |
+
template_ = o3d.geometry.PointCloud()
|
| 33 |
+
source_ = o3d.geometry.PointCloud()
|
| 34 |
+
transformed_source_ = o3d.geometry.PointCloud()
|
| 35 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 36 |
+
source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
|
| 37 |
+
transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
|
| 38 |
+
template_.paint_uniform_color([1, 0, 0])
|
| 39 |
+
source_.paint_uniform_color([0, 1, 0])
|
| 40 |
+
transformed_source_.paint_uniform_color([0, 0, 1])
|
| 41 |
+
o3d.visualization.draw_geometries([template_, source_, transformed_source_])
|
| 42 |
+
|
| 43 |
+
def test_one_epoch(device, model, test_loader):
|
| 44 |
+
model.eval()
|
| 45 |
+
test_loss = 0.0
|
| 46 |
+
pred = 0.0
|
| 47 |
+
count = 0
|
| 48 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 49 |
+
template, source, igt = data
|
| 50 |
+
|
| 51 |
+
transformations = get_transformations(igt)
|
| 52 |
+
transformations = [t.to(device) for t in transformations]
|
| 53 |
+
R_ab, translation_ab, R_ba, translation_ba = transformations
|
| 54 |
+
|
| 55 |
+
template = template.to(device)
|
| 56 |
+
source = source.to(device)
|
| 57 |
+
igt = igt.to(device)
|
| 58 |
+
|
| 59 |
+
output = model(template, source, R_ab, translation_ab.squeeze(2))
|
| 60 |
+
display_open3d(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], output['transformed_source'].detach().cpu().numpy()[0])
|
| 61 |
+
|
| 62 |
+
test_loss += output['loss'].item()
|
| 63 |
+
count += 1
|
| 64 |
+
|
| 65 |
+
test_loss = float(test_loss)/count
|
| 66 |
+
return test_loss
|
| 67 |
+
|
| 68 |
+
def test(args, model, test_loader):
|
| 69 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 70 |
+
|
| 71 |
+
def options():
|
| 72 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 73 |
+
parser.add_argument('--exp_name', type=str, default='exp_prnet', metavar='N',
|
| 74 |
+
help='Name of the experiment')
|
| 75 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 76 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 77 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 78 |
+
|
| 79 |
+
# settings for input data
|
| 80 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 81 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 82 |
+
|
| 83 |
+
# settings for PointNet
|
| 84 |
+
parser.add_argument('--emb_dims', default=512, type=int,
|
| 85 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 86 |
+
parser.add_argument('--num_iterations', default=3, type=int,
|
| 87 |
+
help='Number of Iterations')
|
| 88 |
+
|
| 89 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 90 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 91 |
+
parser.add_argument('-b', '--batch_size', default=1, type=int,
|
| 92 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 93 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_prnet/models/best_model.t7', type=str,
|
| 94 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 95 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 96 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 97 |
+
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
return args
|
| 100 |
+
|
| 101 |
+
def main():
|
| 102 |
+
args = options()
|
| 103 |
+
torch.backends.cudnn.deterministic = True
|
| 104 |
+
|
| 105 |
+
trainset = RegistrationData('PRNet', ModelNet40Data(train=True), partial_source=True, partial_template=True)
|
| 106 |
+
testset = RegistrationData('PRNet', ModelNet40Data(train=False), partial_source=True, partial_template=True)
|
| 107 |
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
|
| 108 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 109 |
+
|
| 110 |
+
if not torch.cuda.is_available():
|
| 111 |
+
args.device = 'cpu'
|
| 112 |
+
args.device = torch.device(args.device)
|
| 113 |
+
|
| 114 |
+
# Create PointNet Model.
|
| 115 |
+
model = PRNet(emb_dims=args.emb_dims, num_iters=args.num_iterations)
|
| 116 |
+
model = model.to(args.device)
|
| 117 |
+
|
| 118 |
+
if args.pretrained:
|
| 119 |
+
assert os.path.isfile(args.pretrained)
|
| 120 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'), strict=False)
|
| 121 |
+
model.to(args.device)
|
| 122 |
+
|
| 123 |
+
test(args, model, test_loader)
|
| 124 |
+
|
| 125 |
+
if __name__ == '__main__':
|
| 126 |
+
main()
|
thirdparty/learning3d/examples/test_rpmnet.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import RPMNet, PPFNet
|
| 22 |
+
from learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
|
| 23 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 24 |
+
|
| 25 |
+
def display_open3d(template, source, transformed_source):
|
| 26 |
+
template_ = o3d.geometry.PointCloud()
|
| 27 |
+
source_ = o3d.geometry.PointCloud()
|
| 28 |
+
transformed_source_ = o3d.geometry.PointCloud()
|
| 29 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 30 |
+
source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
|
| 31 |
+
transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
|
| 32 |
+
template_.paint_uniform_color([1, 0, 0])
|
| 33 |
+
source_.paint_uniform_color([0, 1, 0])
|
| 34 |
+
transformed_source_.paint_uniform_color([0, 0, 1])
|
| 35 |
+
o3d.visualization.draw_geometries([template_, source_, transformed_source_])
|
| 36 |
+
|
| 37 |
+
def test_one_epoch(device, model, test_loader):
|
| 38 |
+
model.eval()
|
| 39 |
+
test_loss = 0.0
|
| 40 |
+
pred = 0.0
|
| 41 |
+
count = 0
|
| 42 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 43 |
+
template, source, igt = data
|
| 44 |
+
|
| 45 |
+
template = template.to(device)
|
| 46 |
+
source = source.to(device)
|
| 47 |
+
igt = igt.to(device)
|
| 48 |
+
|
| 49 |
+
output = model(template, source)
|
| 50 |
+
|
| 51 |
+
display_open3d(template.detach().cpu().numpy()[0,:,:3], source.detach().cpu().numpy()[0,:,:3], output['transformed_source'].detach().cpu().numpy()[0])
|
| 52 |
+
loss_val = FrobeniusNormLoss()(output['est_T'], igt)
|
| 53 |
+
|
| 54 |
+
test_loss += loss_val.item()
|
| 55 |
+
count += 1
|
| 56 |
+
|
| 57 |
+
test_loss = float(test_loss)/count
|
| 58 |
+
return test_loss
|
| 59 |
+
|
| 60 |
+
def test(args, model, test_loader):
|
| 61 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def options():
|
| 65 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 66 |
+
parser.add_argument('--exp_name', type=str, default='exp_rpmnet', metavar='N',
|
| 67 |
+
help='Name of the experiment')
|
| 68 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 69 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 70 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 71 |
+
|
| 72 |
+
# settings for input data
|
| 73 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 74 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 75 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 76 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 77 |
+
|
| 78 |
+
# settings for PointNet
|
| 79 |
+
parser.add_argument('--emb_dims', default=1024, type=int,
|
| 80 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 81 |
+
parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
|
| 82 |
+
help='symmetric function (default: max)')
|
| 83 |
+
|
| 84 |
+
# settings for on training
|
| 85 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 86 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 87 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 88 |
+
parser.add_argument('-b', '--batch_size', default=10, type=int,
|
| 89 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 90 |
+
parser.add_argument('--pretrained', default='learning3d/pretrained/exp_rpmnet/models/partial-trained.pth', type=str,
|
| 91 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 92 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 93 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 94 |
+
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
return args
|
| 97 |
+
|
| 98 |
+
def main():
|
| 99 |
+
args = options()
|
| 100 |
+
|
| 101 |
+
testset = RegistrationData('RPMNet', ModelNet40Data(train=False, num_points=args.num_points, use_normals=True), partial_source=True, partial_template=False)
|
| 102 |
+
test_loader = DataLoader(testset, batch_size=1, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 103 |
+
|
| 104 |
+
if not torch.cuda.is_available():
|
| 105 |
+
args.device = 'cpu'
|
| 106 |
+
args.device = torch.device(args.device)
|
| 107 |
+
|
| 108 |
+
# Create RPMNet Model.
|
| 109 |
+
model = RPMNet(feature_model=PPFNet())
|
| 110 |
+
model = model.to(args.device)
|
| 111 |
+
|
| 112 |
+
if args.pretrained:
|
| 113 |
+
assert os.path.isfile(args.pretrained)
|
| 114 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu')['state_dict'])
|
| 115 |
+
model.to(args.device)
|
| 116 |
+
|
| 117 |
+
test(args, model, test_loader)
|
| 118 |
+
|
| 119 |
+
if __name__ == '__main__':
|
| 120 |
+
main()
|
thirdparty/learning3d/examples/train_PointNetLK.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import logging
|
| 5 |
+
import numpy
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import torchvision
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from tensorboardX import SummaryWriter
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
# Only if the files are in example folder.
|
| 15 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
if BASE_DIR[-8:] == 'examples':
|
| 17 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 18 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
|
| 20 |
+
from learning3d.models import PointNet
|
| 21 |
+
from learning3d.models import PointNetLK
|
| 22 |
+
from learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
|
| 23 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 24 |
+
|
| 25 |
+
def _init_(args):
|
| 26 |
+
if not os.path.exists('checkpoints'):
|
| 27 |
+
os.makedirs('checkpoints')
|
| 28 |
+
if not os.path.exists('checkpoints/' + args.exp_name):
|
| 29 |
+
os.makedirs('checkpoints/' + args.exp_name)
|
| 30 |
+
if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
|
| 31 |
+
os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
|
| 32 |
+
os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup')
|
| 33 |
+
os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class IOStream:
|
| 37 |
+
def __init__(self, path):
|
| 38 |
+
self.f = open(path, 'a')
|
| 39 |
+
|
| 40 |
+
def cprint(self, text):
|
| 41 |
+
print(text)
|
| 42 |
+
self.f.write(text + '\n')
|
| 43 |
+
self.f.flush()
|
| 44 |
+
|
| 45 |
+
def close(self):
|
| 46 |
+
self.f.close()
|
| 47 |
+
|
| 48 |
+
def test_one_epoch(device, model, test_loader):
|
| 49 |
+
model.eval()
|
| 50 |
+
test_loss = 0.0
|
| 51 |
+
pred = 0.0
|
| 52 |
+
count = 0
|
| 53 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 54 |
+
template, source, igt = data
|
| 55 |
+
|
| 56 |
+
template = template.to(device)
|
| 57 |
+
source = source.to(device)
|
| 58 |
+
igt = igt.to(device)
|
| 59 |
+
|
| 60 |
+
output = model(template, source)
|
| 61 |
+
loss_val = FrobeniusNormLoss()(output['est_T'], igt) + RMSEFeaturesLoss()(output['r'])
|
| 62 |
+
|
| 63 |
+
test_loss += loss_val.item()
|
| 64 |
+
count += 1
|
| 65 |
+
|
| 66 |
+
test_loss = float(test_loss)/count
|
| 67 |
+
return test_loss
|
| 68 |
+
|
| 69 |
+
def test(args, model, test_loader, textio):
|
| 70 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
|
| 71 |
+
textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
|
| 72 |
+
|
| 73 |
+
def train_one_epoch(device, model, train_loader, optimizer):
|
| 74 |
+
model.train()
|
| 75 |
+
train_loss = 0.0
|
| 76 |
+
pred = 0.0
|
| 77 |
+
count = 0
|
| 78 |
+
for i, data in enumerate(tqdm(train_loader)):
|
| 79 |
+
template, source, igt = data
|
| 80 |
+
|
| 81 |
+
template = template.to(device)
|
| 82 |
+
source = source.to(device)
|
| 83 |
+
igt = igt.to(device)
|
| 84 |
+
|
| 85 |
+
output = model(template, source)
|
| 86 |
+
loss_val = FrobeniusNormLoss()(output['est_T'], igt) + RMSEFeaturesLoss()(output['r'])
|
| 87 |
+
# print(loss_val.item())
|
| 88 |
+
|
| 89 |
+
# forward + backward + optimize
|
| 90 |
+
optimizer.zero_grad()
|
| 91 |
+
loss_val.backward()
|
| 92 |
+
optimizer.step()
|
| 93 |
+
|
| 94 |
+
train_loss += loss_val.item()
|
| 95 |
+
count += 1
|
| 96 |
+
|
| 97 |
+
train_loss = float(train_loss)/count
|
| 98 |
+
return train_loss
|
| 99 |
+
|
| 100 |
+
def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
|
| 101 |
+
learnable_params = filter(lambda p: p.requires_grad, model.parameters())
|
| 102 |
+
if args.optimizer == 'Adam':
|
| 103 |
+
optimizer = torch.optim.Adam(learnable_params)
|
| 104 |
+
else:
|
| 105 |
+
optimizer = torch.optim.SGD(learnable_params, lr=0.1)
|
| 106 |
+
|
| 107 |
+
if checkpoint is not None:
|
| 108 |
+
min_loss = checkpoint['min_loss']
|
| 109 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 110 |
+
|
| 111 |
+
best_test_loss = np.inf
|
| 112 |
+
|
| 113 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 114 |
+
train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
|
| 115 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 116 |
+
|
| 117 |
+
if test_loss<best_test_loss:
|
| 118 |
+
best_test_loss = test_loss
|
| 119 |
+
snap = {'epoch': epoch + 1,
|
| 120 |
+
'model': model.state_dict(),
|
| 121 |
+
'min_loss': best_test_loss,
|
| 122 |
+
'optimizer' : optimizer.state_dict(),}
|
| 123 |
+
torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
|
| 124 |
+
torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
|
| 125 |
+
torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
|
| 126 |
+
|
| 127 |
+
torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
|
| 128 |
+
torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
|
| 129 |
+
torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
|
| 130 |
+
|
| 131 |
+
boardio.add_scalar('Train Loss', train_loss, epoch+1)
|
| 132 |
+
boardio.add_scalar('Test Loss', test_loss, epoch+1)
|
| 133 |
+
boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
|
| 134 |
+
|
| 135 |
+
textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
|
| 136 |
+
|
| 137 |
+
def options():
|
| 138 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 139 |
+
parser.add_argument('--exp_name', type=str, default='exp_pnlk', metavar='N',
|
| 140 |
+
help='Name of the experiment')
|
| 141 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 142 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 143 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 144 |
+
|
| 145 |
+
# settings for input data
|
| 146 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 147 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 148 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 149 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 150 |
+
|
| 151 |
+
# settings for PointNet
|
| 152 |
+
parser.add_argument('--fine_tune_pointnet', default='tune', type=str, choices=['fixed', 'tune'],
|
| 153 |
+
help='train pointnet (default: tune)')
|
| 154 |
+
parser.add_argument('--transfer_ptnet_weights', default='./checkpoints/exp_classifier/models/best_ptnet_model.t7', type=str,
|
| 155 |
+
metavar='PATH', help='path to pointnet features file')
|
| 156 |
+
parser.add_argument('--emb_dims', default=1024, type=int,
|
| 157 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 158 |
+
parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
|
| 159 |
+
help='symmetric function (default: max)')
|
| 160 |
+
|
| 161 |
+
# settings for on training
|
| 162 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 163 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 164 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 165 |
+
parser.add_argument('-b', '--batch_size', default=10, type=int,
|
| 166 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 167 |
+
parser.add_argument('--epochs', default=200, type=int,
|
| 168 |
+
metavar='N', help='number of total epochs to run')
|
| 169 |
+
parser.add_argument('--start_epoch', default=0, type=int,
|
| 170 |
+
metavar='N', help='manual epoch number (useful on restarts)')
|
| 171 |
+
parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
|
| 172 |
+
metavar='METHOD', help='name of an optimizer (default: Adam)')
|
| 173 |
+
parser.add_argument('--resume', default='', type=str,
|
| 174 |
+
metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
|
| 175 |
+
parser.add_argument('--pretrained', default='', type=str,
|
| 176 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 177 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 178 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 179 |
+
|
| 180 |
+
args = parser.parse_args()
|
| 181 |
+
return args
|
| 182 |
+
|
| 183 |
+
def main():
|
| 184 |
+
args = options()
|
| 185 |
+
|
| 186 |
+
torch.backends.cudnn.deterministic = True
|
| 187 |
+
torch.manual_seed(args.seed)
|
| 188 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 189 |
+
np.random.seed(args.seed)
|
| 190 |
+
|
| 191 |
+
boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
|
| 192 |
+
_init_(args)
|
| 193 |
+
|
| 194 |
+
textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
|
| 195 |
+
textio.cprint(str(args))
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
trainset = RegistrationData('PointNetLK', ModelNet40Data(train=True))
|
| 199 |
+
testset = RegistrationData('PointNetLK', ModelNet40Data(train=False))
|
| 200 |
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
|
| 201 |
+
test_loader = DataLoader(testset, batch_size=8, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 202 |
+
|
| 203 |
+
if not torch.cuda.is_available():
|
| 204 |
+
args.device = 'cpu'
|
| 205 |
+
args.device = torch.device(args.device)
|
| 206 |
+
|
| 207 |
+
# Create PointNet Model.
|
| 208 |
+
ptnet = PointNet(emb_dims=args.emb_dims, use_bn=True)
|
| 209 |
+
|
| 210 |
+
if args.transfer_ptnet_weights and os.path.isfile(args.transfer_ptnet_weights):
|
| 211 |
+
ptnet.load_state_dict(torch.load(args.transfer_ptnet_weights, map_location='cpu'))
|
| 212 |
+
|
| 213 |
+
if args.fine_tune_pointnet == 'tune':
|
| 214 |
+
pass
|
| 215 |
+
elif args.fine_tune_pointnet == 'fixed':
|
| 216 |
+
for param in ptnet.parameters():
|
| 217 |
+
param.requires_grad_(False)
|
| 218 |
+
|
| 219 |
+
model = PointNetLK(feature_model=ptnet)
|
| 220 |
+
model = model.to(args.device)
|
| 221 |
+
|
| 222 |
+
checkpoint = None
|
| 223 |
+
if args.resume:
|
| 224 |
+
assert os.path.isfile(args.resume)
|
| 225 |
+
checkpoint = torch.load(args.resume)
|
| 226 |
+
args.start_epoch = checkpoint['epoch']
|
| 227 |
+
model.load_state_dict(checkpoint['model'])
|
| 228 |
+
|
| 229 |
+
if args.pretrained:
|
| 230 |
+
assert os.path.isfile(args.pretrained)
|
| 231 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 232 |
+
model.to(args.device)
|
| 233 |
+
|
| 234 |
+
if args.eval:
|
| 235 |
+
test(args, model, test_loader, textio)
|
| 236 |
+
else:
|
| 237 |
+
train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
|
| 238 |
+
|
| 239 |
+
if __name__ == '__main__':
|
| 240 |
+
main()
|
thirdparty/learning3d/examples/train_dcp.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import logging
|
| 5 |
+
import numpy
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import torchvision
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from tensorboardX import SummaryWriter
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
# Only if the files are in example folder.
|
| 15 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
if BASE_DIR[-8:] == 'examples':
|
| 17 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 18 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
|
| 20 |
+
from learning3d.models import DGCNN, DCP
|
| 21 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 22 |
+
|
| 23 |
+
def _init_(args):
|
| 24 |
+
if not os.path.exists('checkpoints'):
|
| 25 |
+
os.makedirs('checkpoints')
|
| 26 |
+
if not os.path.exists('checkpoints/' + args.exp_name):
|
| 27 |
+
os.makedirs('checkpoints/' + args.exp_name)
|
| 28 |
+
if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
|
| 29 |
+
os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
|
| 30 |
+
os.system('cp train_dcp.py checkpoints' + '/' + args.exp_name + '/' + 'train.py.backup')
|
| 31 |
+
|
| 32 |
+
class IOStream:
|
| 33 |
+
def __init__(self, path):
|
| 34 |
+
self.f = open(path, 'a')
|
| 35 |
+
|
| 36 |
+
def cprint(self, text):
|
| 37 |
+
print(text)
|
| 38 |
+
self.f.write(text + '\n')
|
| 39 |
+
self.f.flush()
|
| 40 |
+
|
| 41 |
+
def close(self):
|
| 42 |
+
self.f.close()
|
| 43 |
+
|
| 44 |
+
def get_transformations(igt):
|
| 45 |
+
R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt
|
| 46 |
+
translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba
|
| 47 |
+
R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps
|
| 48 |
+
translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab
|
| 49 |
+
return R_ab, translation_ab, R_ba, translation_ba
|
| 50 |
+
|
| 51 |
+
def test_one_epoch(device, model, test_loader):
|
| 52 |
+
model.eval()
|
| 53 |
+
test_loss = 0.0
|
| 54 |
+
pred = 0.0
|
| 55 |
+
count = 0
|
| 56 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 57 |
+
template, source, igt = data
|
| 58 |
+
transformations = get_transformations(igt)
|
| 59 |
+
transformations = [t.to(device) for t in transformations]
|
| 60 |
+
R_ab, translation_ab, R_ba, translation_ba = transformations
|
| 61 |
+
|
| 62 |
+
template = template.to(device)
|
| 63 |
+
source = source.to(device)
|
| 64 |
+
igt = igt.to(device)
|
| 65 |
+
|
| 66 |
+
output = model(template, source)
|
| 67 |
+
identity = torch.eye(3).cuda().unsqueeze(0).repeat(template.shape[0], 1, 1)
|
| 68 |
+
loss_val = torch.nn.functional.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), R_ab), identity) \
|
| 69 |
+
+ torch.nn.functional.mse_loss(output['est_t'], translation_ab[:,:,0])
|
| 70 |
+
|
| 71 |
+
cycle_loss = torch.nn.functional.mse_loss(torch.matmul(output['est_R_'].transpose(2, 1), R_ba), identity) \
|
| 72 |
+
+ torch.nn.functional.mse_loss(output['est_t_'], translation_ba[:,:,0])
|
| 73 |
+
loss_val = loss_val + cycle_loss * 0.1
|
| 74 |
+
|
| 75 |
+
test_loss += loss_val.item()
|
| 76 |
+
count += 1
|
| 77 |
+
|
| 78 |
+
test_loss = float(test_loss)/count
|
| 79 |
+
return test_loss
|
| 80 |
+
|
| 81 |
+
def test(args, model, test_loader, textio):
|
| 82 |
+
test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
|
| 83 |
+
textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
|
| 84 |
+
|
| 85 |
+
def train_one_epoch(device, model, train_loader, optimizer):
|
| 86 |
+
model.train()
|
| 87 |
+
train_loss = 0.0
|
| 88 |
+
pred = 0.0
|
| 89 |
+
count = 0
|
| 90 |
+
for i, data in enumerate(tqdm(train_loader)):
|
| 91 |
+
template, source, igt = data
|
| 92 |
+
transformations = get_transformations(igt)
|
| 93 |
+
transformations = [t.to(device) for t in transformations]
|
| 94 |
+
R_ab, translation_ab, R_ba, translation_ba = transformations
|
| 95 |
+
|
| 96 |
+
template = template.to(device)
|
| 97 |
+
source = source.to(device)
|
| 98 |
+
igt = igt.to(device)
|
| 99 |
+
|
| 100 |
+
output = model(template, source)
|
| 101 |
+
identity = torch.eye(3).cuda().unsqueeze(0).repeat(template.shape[0], 1, 1)
|
| 102 |
+
loss_val = torch.nn.functional.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), R_ab), identity) \
|
| 103 |
+
+ torch.nn.functional.mse_loss(output['est_t'], translation_ab[:,:,0])
|
| 104 |
+
|
| 105 |
+
cycle_loss = torch.nn.functional.mse_loss(torch.matmul(output['est_R_'].transpose(2, 1), R_ba), identity) \
|
| 106 |
+
+ torch.nn.functional.mse_loss(output['est_t_'], translation_ba[:,:,0])
|
| 107 |
+
loss_val = loss_val + cycle_loss * 0.1
|
| 108 |
+
# print(loss_val.item())
|
| 109 |
+
|
| 110 |
+
# forward + backward + optimize
|
| 111 |
+
optimizer.zero_grad()
|
| 112 |
+
loss_val.backward()
|
| 113 |
+
optimizer.step()
|
| 114 |
+
|
| 115 |
+
train_loss += loss_val.item()
|
| 116 |
+
count += 1
|
| 117 |
+
|
| 118 |
+
train_loss = float(train_loss)/count
|
| 119 |
+
return train_loss
|
| 120 |
+
|
| 121 |
+
def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
|
| 122 |
+
learnable_params = filter(lambda p: p.requires_grad, model.parameters())
|
| 123 |
+
if args.optimizer == 'Adam':
|
| 124 |
+
optimizer = torch.optim.Adam(learnable_params)
|
| 125 |
+
else:
|
| 126 |
+
optimizer = torch.optim.SGD(learnable_params, lr=0.1)
|
| 127 |
+
|
| 128 |
+
if checkpoint is not None:
|
| 129 |
+
min_loss = checkpoint['min_loss']
|
| 130 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 131 |
+
|
| 132 |
+
best_test_loss = np.inf
|
| 133 |
+
|
| 134 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 135 |
+
train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
|
| 136 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 137 |
+
|
| 138 |
+
if test_loss<best_test_loss:
|
| 139 |
+
best_test_loss = test_loss
|
| 140 |
+
snap = {'epoch': epoch + 1,
|
| 141 |
+
'model': model.state_dict(),
|
| 142 |
+
'min_loss': best_test_loss,
|
| 143 |
+
'optimizer' : optimizer.state_dict(),}
|
| 144 |
+
torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
|
| 145 |
+
torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
|
| 146 |
+
torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
|
| 147 |
+
|
| 148 |
+
torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
|
| 149 |
+
torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
|
| 150 |
+
torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
|
| 151 |
+
|
| 152 |
+
boardio.add_scalar('Train Loss', train_loss, epoch+1)
|
| 153 |
+
boardio.add_scalar('Test Loss', test_loss, epoch+1)
|
| 154 |
+
boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
|
| 155 |
+
|
| 156 |
+
textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
|
| 157 |
+
|
| 158 |
+
def options():
|
| 159 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 160 |
+
parser.add_argument('--exp_name', type=str, default='exp_dcp', metavar='N',
|
| 161 |
+
help='Name of the experiment')
|
| 162 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 163 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 164 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 165 |
+
|
| 166 |
+
# settings for input data
|
| 167 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 168 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 169 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 170 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 171 |
+
|
| 172 |
+
# settings for PointNet
|
| 173 |
+
parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
|
| 174 |
+
help='train pointnet (default: tune)')
|
| 175 |
+
parser.add_argument('--emb_dims', default=1024, type=int,
|
| 176 |
+
metavar='K', help='dim. of the feature vector (default: 1024)')
|
| 177 |
+
parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
|
| 178 |
+
help='symmetric function (default: max)')
|
| 179 |
+
|
| 180 |
+
# settings for on training
|
| 181 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 182 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 183 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 184 |
+
parser.add_argument('-b', '--batch_size', default=32, type=int,
|
| 185 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 186 |
+
parser.add_argument('--epochs', default=200, type=int,
|
| 187 |
+
metavar='N', help='number of total epochs to run')
|
| 188 |
+
parser.add_argument('--start_epoch', default=0, type=int,
|
| 189 |
+
metavar='N', help='manual epoch number (useful on restarts)')
|
| 190 |
+
parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
|
| 191 |
+
metavar='METHOD', help='name of an optimizer (default: Adam)')
|
| 192 |
+
parser.add_argument('--resume', default='', type=str,
|
| 193 |
+
metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
|
| 194 |
+
parser.add_argument('--pretrained', default='', type=str,
|
| 195 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 196 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 197 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 198 |
+
|
| 199 |
+
args = parser.parse_args()
|
| 200 |
+
return args
|
| 201 |
+
|
| 202 |
+
def main():
|
| 203 |
+
args = options()
|
| 204 |
+
|
| 205 |
+
torch.backends.cudnn.deterministic = True
|
| 206 |
+
torch.manual_seed(args.seed)
|
| 207 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 208 |
+
np.random.seed(args.seed)
|
| 209 |
+
|
| 210 |
+
boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
|
| 211 |
+
_init_(args)
|
| 212 |
+
|
| 213 |
+
textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
|
| 214 |
+
textio.cprint(str(args))
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
trainset = RegistrationData('DCP', ModelNet40Data(train=True))
|
| 218 |
+
testset = RegistrationData('DCP', ModelNet40Data(train=False))
|
| 219 |
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
|
| 220 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 221 |
+
|
| 222 |
+
if not torch.cuda.is_available():
|
| 223 |
+
args.device = 'cpu'
|
| 224 |
+
args.device = torch.device(args.device)
|
| 225 |
+
|
| 226 |
+
# Create PointNet Model.
|
| 227 |
+
dgcnn = DGCNN(emb_dims=args.emb_dims)
|
| 228 |
+
model = DCP(feature_model=dgcnn, cycle=True)
|
| 229 |
+
model = model.to(args.device)
|
| 230 |
+
|
| 231 |
+
checkpoint = None
|
| 232 |
+
if args.resume:
|
| 233 |
+
assert os.path.isfile(args.resume)
|
| 234 |
+
checkpoint = torch.load(args.resume)
|
| 235 |
+
args.start_epoch = checkpoint['epoch']
|
| 236 |
+
model.load_state_dict(checkpoint['model'])
|
| 237 |
+
|
| 238 |
+
if args.pretrained:
|
| 239 |
+
assert os.path.isfile(args.pretrained)
|
| 240 |
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
| 241 |
+
model.to(args.device)
|
| 242 |
+
|
| 243 |
+
if args.eval:
|
| 244 |
+
test(args, model, test_loader, textio)
|
| 245 |
+
else:
|
| 246 |
+
train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
|
| 247 |
+
|
| 248 |
+
if __name__ == '__main__':
|
| 249 |
+
main()
|
thirdparty/learning3d/examples/train_deepgmr.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open3d as o3d
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import numpy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Only if the files are in example folder.
|
| 16 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if BASE_DIR[-8:] == 'examples':
|
| 18 |
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
| 19 |
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
| 20 |
+
|
| 21 |
+
from learning3d.models import DeepGMR
|
| 22 |
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
| 23 |
+
|
| 24 |
+
def display_open3d(template, source, transformed_source):
|
| 25 |
+
template_ = o3d.geometry.PointCloud()
|
| 26 |
+
source_ = o3d.geometry.PointCloud()
|
| 27 |
+
transformed_source_ = o3d.geometry.PointCloud()
|
| 28 |
+
template_.points = o3d.utility.Vector3dVector(template)
|
| 29 |
+
source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
|
| 30 |
+
transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
|
| 31 |
+
template_.paint_uniform_color([1, 0, 0])
|
| 32 |
+
source_.paint_uniform_color([0, 1, 0])
|
| 33 |
+
transformed_source_.paint_uniform_color([0, 0, 1])
|
| 34 |
+
o3d.visualization.draw_geometries([template_, source_, transformed_source_])
|
| 35 |
+
|
| 36 |
+
def rotation_error(R, R_gt):
|
| 37 |
+
cos_theta = (torch.einsum('bij,bij->b', R, R_gt) - 1) / 2
|
| 38 |
+
cos_theta = torch.clamp(cos_theta, -1, 1)
|
| 39 |
+
return torch.acos(cos_theta) * 180 / math.pi
|
| 40 |
+
|
| 41 |
+
def translation_error(t, t_gt):
|
| 42 |
+
return torch.norm(t - t_gt, dim=1)
|
| 43 |
+
|
| 44 |
+
def rmse(pts, T, T_gt):
|
| 45 |
+
pts_pred = pts @ T[:, :3, :3].transpose(1, 2) + T[:, :3, 3].unsqueeze(1)
|
| 46 |
+
pts_gt = pts @ T_gt[:, :3, :3].transpose(1, 2) + T_gt[:, :3, 3].unsqueeze(1)
|
| 47 |
+
return torch.norm(pts_pred - pts_gt, dim=2).mean(dim=1)
|
| 48 |
+
|
| 49 |
+
def test_one_epoch(device, model, test_loader):
|
| 50 |
+
model.eval()
|
| 51 |
+
test_loss = 0.0
|
| 52 |
+
pred = 0.0
|
| 53 |
+
count = 0
|
| 54 |
+
rotation_errors, translation_errors, rmses = [], [], []
|
| 55 |
+
|
| 56 |
+
for i, data in enumerate(tqdm(test_loader)):
|
| 57 |
+
template, source, igt = data
|
| 58 |
+
|
| 59 |
+
template = template.to(device)
|
| 60 |
+
source = source.to(device)
|
| 61 |
+
igt = igt.to(device)
|
| 62 |
+
|
| 63 |
+
output = model(template, source)
|
| 64 |
+
|
| 65 |
+
eye = torch.eye(4).expand_as(igt).to(igt.device)
|
| 66 |
+
mse1 = F.mse_loss(output['est_T_inverse'] @ torch.inverse(igt), eye)
|
| 67 |
+
mse2 = F.mse_loss(output['est_T'] @ igt, eye)
|
| 68 |
+
loss = mse1 + mse2
|
| 69 |
+
|
| 70 |
+
r_err = rotation_error(est_T_inverse[:, :3, :3], igt[:, :3, :3])
|
| 71 |
+
t_err = translation_error(est_T_inverse[:, :3, 3], igt[:, :3, 3])
|
| 72 |
+
rmse_val = rmse(template[:, :100], est_T_inverse, igt)
|
| 73 |
+
rotation_errors.append(r_err)
|
| 74 |
+
translation_errors.append(t_err)
|
| 75 |
+
rmses.append(rmse_val)
|
| 76 |
+
|
| 77 |
+
test_loss += loss_val.item()
|
| 78 |
+
count += 1
|
| 79 |
+
|
| 80 |
+
test_loss = float(test_loss)/count
|
| 81 |
+
print("Mean rotation error: {}, Mean translation error: {} and Mean RMSE: {}".format(np.mean(rotation_errors), np.mean(translation_errors), np.mean(rmses)))
|
| 82 |
+
return test_loss
|
| 83 |
+
|
| 84 |
+
def test(args, model, test_loader, textio):
|
| 85 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 86 |
+
textio.cprint('Validation Loss: %f'%(test_loss))
|
| 87 |
+
|
| 88 |
+
def train_one_epoch(device, model, train_loader, optimizer):
|
| 89 |
+
model.train()
|
| 90 |
+
train_loss = 0.0
|
| 91 |
+
pred = 0.0
|
| 92 |
+
count = 0
|
| 93 |
+
for i, data in enumerate(tqdm(train_loader)):
|
| 94 |
+
template, source, igt = data
|
| 95 |
+
|
| 96 |
+
template = template.to(device)
|
| 97 |
+
source = source.to(device)
|
| 98 |
+
igt = igt.to(device)
|
| 99 |
+
|
| 100 |
+
output = model(template, source)
|
| 101 |
+
|
| 102 |
+
eye = torch.eye(4).expand_as(igt).to(igt.device)
|
| 103 |
+
mse1 = F.mse_loss(output['est_T_inverse'] @ torch.inverse(igt), eye)
|
| 104 |
+
mse2 = F.mse_loss(output['est_T'] @ igt, eye)
|
| 105 |
+
loss = mse1 + mse2
|
| 106 |
+
|
| 107 |
+
# forward + backward + optimize
|
| 108 |
+
optimizer.zero_grad()
|
| 109 |
+
loss_val.backward()
|
| 110 |
+
optimizer.step()
|
| 111 |
+
|
| 112 |
+
train_loss += loss_val.item()
|
| 113 |
+
count += 1
|
| 114 |
+
|
| 115 |
+
train_loss = float(train_loss)/count
|
| 116 |
+
return train_loss
|
| 117 |
+
|
| 118 |
+
def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
|
| 119 |
+
learnable_params = filter(lambda p: p.requires_grad, model.parameters())
|
| 120 |
+
if args.optimizer == 'Adam':
|
| 121 |
+
optimizer = torch.optim.Adam(learnable_params)
|
| 122 |
+
else:
|
| 123 |
+
optimizer = torch.optim.SGD(learnable_params, lr=0.1)
|
| 124 |
+
|
| 125 |
+
if checkpoint is not None:
|
| 126 |
+
min_loss = checkpoint['min_loss']
|
| 127 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 128 |
+
|
| 129 |
+
best_test_loss = np.inf
|
| 130 |
+
|
| 131 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 132 |
+
train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
|
| 133 |
+
test_loss = test_one_epoch(args.device, model, test_loader)
|
| 134 |
+
|
| 135 |
+
if test_loss<best_test_loss:
|
| 136 |
+
best_test_loss = test_loss
|
| 137 |
+
snap = {'epoch': epoch + 1,
|
| 138 |
+
'model': model.state_dict(),
|
| 139 |
+
'min_loss': best_test_loss,
|
| 140 |
+
'optimizer' : optimizer.state_dict(),}
|
| 141 |
+
torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
|
| 142 |
+
torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
|
| 143 |
+
torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
|
| 144 |
+
|
| 145 |
+
torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
|
| 146 |
+
torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
|
| 147 |
+
torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
|
| 148 |
+
|
| 149 |
+
boardio.add_scalar('Train Loss', train_loss, epoch+1)
|
| 150 |
+
boardio.add_scalar('Test Loss', test_loss, epoch+1)
|
| 151 |
+
boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
|
| 152 |
+
|
| 153 |
+
textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
|
| 154 |
+
|
| 155 |
+
def options():
|
| 156 |
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
| 157 |
+
parser.add_argument('--exp_name', type=str, default='exp_deepgmr', metavar='N',
|
| 158 |
+
help='Name of the experiment')
|
| 159 |
+
parser.add_argument('--dataset_path', type=str, default='ModelNet40',
|
| 160 |
+
metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
|
| 161 |
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
| 162 |
+
|
| 163 |
+
# settings for input data
|
| 164 |
+
parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
|
| 165 |
+
metavar='DATASET', help='dataset type (default: modelnet)')
|
| 166 |
+
parser.add_argument('--num_points', default=1024, type=int,
|
| 167 |
+
metavar='N', help='points in point-cloud (default: 1024)')
|
| 168 |
+
|
| 169 |
+
parser.add_argument('--nearest_neighbors', default=20, type=int,
|
| 170 |
+
metavar='K', help='No of nearest neighbors to be estimated.')
|
| 171 |
+
parser.add_argument('--use_rri', default=True, type=bool,
|
| 172 |
+
help='Find nearest neighbors to estimate features from PointNet.')
|
| 173 |
+
|
| 174 |
+
# settings for on training
|
| 175 |
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
| 176 |
+
metavar='N', help='number of data loading workers (default: 4)')
|
| 177 |
+
parser.add_argument('-b', '--batch_size', default=32, type=int,
|
| 178 |
+
metavar='N', help='mini-batch size (default: 32)')
|
| 179 |
+
parser.add_argument('--pretrained', default='', type=str,
|
| 180 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 181 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 182 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 183 |
+
parser.add_argument('--epochs', default=200, type=int,
|
| 184 |
+
metavar='N', help='number of total epochs to run')
|
| 185 |
+
parser.add_argument('--start_epoch', default=0, type=int,
|
| 186 |
+
metavar='N', help='manual epoch number (useful on restarts)')
|
| 187 |
+
parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
|
| 188 |
+
metavar='METHOD', help='name of an optimizer (default: Adam)')
|
| 189 |
+
parser.add_argument('--resume', default='', type=str,
|
| 190 |
+
metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
|
| 191 |
+
parser.add_argument('--pretrained', default='', type=str,
|
| 192 |
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
| 193 |
+
parser.add_argument('--device', default='cuda:0', type=str,
|
| 194 |
+
metavar='DEVICE', help='use CUDA if available')
|
| 195 |
+
|
| 196 |
+
args = parser.parse_args()
|
| 197 |
+
if args.nearest_neighbors > 0:
|
| 198 |
+
args.use_rri = True
|
| 199 |
+
return args
|
| 200 |
+
|
| 201 |
+
def main():
|
| 202 |
+
args = options()
|
| 203 |
+
torch.backends.cudnn.deterministic = True
|
| 204 |
+
torch.manual_seed(args.seed)
|
| 205 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 206 |
+
np.random.seed(args.seed)
|
| 207 |
+
|
| 208 |
+
boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
|
| 209 |
+
_init_(args)
|
| 210 |
+
|
| 211 |
+
textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
|
| 212 |
+
textio.cprint(str(args))
|
| 213 |
+
|
| 214 |
+
trainset = RegistrationData('DeepGMR', ModelNet40Data(train=True), additional_params={'nearest_neighbors': args.nearest_neighbors})
|
| 215 |
+
testset = RegistrationData('DeepGMR', ModelNet40Data(train=False), additional_params={'nearest_neighbors': args.nearest_neighbors})
|
| 216 |
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
|
| 217 |
+
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
| 218 |
+
|
| 219 |
+
if not torch.cuda.is_available():
|
| 220 |
+
args.device = 'cpu'
|
| 221 |
+
args.device = torch.device(args.device)
|
| 222 |
+
|
| 223 |
+
model = DeepGMR(use_rri=args.use_rri, nearest_neighbors=args.nearest_neighbors)
|
| 224 |
+
model = model.to(args.device)
|
| 225 |
+
|
| 226 |
+
checkpoint = None
|
| 227 |
+
if args.resume:
|
| 228 |
+
assert os.path.isfile(args.resume)
|
| 229 |
+
checkpoint = torch.load(args.resume)
|
| 230 |
+
args.start_epoch = checkpoint['epoch']
|
| 231 |
+
model.load_state_dict(checkpoint['model'])
|
| 232 |
+
|
| 233 |
+
if args.pretrained:
|
| 234 |
+
assert os.path.isfile(args.pretrained)
|
| 235 |
+
model.load_state_dict(torch.load(args.pretrained), strict=False)
|
| 236 |
+
model.to(args.device)
|
| 237 |
+
|
| 238 |
+
if args.eval:
|
| 239 |
+
test(args, model, test_loader, textio)
|
| 240 |
+
else:
|
| 241 |
+
train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
|
| 242 |
+
|
| 243 |
+
if __name__ == '__main__':
|
| 244 |
+
main()
|