FangSen9000 commited on
Commit
7f585cf
·
verified ·
1 Parent(s): 735ee2e

Upload EMS-superquadric_fitting_inference

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. EMS-superquadric_fitting_inference/LICENSE +21 -0
  3. EMS-superquadric_fitting_inference/README.md +37 -0
  4. EMS-superquadric_fitting_inference/README_MOVI.md +188 -0
  5. EMS-superquadric_fitting_inference/__pycache__/process_movi_validation.cpython-311.pyc +0 -0
  6. EMS-superquadric_fitting_inference/__pycache__/process_movi_validation.cpython-312.pyc +0 -0
  7. EMS-superquadric_fitting_inference/download_movi_a.py +462 -0
  8. EMS-superquadric_fitting_inference/process_movi_train.py +886 -0
  9. EMS-superquadric_fitting_inference/process_movi_validation.py +886 -0
  10. EMS-superquadric_fitting_inference/process_viser_hierarchical.py +486 -0
  11. EMS-superquadric_fitting_inference/process_viser_single.py +263 -0
  12. EMS-superquadric_fitting_inference/pyproject.toml +3 -0
  13. EMS-superquadric_fitting_inference/setup.py +32 -0
  14. EMS-superquadric_fitting_inference/src/EMS/EMS_recovery.py +378 -0
  15. EMS-superquadric_fitting_inference/src/EMS/__init__.py +0 -0
  16. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.BoundVolume-279.py311.1.nbc +0 -0
  17. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.BoundVolume-279.py311.nbi +0 -0
  18. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.BoundVolume-279.py312.1.nbc +0 -0
  19. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.BoundVolume-279.py312.nbi +0 -0
  20. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py311.1.nbc +3 -0
  21. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py311.nbi +0 -0
  22. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py312.1.nbc +3 -0
  23. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py312.nbi +0 -0
  24. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py311.1.nbc +3 -0
  25. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py311.nbi +0 -0
  26. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py312.1.nbc +3 -0
  27. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py312.nbi +0 -0
  28. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py311.1.nbc +3 -0
  29. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py311.nbi +0 -0
  30. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py312.1.nbc +3 -0
  31. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py312.nbi +0 -0
  32. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Euler2RotM-339.py311.1.nbc +0 -0
  33. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Euler2RotM-339.py311.nbi +0 -0
  34. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Euler2RotM-339.py312.1.nbc +0 -0
  35. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Euler2RotM-339.py312.nbi +0 -0
  36. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.OutlierProb-316.py311.1.nbc +0 -0
  37. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.OutlierProb-316.py311.nbi +0 -0
  38. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.OutlierProb-316.py312.1.nbc +0 -0
  39. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.OutlierProb-316.py312.nbi +0 -0
  40. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.RotM2Euler-363.py311.1.nbc +0 -0
  41. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.RotM2Euler-363.py311.nbi +0 -0
  42. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.RotM2Euler-363.py312.1.nbc +0 -0
  43. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.RotM2Euler-363.py312.nbi +0 -0
  44. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py311.1.nbc +3 -0
  45. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py311.nbi +0 -0
  46. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py312.1.nbc +3 -0
  47. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py312.nbi +0 -0
  48. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SurfaceArea-324.py311.1.nbc +0 -0
  49. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SurfaceArea-324.py311.nbi +0 -0
  50. EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SurfaceArea-324.py312.1.nbc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py312.1.nbc filter=lfs diff=lfs merge=lfs -text
37
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py311.1.nbc filter=lfs diff=lfs merge=lfs -text
38
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py312.1.nbc filter=lfs diff=lfs merge=lfs -text
39
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SwitchCost-265.py311.1.nbc filter=lfs diff=lfs merge=lfs -text
40
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py311.1.nbc filter=lfs diff=lfs merge=lfs -text
41
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py312.1.nbc filter=lfs diff=lfs merge=lfs -text
42
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py311.1.nbc filter=lfs diff=lfs merge=lfs -text
43
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SwitchCost-265.py312.1.nbc filter=lfs diff=lfs merge=lfs -text
44
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py312.1.nbc filter=lfs diff=lfs merge=lfs -text
45
+ EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py311.1.nbc filter=lfs diff=lfs merge=lfs -text
EMS-superquadric_fitting_inference/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Weixiao Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
EMS-superquadric_fitting_inference/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Guidelines for Python Implementation
2
+
3
+ This is the guideline and structual explanation of the Python implementation of the EMS algorithm.
4
+
5
+ ### Dependency
6
+
7
+ The code is tested under Python 3.8.8, but should have little compatibility concerns.
8
+ The following packages are required to run the EMS algorithm:
9
+
10
+ 1. numpy 1.19.2
11
+ 2. scipy 1.5.2
12
+ 3. numba 0.53.1 -- for acceleration based on JIT (Just-In-Time compiler)
13
+
14
+
15
+ For demo, the following packages are needed:
16
+
17
+ 1. plyfile -- for loading `.ply` point cloud files
18
+ 2. mayavi -- for visualization of meshes and point clouds
19
+
20
+ ### Installation
21
+
22
+ We recommend to install the EMS package with `pip`.
23
+
24
+ 1. Change directory to `/Python`
25
+ 2. Install package: `pip install .`
26
+
27
+
28
+ ### Run Demo
29
+
30
+ The demo script is `/Python/tests/test_script.py`.
31
+ The demo reads a `.ply` point cloud and returns the parameters of the recovered superquadric, runtime, and visualization as required.
32
+
33
+ For example, in terminal type in
34
+
35
+ python test_script.py *.ply file location* --result --runtime --visualize
36
+
37
+ Note the first run of the code takes longer, since the JIT will translate the Python and NumPy code into fast machine code (and will be cached for futher calls).
EMS-superquadric_fitting_inference/README_MOVI.md ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MOVi-A Dataset for Superquadric Fitting
2
+
3
+ This directory contains scripts to download and process the MOVi-A dataset for superquadric fitting experiments.
4
+
5
+ ● 完成了MOVi-A数据集下载脚本!
6
+
7
+ 我在 EMS-superquadric_fitting_inference
8
+ 目录下创建了:
9
+
10
+ 1. download_movi_a.py - 主下载脚本
11
+
12
+ - 交互式选择下载数量(10/100/全部样本)
13
+ - 自动提取点云(从深度图)
14
+ - 保存为numpy格式,方便后续使用
15
+ - 分离每个物体的点云
16
+
17
+ 2. 数据组织结构
18
+
19
+ data/
20
+ └── movi_a_128x128_train/
21
+ ├── sample_00000/
22
+ │ ├── metadata.json #
23
+ 物体形状、材质等信息
24
+ │ ├── rgb/ # RGB图像
25
+ │ ├── depth/ # 深度图
26
+ │ ├── segmentation/ # 实例分割
27
+ │ ├── point_clouds/ # 提取的3D点云
28
+ │ └── trajectories.npz # 运动轨迹
29
+ └── ...
30
+
31
+ 3. 加载器脚本 (自动生成)
32
+
33
+ 下载完成后会生成
34
+ load_movi_a.py,无需TensorFlow即可加载数据:
35
+ loader = MOViALoader('data/movi_a_128x128_train')
36
+ frame = loader.load_frame(0, 0) #
37
+ 加载第0个样本的第0帧
38
+
39
+ ## Dataset Overview
40
+
41
+ MOVi-A is a synthetic dataset with:
42
+ - Simple geometric shapes (cube, sphere, cylinder) - perfect for superquadric fitting
43
+ - 3-10 objects per scene with physics simulation
44
+ - 128x128 resolution (smaller, faster to process)
45
+ - Rich annotations including depth, segmentation, 3D trajectories, physics properties, and collision events
46
+
47
+ ## Quick Start
48
+
49
+ 1. **Install minimal dependencies** (in WaveGen environment):
50
+ ```bash
51
+ conda create -n movi python=3.9
52
+ conda activate movi
53
+ pip install tensorflow-cpu tensorflow-datasets opencv-python tqdm
54
+ ```
55
+
56
+ 2. **Download the dataset**:
57
+ ```bash
58
+ python download_movi_a.py
59
+ ```
60
+
61
+ Press Enter to download ALL samples, or specify a number for testing.
62
+
63
+ 3. **After download, you can uninstall TensorFlow** if desired:
64
+ ```bash
65
+ pip uninstall tensorflow tensorflow-datasets
66
+ ```
67
+
68
+ ## Data Structure
69
+
70
+ After downloading, the data is organized as:
71
+
72
+ ```
73
+ data/
74
+ └── movi_a_128x128_train/
75
+ ├── dataset_info.json # Overall dataset metadata
76
+ ├── sample_00000/ # Each sample
77
+ │ ├── metadata.json # Sample metadata + physics properties
78
+ │ ├── rgb/ # RGB frames (PNG)
79
+ │ │ ├── frame_000.png
80
+ │ │ └── ...
81
+ │ ├── depth/ # Depth maps (NPY)
82
+ │ │ ├── frame_000.npy
83
+ │ │ └── ...
84
+ │ ├── segmentation/ # Instance masks (NPY)
85
+ │ │ ├── frame_000.npy
86
+ │ │ └── ...
87
+ │ ├── normal/ # Surface normals (NPY)
88
+ │ │ ├── frame_000.npy
89
+ │ │ └── ...
90
+ │ ├── object_coordinates/ # Object-relative coordinates (NPY)
91
+ │ │ ├── frame_000.npy
92
+ │ │ └── ...
93
+ │ ├── point_clouds/ # Extracted 3D points
94
+ │ │ ├── frame_000_full.npy
95
+ │ │ ├── frame_000_instance_1.npy
96
+ │ │ └── ...
97
+ │ ├── trajectories.npz # Full object motion data
98
+ │ ├── camera_trajectory.npz # Camera positions and rotations
99
+ │ └── collisions.npz # Collision events data
100
+ └── ...
101
+ ```
102
+
103
+ ## Loading Data (No TensorFlow Required!)
104
+
105
+ ```python
106
+ from load_movi_a import MOViALoader
107
+
108
+ # Initialize loader
109
+ loader = MOViALoader('data/movi_a_128x128_train')
110
+
111
+ # Load sample metadata
112
+ sample = loader.load_sample(0)
113
+ print(f"Objects: {sample['metadata']['num_instances']}")
114
+
115
+ # Load frame data
116
+ frame = loader.load_frame(sample_idx=0, frame_idx=0)
117
+ rgb = frame['rgb'] # (128, 128, 3)
118
+ depth = frame['depth'] # (128, 128, 1)
119
+ points = frame['point_cloud'] # (N, 3)
120
+
121
+ # Load instance-specific point cloud
122
+ instance_pc = loader.load_instance_point_cloud(0, 0, instance_id=1)
123
+ ```
124
+
125
+ ## Using with Superquadric Fitting
126
+
127
+ ```python
128
+ from EMS.EMS_recovery import EMS_recovery
129
+ from load_movi_a import MOViALoader
130
+
131
+ loader = MOViALoader('data/movi_a_128x128_train')
132
+
133
+ # Fit superquadric to first object in first frame
134
+ instance_pc = loader.load_instance_point_cloud(0, 0, 1)
135
+ if instance_pc is not None and len(instance_pc) > 100:
136
+ sq, p = EMS_recovery(instance_pc, OutlierRatio=0.2)
137
+ print(f"Shape parameters: {sq.shape}")
138
+ ```
139
+
140
+ ## Object Properties
141
+
142
+ Each object has (stored in `metadata.json`):
143
+ - **Shape**: cube, sphere, or cylinder
144
+ - **Size**: small or large
145
+ - **Material**: metal or rubber
146
+ - **Color**: 8 different colors + RGB values
147
+ - **Physics properties**:
148
+ - mass: Mass of the object
149
+ - friction: Friction coefficient (metal=0.4, rubber=0.8)
150
+ - restitution: Bounciness (metal=0.3, rubber=0.7)
151
+
152
+ ## Additional Data
153
+
154
+ ### Trajectories (`trajectories.npz`)
155
+ - positions: (num_objects, 24, 3) - 3D positions
156
+ - quaternions: (num_objects, 24, 4) - Rotations
157
+ - velocities: (num_objects, 24, 3) - Linear velocities
158
+ - angular_velocities: (num_objects, 24, 3) - Angular velocities
159
+ - visibility: (num_objects, 24) - Pixel count visibility
160
+ - bboxes_3d: (num_objects, 24, 8, 3) - 3D bounding box corners
161
+ - image_positions: (num_objects, 24, 2) - 2D center of mass
162
+
163
+ ### Camera (`camera_trajectory.npz`)
164
+ - positions: (24, 3) - Camera positions (static in MOVi-A)
165
+ - quaternions: (24, 4) - Camera rotations
166
+
167
+ ### Collisions (`collisions.npz`)
168
+ - instances: (N, 2) - Pairs of colliding objects
169
+ - frame: (N,) - Frame of collision
170
+ - force: (N,) - Collision force
171
+ - position: (N, 3) - 3D collision position
172
+ - image_position: (N, 2) - 2D collision position
173
+ - contact_normal: (N, 3) - Collision normal vector
174
+
175
+ ## Tips
176
+
177
+ 1. Start with 10 samples to test your pipeline
178
+ 2. The point clouds are already extracted from depth maps
179
+ 3. Instance segmentation helps separate objects
180
+ 4. Use trajectories.npz for temporal consistency
181
+
182
+ ## Storage Requirements (with all data)
183
+
184
+ - 10 samples: ~250 MB
185
+ - 100 samples: ~2.5 GB
186
+ - Full training set (9750 samples): ~25 GB
187
+ - Validation set (250 samples): ~650 MB
188
+ - **Total (train + validation): ~26 GB**
EMS-superquadric_fitting_inference/__pycache__/process_movi_validation.cpython-311.pyc ADDED
Binary file (37.3 kB). View file
 
EMS-superquadric_fitting_inference/__pycache__/process_movi_validation.cpython-312.pyc ADDED
Binary file (33.2 kB). View file
 
EMS-superquadric_fitting_inference/download_movi_a.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple MOVi-A dataset loader using the Kubric example code.
4
+ This version works with the MOVi datasets hosted on Google Cloud.
5
+ """
6
+
7
+ import os
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ import tensorflow_datasets as tfds
11
+ from pathlib import Path
12
+ import json
13
+ import cv2
14
+ from tqdm import tqdm
15
+
16
+ # Reduce TF logging
17
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
18
+
19
+
20
+ def extract_point_cloud_from_depth(depth, camera_K, segmentation=None, instance_id=None):
21
+ """
22
+ Convert depth image to 3D point cloud using camera intrinsics.
23
+ """
24
+ H, W = depth.shape[:2]
25
+
26
+ # Create pixel coordinates
27
+ xx, yy = np.meshgrid(np.arange(W), np.arange(H))
28
+
29
+ # Get valid depth values
30
+ if instance_id is not None and segmentation is not None:
31
+ mask = (segmentation == instance_id) & (depth > 0)
32
+ else:
33
+ mask = depth > 0
34
+
35
+ # Extract valid coordinates
36
+ valid_x = xx[mask]
37
+ valid_y = yy[mask]
38
+ valid_z = depth[mask]
39
+
40
+ # Unproject to 3D using camera intrinsics
41
+ # K = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
42
+ fx, fy = camera_K[0, 0], camera_K[1, 1]
43
+ cx, cy = camera_K[0, 2], camera_K[1, 2]
44
+
45
+ x_3d = (valid_x - cx) * valid_z / fx
46
+ y_3d = (valid_y - cy) * valid_z / fy
47
+ z_3d = valid_z
48
+
49
+ return np.stack([x_3d, y_3d, z_3d], axis=-1)
50
+
51
+
52
+ def process_and_save_sample(sample, output_dir, sample_idx):
53
+ """Process a single MOVi sample and save to disk."""
54
+ sample_dir = Path(output_dir) / f"sample_{sample_idx:05d}"
55
+
56
+ # Check if sample already fully processed
57
+ if sample_dir.exists():
58
+ # Check if key files exist to determine if download was complete
59
+ required_files = [
60
+ sample_dir / "metadata.json",
61
+ sample_dir / "trajectories.npz",
62
+ sample_dir / "camera_trajectory.npz"
63
+ ]
64
+ if all(f.exists() for f in required_files):
65
+ # Also check if all frames are downloaded
66
+ num_frames = 24 # MOVi-A has 24 frames
67
+ frame_files_exist = all(
68
+ (sample_dir / "rgb" / f"frame_{i:03d}.png").exists()
69
+ for i in range(num_frames)
70
+ )
71
+ if frame_files_exist:
72
+ print(f" Sample {sample_idx:05d} already downloaded, skipping...")
73
+ return None
74
+
75
+ sample_dir.mkdir(parents=True, exist_ok=True)
76
+
77
+ # Decode depth values
78
+ minv, maxv = sample["metadata"]["depth_range"]
79
+ depth = sample["depth"] / 65535 * (maxv - minv) + minv
80
+
81
+ # Get camera info
82
+ focal_length = float(sample["camera"]["focal_length"])
83
+ sensor_width = float(sample["camera"]["sensor_width"])
84
+ field_of_view = float(sample["camera"]["field_of_view"])
85
+ resolution = sample["video"].shape[1] # Assuming square
86
+
87
+ # Compute camera intrinsics
88
+ fx = fy = focal_length * resolution / sensor_width
89
+ cx = cy = resolution / 2
90
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
91
+
92
+ # Extract metadata
93
+ metadata = {
94
+ "num_frames": int(sample["metadata"]["num_frames"]),
95
+ "num_instances": int(sample["metadata"]["num_instances"]),
96
+ "resolution": resolution,
97
+ "depth_range": [float(minv), float(maxv)],
98
+ "camera": {
99
+ "focal_length": focal_length,
100
+ "sensor_width": sensor_width,
101
+ "field_of_view": field_of_view,
102
+ "K": K.tolist()
103
+ },
104
+ "instances": []
105
+ }
106
+
107
+ # Process instance information
108
+ for i in range(metadata["num_instances"]):
109
+ # Handle both string and integer labels
110
+ shape_label = sample["instances"]["shape_label"][i]
111
+ size_label = sample["instances"]["size_label"][i]
112
+ color_label = sample["instances"]["color_label"][i]
113
+ material_label = sample["instances"]["material_label"][i]
114
+
115
+ # Decode if bytes, otherwise convert to string
116
+ if hasattr(shape_label, 'decode'):
117
+ shape = shape_label.decode()
118
+ else:
119
+ # Map integer labels to names
120
+ shape_names = ["cube", "cylinder", "sphere"]
121
+ shape = shape_names[int(shape_label)] if int(shape_label) < len(shape_names) else str(shape_label)
122
+
123
+ if hasattr(size_label, 'decode'):
124
+ size = size_label.decode()
125
+ else:
126
+ size_names = ["small", "large"]
127
+ size = size_names[int(size_label)] if int(size_label) < len(size_names) else str(size_label)
128
+
129
+ if hasattr(color_label, 'decode'):
130
+ color = color_label.decode()
131
+ else:
132
+ color_names = ["blue", "brown", "cyan", "gray", "green", "purple", "red", "yellow"]
133
+ color = color_names[int(color_label)] if int(color_label) < len(color_names) else str(color_label)
134
+
135
+ if hasattr(material_label, 'decode'):
136
+ material = material_label.decode()
137
+ else:
138
+ material_names = ["metal", "rubber"]
139
+ material = material_names[int(material_label)] if int(material_label) < len(material_names) else str(material_label)
140
+
141
+ # Extract physics properties
142
+ mass = float(sample["instances"]["mass"][i])
143
+ friction = float(sample["instances"]["friction"][i])
144
+ restitution = float(sample["instances"]["restitution"][i])
145
+
146
+ # Extract color RGB values
147
+ color_rgb = sample["instances"]["color"][i].tolist()
148
+
149
+ instance_info = {
150
+ "id": i + 1, # 1-indexed in segmentation
151
+ "shape": shape,
152
+ "size": size,
153
+ "color": color,
154
+ "color_rgb": color_rgb,
155
+ "material": material,
156
+ "mass": mass,
157
+ "friction": friction,
158
+ "restitution": restitution
159
+ }
160
+ metadata["instances"].append(instance_info)
161
+
162
+ # Save metadata
163
+ with open(sample_dir / "metadata.json", 'w') as f:
164
+ json.dump(metadata, f, indent=2)
165
+
166
+ # Save trajectories with all motion data
167
+ np.savez_compressed(
168
+ sample_dir / "trajectories.npz",
169
+ positions=sample["instances"]["positions"],
170
+ quaternions=sample["instances"]["quaternions"],
171
+ velocities=sample["instances"]["velocities"],
172
+ angular_velocities=sample["instances"]["angular_velocities"],
173
+ visibility=sample["instances"]["visibility"],
174
+ bboxes_3d=sample["instances"]["bboxes_3d"],
175
+ image_positions=sample["instances"]["image_positions"]
176
+ )
177
+
178
+ # Save camera trajectory
179
+ np.savez_compressed(
180
+ sample_dir / "camera_trajectory.npz",
181
+ positions=sample["camera"]["positions"],
182
+ quaternions=sample["camera"]["quaternions"]
183
+ )
184
+
185
+ # Save collision events
186
+ if "events" in sample and "collisions" in sample["events"]:
187
+ collisions = sample["events"]["collisions"]
188
+ collision_data = {
189
+ "instances": collisions["instances"],
190
+ "frame": collisions["frame"],
191
+ "force": collisions["force"],
192
+ "position": collisions["position"],
193
+ "image_position": collisions["image_position"],
194
+ "contact_normal": collisions["contact_normal"]
195
+ }
196
+ np.savez_compressed(sample_dir / "collisions.npz", **collision_data)
197
+
198
+ # Process and save frames
199
+ (sample_dir / "rgb").mkdir(exist_ok=True)
200
+ (sample_dir / "depth").mkdir(exist_ok=True)
201
+ (sample_dir / "segmentation").mkdir(exist_ok=True)
202
+ (sample_dir / "normal").mkdir(exist_ok=True)
203
+ (sample_dir / "object_coordinates").mkdir(exist_ok=True)
204
+ (sample_dir / "point_clouds").mkdir(exist_ok=True)
205
+
206
+ for t in range(metadata["num_frames"]):
207
+ # Save RGB
208
+ rgb = sample["video"][t]
209
+ cv2.imwrite(str(sample_dir / "rgb" / f"frame_{t:03d}.png"),
210
+ cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
211
+
212
+ # Save depth
213
+ np.save(sample_dir / "depth" / f"frame_{t:03d}.npy", depth[t])
214
+
215
+ # Save segmentation
216
+ seg = sample["segmentations"][t, :, :, 0]
217
+ np.save(sample_dir / "segmentation" / f"frame_{t:03d}.npy", seg)
218
+
219
+ # Save normal
220
+ normal = sample["normal"][t]
221
+ np.save(sample_dir / "normal" / f"frame_{t:03d}.npy", normal)
222
+
223
+ # Save object coordinates
224
+ obj_coords = sample["object_coordinates"][t]
225
+ np.save(sample_dir / "object_coordinates" / f"frame_{t:03d}.npy", obj_coords)
226
+
227
+ # Extract and save point clouds
228
+ # Full scene
229
+ pc_full = extract_point_cloud_from_depth(depth[t, :, :, 0], K)
230
+ np.save(sample_dir / "point_clouds" / f"frame_{t:03d}_full.npy", pc_full)
231
+
232
+ # Per-instance
233
+ for i in range(metadata["num_instances"]):
234
+ instance_id = i + 1
235
+ pc_instance = extract_point_cloud_from_depth(
236
+ depth[t, :, :, 0], K, seg, instance_id
237
+ )
238
+ if len(pc_instance) > 0:
239
+ np.save(sample_dir / "point_clouds" / f"frame_{t:03d}_obj{instance_id}.npy",
240
+ pc_instance)
241
+
242
+ return metadata
243
+
244
+
245
+ def main():
246
+ print("MOVi-A Dataset Downloader (Simple Version)")
247
+ print("=" * 50)
248
+
249
+ # Output directory
250
+ script_dir = Path(__file__).parent
251
+ output_base_dir = script_dir / ".." / "data" / "movi_a_128x128"
252
+ output_base_dir.mkdir(parents=True, exist_ok=True)
253
+
254
+ print(f"Output directory: {output_base_dir}")
255
+ print("\nAttempting to load MOVi-A from Google Cloud Storage...")
256
+
257
+ try:
258
+ # Download both train and validation splits
259
+ for split_name in ["train", "validation"]:
260
+ print(f"\n{'='*50}")
261
+ print(f"Processing {split_name.upper()} split...")
262
+ print(f"{'='*50}")
263
+
264
+ output_dir = output_base_dir / split_name
265
+ output_dir.mkdir(exist_ok=True)
266
+
267
+ # Check existing samples
268
+ existing_samples = len(list(output_dir.glob("sample_*")))
269
+ if existing_samples > 0:
270
+ print(f"Found {existing_samples} existing samples in {split_name} directory")
271
+
272
+ # Load split
273
+ ds = tfds.load(
274
+ "movi_a/128x128",
275
+ split=split_name,
276
+ data_dir="gs://kubric-public/tfds",
277
+ with_info=False
278
+ )
279
+
280
+ print(f"Successfully connected to MOVi-A {split_name} dataset!")
281
+ print(f"Processing ALL {split_name} samples (will skip existing)...")
282
+
283
+ # Process all samples
284
+ total_processed = 0
285
+ total_skipped = 0
286
+ for idx, sample in enumerate(tqdm(tfds.as_numpy(ds), desc=f"Processing {split_name}")):
287
+ metadata = process_and_save_sample(sample, output_dir, idx)
288
+ if metadata is None:
289
+ total_skipped += 1
290
+ else:
291
+ total_processed += 1
292
+
293
+ print(f"\nProcessed {split_name} split:")
294
+ print(f" - Downloaded: {total_processed} samples")
295
+ print(f" - Skipped (already exist): {total_skipped} samples")
296
+ print(f" - Total: {total_processed + total_skipped} samples")
297
+
298
+ # Save split info
299
+ dataset_info = {
300
+ "dataset": "movi_a",
301
+ "split": split_name,
302
+ "resolution": "128x128",
303
+ "num_samples": total_processed,
304
+ "fps": 12,
305
+ "num_frames_per_sample": 24
306
+ }
307
+ with open(output_dir / "dataset_info.json", 'w') as f:
308
+ json.dump(dataset_info, f, indent=2)
309
+
310
+ print(f"\n{'='*50}")
311
+ print(f"All downloads complete!")
312
+ print(f"Data saved to: {output_base_dir}")
313
+
314
+ # Count final samples
315
+ train_count = len(list((output_base_dir / "train").glob("sample_*")))
316
+ val_count = len(list((output_base_dir / "validation").glob("sample_*"))) if (output_base_dir / "validation").exists() else 0
317
+
318
+ print(f"\nFinal dataset size:")
319
+ print(f" - Train samples: {train_count} in {output_base_dir}/train")
320
+ print(f" - Validation samples: {val_count} in {output_base_dir}/validation")
321
+
322
+ # Create simple loader
323
+ create_loader_script(output_base_dir)
324
+
325
+ except KeyboardInterrupt:
326
+ print("\n\nDownload interrupted by user. You can run the script again to resume.")
327
+ return
328
+ except Exception as e:
329
+ print(f"\nError: {e}")
330
+ import traceback
331
+ traceback.print_exc()
332
+ print("\nNote: You can run the script again to resume downloading.")
333
+ print("\nAlternative: Download manually using gsutil")
334
+ print("1. Install: conda install -c conda-forge google-cloud-sdk")
335
+ print("2. Download a few samples manually:")
336
+ print(" gsutil -m cp -r gs://kubric-public/tfds/movi_a/128x128/1.0.0/movi_a-train.tfrecord-00000-of-00256 ./")
337
+ print("\nOr try the original Kubric repository:")
338
+ print(" https://github.com/google-research/kubric")
339
+
340
+
341
+ def create_loader_script(output_dir):
342
+ """Create a simple loader script."""
343
+ script = '''#!/usr/bin/env python3
344
+ """Simple MOVi-A data loader - no TensorFlow required!"""
345
+
346
+ import numpy as np
347
+ import json
348
+ import cv2
349
+ from pathlib import Path
350
+
351
+
352
+ class MOViLoader:
353
+ def __init__(self, data_dir, split="train"):
354
+ self.data_dir = Path(data_dir)
355
+ self.split = split
356
+ self.split_dir = self.data_dir / split
357
+ self.samples = sorted(list(self.split_dir.glob("sample_*")))
358
+
359
+ def load_sample(self, idx):
360
+ """Load metadata and trajectories for a sample."""
361
+ sample_dir = self.samples[idx]
362
+
363
+ with open(sample_dir / "metadata.json", 'r') as f:
364
+ metadata = json.load(f)
365
+
366
+ trajectories = np.load(sample_dir / "trajectories.npz")
367
+
368
+ return {
369
+ "metadata": metadata,
370
+ "trajectories": trajectories,
371
+ "sample_dir": sample_dir,
372
+ "camera": None,
373
+ "collisions": None
374
+ }
375
+
376
+ # Load camera trajectory if exists
377
+ camera_path = sample_dir / "camera_trajectory.npz"
378
+ if camera_path.exists():
379
+ data["camera"] = np.load(camera_path)
380
+
381
+ # Load collision data if exists
382
+ collision_path = sample_dir / "collisions.npz"
383
+ if collision_path.exists():
384
+ data["collisions"] = np.load(collision_path)
385
+
386
+ return data
387
+
388
+ def load_frame(self, sample_idx, frame_idx):
389
+ """Load all data for a specific frame."""
390
+ sample_dir = self.samples[sample_idx]
391
+
392
+ # Load RGB
393
+ rgb = cv2.imread(str(sample_dir / "rgb" / f"frame_{frame_idx:03d}.png"))
394
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
395
+
396
+ # Load depth
397
+ depth = np.load(sample_dir / "depth" / f"frame_{frame_idx:03d}.npy")
398
+
399
+ # Load segmentation
400
+ seg = np.load(sample_dir / "segmentation" / f"frame_{frame_idx:03d}.npy")
401
+
402
+ # Load full point cloud
403
+ pc = np.load(sample_dir / "point_clouds" / f"frame_{frame_idx:03d}_full.npy")
404
+
405
+ return {
406
+ "rgb": rgb,
407
+ "depth": depth,
408
+ "segmentation": seg,
409
+ "point_cloud": pc,
410
+ "normal": None,
411
+ "object_coordinates": None
412
+ }
413
+
414
+ # Load normal if exists
415
+ normal_path = sample_dir / "normal" / f"frame_{frame_idx:03d}.npy"
416
+ if normal_path.exists():
417
+ frame_data["normal"] = np.load(normal_path)
418
+
419
+ # Load object coordinates if exists
420
+ obj_coord_path = sample_dir / "object_coordinates" / f"frame_{frame_idx:03d}.npy"
421
+ if obj_coord_path.exists():
422
+ frame_data["object_coordinates"] = np.load(obj_coord_path)
423
+
424
+ return frame_data
425
+
426
+ def load_object_points(self, sample_idx, frame_idx, object_id):
427
+ """Load point cloud for a specific object."""
428
+ sample_dir = self.samples[sample_idx]
429
+ pc_file = sample_dir / "point_clouds" / f"frame_{frame_idx:03d}_obj{object_id}.npy"
430
+
431
+ if pc_file.exists():
432
+ return np.load(pc_file)
433
+ return None
434
+
435
+
436
+ if __name__ == "__main__":
437
+ # Example usage
438
+ loader_train = MOViLoader(".", split="train")
439
+ loader_val = MOViLoader(".", split="validation")
440
+ print(f"Found {len(loader_train.samples)} training samples")
441
+ print(f"Found {len(loader_val.samples)} validation samples")
442
+
443
+ if len(loader_train.samples) > 0:
444
+ # Load first sample
445
+ sample = loader_train.load_sample(0)
446
+ print(f"\\nTrain sample 0: {sample['metadata']['num_instances']} objects")
447
+
448
+ # Load first frame
449
+ frame = loader_train.load_frame(0, 0)
450
+ print(f"Point cloud shape: {frame['point_cloud'].shape}")
451
+ '''
452
+
453
+ loader_path = output_dir.parent / "load_movi.py"
454
+ with open(loader_path, 'w') as f:
455
+ f.write(script)
456
+
457
+ print(f"\nCreated loader script: {loader_path}")
458
+
459
+
460
+ if __name__ == "__main__":
461
+ print("\nTip: This script supports resuming downloads. If interrupted, just run it again!\n")
462
+ main()
EMS-superquadric_fitting_inference/process_movi_train.py ADDED
@@ -0,0 +1,886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Process MOVi-A train set with hierarchical multi-superquadric fitting
4
+ Converts depth maps to normalized point clouds for superquadric fitting
5
+ """
6
+
7
+ import numpy as np
8
+ import sys
9
+ import os
10
+ import time
11
+ import viser
12
+ import json
13
+ import cv2
14
+ from pathlib import Path
15
+ from sklearn.cluster import DBSCAN
16
+
17
+ # Add the src directory to Python path
18
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
19
+
20
+ from EMS.EMS_recovery import EMS_recovery
21
+
22
+
23
+ def depth_to_normalized_pointcloud(depth, segmentation, camera_K, camera_position=None, camera_quaternion=None, resolution=128, convert_to_zdepth=True):
24
+ """
25
+ Convert depth map to normalized point cloud in range [-10, 10]
26
+
27
+ Args:
28
+ depth: (H, W, 1) depth array (euclidean distance from camera center)
29
+ segmentation: (H, W) instance segmentation mask
30
+ camera_K: 3x3 camera intrinsic matrix
31
+ camera_position: camera position in world coordinates
32
+ camera_quaternion: camera quaternion (x,y,z,w) in world coordinates
33
+ resolution: image resolution (assuming square)
34
+ convert_to_zdepth: bool, convert euclidean depth to z-depth before processing
35
+
36
+ Returns:
37
+ dict: instance_id -> normalized point cloud in world coordinates
38
+ """
39
+ H, W = depth.shape[:2]
40
+
41
+ # Get camera parameters
42
+ fx = camera_K[0, 0]
43
+ fy = camera_K[1, 1]
44
+ cx = camera_K[0, 2]
45
+ cy = camera_K[1, 2]
46
+
47
+ # Create pixel grid
48
+ xx, yy = np.meshgrid(np.arange(W), np.arange(H))
49
+
50
+ # Convert to normalized camera coordinates
51
+ x_norm = (xx - cx) / fx
52
+ y_norm = (yy - cy) / fy
53
+
54
+ if convert_to_zdepth:
55
+ # MOVi uses euclidean distance, convert to z-depth (planar depth)
56
+ # For each pixel, we have: euclidean_dist^2 = x^2 + y^2 + z^2
57
+ # Where x = x_norm * z, y = y_norm * z
58
+ # So: euclidean_dist^2 = (x_norm^2 + y_norm^2 + 1) * z^2
59
+ z = depth[:, :, 0] / np.sqrt(x_norm**2 + y_norm**2 + 1)
60
+ else:
61
+ # Use depth as-is (assume it's already z-depth)
62
+ z = depth[:, :, 0]
63
+
64
+ # Get 3D points
65
+ x = x_norm * z
66
+ y = y_norm * z
67
+
68
+ # Stack to get point cloud (in camera coordinates)
69
+ points_3d_camera = np.stack([x, y, z], axis=-1)
70
+
71
+ # Transform from camera to world coordinates if camera pose is provided
72
+ if camera_position is not None and camera_quaternion is not None:
73
+ from scipy.spatial.transform import Rotation
74
+
75
+ # Convert quaternion to rotation matrix
76
+ # MOVi uses [x, y, z, w] format
77
+ cam_rot = Rotation.from_quat(camera_quaternion)
78
+ cam_rot_matrix = cam_rot.as_matrix()
79
+
80
+ # Transform points: World = R * Camera + T
81
+ points_3d_flat = points_3d_camera.reshape(-1, 3)
82
+ points_3d_world = points_3d_flat @ cam_rot_matrix.T + camera_position
83
+ points_3d = points_3d_world.reshape(points_3d_camera.shape)
84
+ else:
85
+ points_3d = points_3d_camera
86
+
87
+
88
+ # Normalize entire scene to [-10, 10] range
89
+ # Find scene bounds (only valid depth points)
90
+ valid_mask = z > 0
91
+ valid_points = points_3d[valid_mask]
92
+
93
+ if len(valid_points) > 0:
94
+ # Find scene extent
95
+ scene_min = np.min(valid_points, axis=0)
96
+ scene_max = np.max(valid_points, axis=0)
97
+ scene_center = (scene_min + scene_max) / 2
98
+ scene_extent = np.max(scene_max - scene_min)
99
+
100
+ # Scale to [-10, 10]
101
+ if scene_extent > 0:
102
+ scale_factor = 20.0 / scene_extent # 20 because we want -10 to 10
103
+ points_3d_normalized = (points_3d - scene_center) * scale_factor
104
+ else:
105
+ points_3d_normalized = points_3d - scene_center
106
+ else:
107
+ points_3d_normalized = points_3d
108
+
109
+ # Get unique instance IDs (excluding background=0)
110
+ instance_ids = np.unique(segmentation)
111
+ instance_ids = instance_ids[instance_ids > 0]
112
+
113
+ instance_pointclouds = {}
114
+
115
+ for inst_id in instance_ids:
116
+ # Get mask for this instance
117
+ mask = segmentation == inst_id
118
+
119
+ # Extract points for this instance (already normalized with scene)
120
+ instance_points = points_3d_normalized[mask]
121
+
122
+ if len(instance_points) < 50: # Skip if too few points
123
+ continue
124
+
125
+ instance_pointclouds[int(inst_id)] = instance_points
126
+
127
+ # Also return the full scene point cloud and segmentation for visualization
128
+ return instance_pointclouds, points_3d_normalized, segmentation, scene_center if 'scene_center' in locals() else np.zeros(3), scene_extent if 'scene_extent' in locals() else 1.0
129
+
130
+
131
+ def hierarchical_ems(
132
+ point,
133
+ OutlierRatio=0.5,
134
+ MaxIterationEM=20,
135
+ ToleranceEM=1e-3,
136
+ RelativeToleranceEM=2e-1,
137
+ MaxOptiIterations=2,
138
+ Sigma=0.3,
139
+ MaxiSwitch=2,
140
+ AdaptiveUpperBound=True,
141
+ Rescale=False,
142
+ MaxLayer=3,
143
+ Eps=1.0, # Adjusted for normalized [-10, 10] point clouds
144
+ MinPoints=50,
145
+ ):
146
+ """
147
+ Hierarchical EMS for extracting multiple superquadrics from a point cloud
148
+ """
149
+ point_seg = {key: [] for key in list(range(0, MaxLayer+1))}
150
+ point_outlier = {key: [] for key in list(range(0, MaxLayer+1))}
151
+ point_seg[0] = [point]
152
+ list_quadrics = []
153
+ quadric_info = []
154
+
155
+ for h in range(MaxLayer):
156
+ if len(point_seg[h]) == 0:
157
+ break
158
+
159
+ for c in range(len(point_seg[h])):
160
+ current_points = point_seg[h][c]
161
+ if len(current_points) < MinPoints * 2:
162
+ continue
163
+
164
+ try:
165
+ # Fit superquadric
166
+ x_raw, p_raw = EMS_recovery(
167
+ current_points,
168
+ OutlierRatio,
169
+ MaxIterationEM,
170
+ ToleranceEM,
171
+ RelativeToleranceEM,
172
+ MaxOptiIterations,
173
+ Sigma,
174
+ MaxiSwitch,
175
+ AdaptiveUpperBound,
176
+ Rescale,
177
+ )
178
+
179
+ # Calculate fitting quality
180
+ inlier_mask = p_raw > 0.5
181
+ inlier_ratio = np.sum(inlier_mask) / len(p_raw)
182
+
183
+ if inlier_ratio > 0.3: # Accept if at least 30% inliers
184
+ list_quadrics.append(x_raw)
185
+ quadric_info.append({
186
+ 'layer': h,
187
+ 'segment': c,
188
+ 'inlier_ratio': inlier_ratio,
189
+ 'num_points': len(current_points),
190
+ 'inlier_points': current_points[inlier_mask]
191
+ })
192
+
193
+ # Separate outliers for next layer
194
+ outlier_mask = p_raw < 0.1
195
+ outlier = current_points[outlier_mask]
196
+
197
+ # If many outliers and not last layer, try clustering
198
+ if len(outlier) > MinPoints * 2 and h < MaxLayer - 1:
199
+ clustering = DBSCAN(eps=Eps, min_samples=MinPoints).fit(outlier)
200
+ labels = list(set(clustering.labels_))
201
+ labels = [item for item in labels if item >= 0]
202
+
203
+ if len(labels) >= 1:
204
+ for i in range(len(labels)):
205
+ cluster_points = outlier[clustering.labels_ == labels[i]]
206
+ if len(cluster_points) > MinPoints:
207
+ point_seg[h + 1].append(cluster_points)
208
+
209
+ except Exception as e:
210
+ continue
211
+
212
+ return list_quadrics, quadric_info
213
+
214
+
215
+ def generate_superquadric_mesh(sq, num_samples=25):
216
+ """Generate mesh vertices and faces for superquadric surface"""
217
+ eta = np.linspace(-np.pi/2, np.pi/2, num_samples)
218
+ omega = np.linspace(-np.pi, np.pi, num_samples)
219
+
220
+ vertices = []
221
+ faces = []
222
+
223
+ # Generate vertices
224
+ for i, e in enumerate(eta):
225
+ for j, w in enumerate(omega):
226
+ # Superquadric parametric equations
227
+ cos_eta = np.sign(np.cos(e)) * np.abs(np.cos(e))**sq.shape[0]
228
+ sin_eta = np.sign(np.sin(e)) * np.abs(np.sin(e))**sq.shape[0]
229
+ cos_omega = np.sign(np.cos(w)) * np.abs(np.cos(w))**sq.shape[1]
230
+ sin_omega = np.sign(np.sin(w)) * np.abs(np.sin(w))**sq.shape[1]
231
+
232
+ # Local coordinates
233
+ x_local = sq.scale[0] * cos_eta * cos_omega
234
+ y_local = sq.scale[1] * cos_eta * sin_omega
235
+ z_local = sq.scale[2] * sin_eta
236
+
237
+ # Apply rotation and translation
238
+ point_local = np.array([x_local, y_local, z_local])
239
+ point_global = sq.RotM @ point_local + sq.translation
240
+
241
+ vertices.append(point_global)
242
+
243
+ vertices = np.array(vertices)
244
+
245
+ # Generate faces (triangles)
246
+ for i in range(num_samples - 1):
247
+ for j in range(num_samples - 1):
248
+ # Current vertex indices
249
+ idx1 = i * num_samples + j
250
+ idx2 = i * num_samples + (j + 1) % num_samples
251
+ idx3 = (i + 1) * num_samples + j
252
+ idx4 = (i + 1) * num_samples + (j + 1) % num_samples
253
+
254
+ # Two triangles per quad
255
+ faces.append([idx1, idx2, idx3])
256
+ faces.append([idx2, idx4, idx3])
257
+
258
+ return vertices, np.array(faces)
259
+
260
+
261
+ def preprocess_all_frames(samples_info):
262
+ """Preprocess all frames for all samples"""
263
+ print("\n" + "="*60)
264
+ print("PREPROCESSING ALL FRAMES")
265
+ print("="*60)
266
+
267
+ all_results = {}
268
+
269
+ for sample_idx, sample in enumerate(samples_info):
270
+ print(f"\nProcessing {sample['name']} ({sample_idx + 1}/{len(samples_info)})")
271
+ sample_results = {}
272
+
273
+ for frame_idx in range(sample['num_frames']):
274
+ print(f" Frame {frame_idx}/{sample['num_frames']-1}", end='', flush=True)
275
+
276
+ try:
277
+ # Load depth, segmentation and RGB
278
+ depth = np.load(sample['dir'] / "depth" / f"frame_{frame_idx:03d}.npy")
279
+ segmentation = np.load(sample['dir'] / "segmentation" / f"frame_{frame_idx:03d}.npy")
280
+
281
+ # Load RGB image
282
+ rgb_path = sample['dir'] / "rgb" / f"frame_{frame_idx:03d}.png"
283
+ rgb_image = cv2.imread(str(rgb_path))
284
+ rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
285
+
286
+ # Get camera intrinsics
287
+ camera_K = np.array(sample['metadata']['camera']['K'])
288
+
289
+ # Load camera trajectory
290
+ camera_traj = np.load(sample['dir'] / "camera_trajectory.npz")
291
+ camera_position = camera_traj['positions'][frame_idx]
292
+ camera_quaternion = camera_traj['quaternions'][frame_idx]
293
+
294
+ # Convert to normalized point clouds
295
+ instance_pointclouds, scene_points, scene_seg, scene_center, scene_extent = depth_to_normalized_pointcloud(
296
+ depth, segmentation, camera_K,
297
+ camera_position=camera_position,
298
+ camera_quaternion=camera_quaternion,
299
+ convert_to_zdepth=True
300
+ )
301
+
302
+ # Process each instance
303
+ instances = []
304
+ for inst_id, points in instance_pointclouds.items():
305
+ inst_info = sample['metadata']['instances'][inst_id - 1]
306
+
307
+ try:
308
+ # Fit superquadric
309
+ sq, p = EMS_recovery(
310
+ points,
311
+ OutlierRatio=0.13,
312
+ MaxIterationEM=20,
313
+ AdaptiveUpperBound=True,
314
+ Rescale=False
315
+ )
316
+
317
+ inlier_ratio = np.sum(p > 0.5) / len(p)
318
+
319
+ instances.append({
320
+ 'id': inst_id,
321
+ 'info': inst_info,
322
+ 'points': points,
323
+ 'quadric': sq,
324
+ 'inlier_ratio': inlier_ratio,
325
+ 'inlier_points': points[p > 0.5]
326
+ })
327
+ except Exception as e:
328
+ print(f" [Failed instance {inst_id}: {str(e)[:30]}...]", end='')
329
+
330
+ # Store frame result
331
+ sample_results[frame_idx] = {
332
+ 'metadata': sample['metadata'],
333
+ 'instances': instances,
334
+ 'scene_points': scene_points,
335
+ 'scene_seg': scene_seg,
336
+ 'rgb_image': rgb_image,
337
+ 'camera_position': camera_position,
338
+ 'camera_quaternion': camera_quaternion,
339
+ 'scene_scale': 20.0 / scene_extent if scene_extent > 0 else 1.0,
340
+ 'scene_center': scene_center
341
+ }
342
+ print(" ✓", end='', flush=True)
343
+
344
+ except Exception as e:
345
+ print(f" [Error: {str(e)}]", end='')
346
+ sample_results[frame_idx] = None
347
+
348
+ all_results[sample['name']] = sample_results
349
+ print()
350
+
351
+ print(f"\nPreprocessing complete! Processed {len(all_results)} samples")
352
+ return all_results
353
+
354
+
355
+ def main():
356
+ # Load MOVi-A train data
357
+ data_dir = Path("/research/cbim/vast/sf895/code/WaveGen/WaveGen_v33_使用超二次元函数_Transformer/data/movi_a_128x128/train")
358
+
359
+ if not data_dir.exists():
360
+ print(f"Error: train data not found at {data_dir}")
361
+ print("Please run download_movi_simple.py first to download the MOVi-A dataset")
362
+ return
363
+
364
+ # Get all train samples
365
+ sample_dirs = sorted(list(data_dir.glob("sample_*")))
366
+ print(f"Found {len(sample_dirs)} train samples")
367
+
368
+ if len(sample_dirs) == 0:
369
+ print("No train samples found!")
370
+ return
371
+
372
+ # Pre-load sample metadata
373
+ samples_info = []
374
+ print("\nLoading sample metadata...")
375
+ for sample_dir in sample_dirs[:10]: # Process first 10 samples
376
+ with open(sample_dir / "metadata.json", 'r') as f:
377
+ metadata = json.load(f)
378
+ samples_info.append({
379
+ 'dir': sample_dir,
380
+ 'name': sample_dir.name,
381
+ 'metadata': metadata,
382
+ 'num_frames': metadata['num_frames']
383
+ })
384
+
385
+ print(f"Loaded metadata for {len(samples_info)} samples")
386
+
387
+ # Preprocess all frames for all samples
388
+ all_preprocessed_results = preprocess_all_frames(samples_info)
389
+
390
+ # Start viser visualization
391
+ server = viser.ViserServer(port=8080)
392
+ print(f"\n{'='*60}")
393
+ print(f"Viser server started at http://localhost:8080")
394
+ print("Open this URL in your browser to view the 3D visualization")
395
+ print("Press Ctrl+C to stop the server")
396
+ print('='*60)
397
+
398
+ # Colors for different objects
399
+ object_colors = {
400
+ 'cube': (255, 0, 0), # Red
401
+ 'sphere': (0, 255, 0), # Green
402
+ 'cylinder': (0, 0, 255), # Blue
403
+ }
404
+
405
+ # Colors for instances
406
+ instance_colors = [
407
+ (255, 0, 0), # Red
408
+ (0, 255, 0), # Green
409
+ (0, 0, 255), # Blue
410
+ (255, 255, 0), # Yellow
411
+ (255, 0, 255), # Magenta
412
+ (0, 255, 255), # Cyan
413
+ (255, 128, 0), # Orange
414
+ (128, 0, 255), # Purple
415
+ ]
416
+
417
+ # Create GUI
418
+ with server.gui.add_folder("Controls"):
419
+ # Sample selector
420
+ sample_names = [s['name'] for s in samples_info]
421
+ current_sample = server.gui.add_dropdown(
422
+ "Select Sample",
423
+ options=sample_names,
424
+ initial_value=sample_names[0] if sample_names else None
425
+ )
426
+
427
+ # Frame selector
428
+ frame_slider = server.gui.add_slider(
429
+ "Frame",
430
+ min=0,
431
+ max=23, # MOVi-A has 24 frames
432
+ step=1,
433
+ initial_value=0
434
+ )
435
+
436
+ # Playback controls
437
+ with server.gui.add_folder("Playback Controls"):
438
+ play_button = server.gui.add_button("Play ▶")
439
+ pause_button = server.gui.add_button("Pause ⏸")
440
+ fps_slider = server.gui.add_slider(
441
+ "Playback FPS",
442
+ min=1,
443
+ max=24,
444
+ step=1,
445
+ initial_value=12
446
+ )
447
+
448
+ # Status display
449
+ status_display = server.gui.add_markdown("**Status:** Ready")
450
+
451
+ # Instance selector will be updated dynamically
452
+ instance_folder = server.gui.add_folder("Instances")
453
+
454
+ # Visibility controls
455
+ show_scene = server.gui.add_checkbox("Show Background Points", initial_value=True)
456
+ show_points = server.gui.add_checkbox("Highlight only the identified instance points", initial_value=True)
457
+ show_quadrics = server.gui.add_checkbox("Show Superquadrics", initial_value=True)
458
+ show_labels = server.gui.add_checkbox("Show Labels", initial_value=False)
459
+ show_camera = server.gui.add_checkbox("Show Camera", initial_value=False)
460
+ use_rgb_colors = server.gui.add_checkbox("Show Point Colors", initial_value=True)
461
+
462
+ # Camera view button
463
+ match_camera_view = server.gui.add_button("Match Frame Camera View")
464
+
465
+ # Visual parameters
466
+ point_size = server.gui.add_slider(
467
+ "Point Size",
468
+ min=0.001,
469
+ max=0.05,
470
+ step=0.001,
471
+ initial_value=0.01
472
+ )
473
+
474
+ mesh_opacity = server.gui.add_slider(
475
+ "Mesh Opacity",
476
+ min=0.0,
477
+ max=1.0,
478
+ step=0.1,
479
+ initial_value=0.7
480
+ )
481
+
482
+ # Info display
483
+ info_display = server.gui.add_markdown("**Sample Info:**\n\nSelect a sample to view")
484
+
485
+ # Store current visualization handles and results
486
+ current_viz = {
487
+ 'scene_cloud': None,
488
+ 'points': {},
489
+ 'meshes': {},
490
+ 'labels': {},
491
+ 'camera_frustum': None,
492
+ 'camera_label': None,
493
+ 'instance_toggles': [],
494
+ 'current_result': None,
495
+ 'all_results': all_preprocessed_results, # Store preprocessed results
496
+ 'is_playing': False
497
+ }
498
+
499
+ def load_frame():
500
+ """Load the current frame from preprocessed results"""
501
+ sample_name = current_sample.value
502
+ frame_idx = int(frame_slider.value)
503
+
504
+ if sample_name not in current_viz['all_results']:
505
+ status_display.value = f"**Status:** Sample {sample_name} not found in preprocessed results"
506
+ return
507
+
508
+ sample_results = current_viz['all_results'][sample_name]
509
+ if frame_idx not in sample_results or sample_results[frame_idx] is None:
510
+ status_display.value = f"**Status:** Frame {frame_idx} not available"
511
+ return
512
+
513
+ # Get preprocessed result
514
+ current_viz['current_result'] = sample_results[frame_idx]
515
+ current_viz['current_result']['name'] = sample_name
516
+ current_viz['current_result']['frame'] = frame_idx
517
+
518
+ # Update visualization
519
+ num_instances = len(current_viz['current_result']['instances'])
520
+ status_display.value = f"**Status:** Loaded frame {frame_idx} - {num_instances} instances"
521
+ update_scene()
522
+
523
+ def update_scene():
524
+ """Update the 3D scene based on current result"""
525
+ # Clear existing visualization
526
+ if current_viz['scene_cloud'] is not None:
527
+ current_viz['scene_cloud'].remove()
528
+ current_viz['scene_cloud'] = None
529
+
530
+ if current_viz['camera_frustum'] is not None:
531
+ current_viz['camera_frustum'].remove()
532
+ current_viz['camera_frustum'] = None
533
+
534
+ if current_viz['camera_label'] is not None:
535
+ current_viz['camera_label'].remove()
536
+ current_viz['camera_label'] = None
537
+
538
+ for pc in current_viz['points'].values():
539
+ pc.remove()
540
+ current_viz['points'] = {}
541
+
542
+ for mesh in current_viz['meshes'].values():
543
+ mesh.remove()
544
+ current_viz['meshes'] = {}
545
+
546
+ for label in current_viz['labels'].values():
547
+ label.remove()
548
+ current_viz['labels'] = {}
549
+
550
+ # Clear instance toggles
551
+ for toggle in current_viz['instance_toggles']:
552
+ toggle.remove()
553
+ current_viz['instance_toggles'] = []
554
+
555
+ # Get current result
556
+ selected = current_viz['current_result']
557
+
558
+ if selected is None:
559
+ info_display.value = "**Sample Info:**\n\nClick 'Process Current Frame' to start"
560
+ return
561
+
562
+ # Update info
563
+ info_text = f"**{selected['name']} - Frame {selected['frame']}**\n\n"
564
+ info_text += f"Total instances: {len(selected['instances'])}\n"
565
+
566
+ # Show full scene point cloud if requested
567
+ if show_scene.value:
568
+ scene_points_flat = selected['scene_points'].reshape(-1, 3)
569
+ scene_seg_flat = selected['scene_seg'].reshape(-1)
570
+
571
+ # Filter out invalid points
572
+ valid_mask = ~np.isnan(scene_points_flat).any(axis=1)
573
+ scene_points_valid = scene_points_flat[valid_mask]
574
+ scene_seg_valid = scene_seg_flat[valid_mask]
575
+
576
+ if use_rgb_colors.value and 'rgb_image' in selected:
577
+ # Use RGB colors from image
578
+ rgb_flat = selected['rgb_image'].reshape(-1, 3)
579
+ rgb_valid = rgb_flat[valid_mask]
580
+ colors = rgb_valid.astype(np.uint8)
581
+ else:
582
+ # Use segmentation colors
583
+ colors = np.zeros((len(scene_points_valid), 3), dtype=np.uint8)
584
+ for i, seg_id in enumerate(scene_seg_valid):
585
+ if seg_id == 0:
586
+ colors[i] = [128, 128, 128] # Gray for background
587
+ else:
588
+ colors[i] = instance_colors[(seg_id - 1) % len(instance_colors)]
589
+
590
+ current_viz['scene_cloud'] = server.scene.add_point_cloud(
591
+ "/scene_points",
592
+ points=scene_points_valid,
593
+ colors=colors,
594
+ point_size=point_size.value,
595
+ )
596
+ info_text += f"Scene points shown: {len(scene_points_valid)}\n"
597
+
598
+ info_text += "\n"
599
+
600
+ # Show camera if requested
601
+ if show_camera.value and 'camera_position' in selected:
602
+ # Transform camera position to normalized scene coordinates
603
+ cam_pos = selected['camera_position']
604
+ scale = selected.get('scene_scale', 1.0)
605
+ center = selected.get('scene_center', np.zeros(3))
606
+ cam_pos_normalized = (cam_pos - center) * scale
607
+
608
+ # Get camera parameters from metadata
609
+ focal_length = selected['metadata']['camera']['focal_length']
610
+ sensor_width = selected['metadata']['camera']['sensor_width']
611
+ resolution = selected['metadata']['resolution']
612
+
613
+ # Calculate field of view
614
+ fov = 2 * np.arctan(sensor_width / (2 * focal_length))
615
+
616
+ # Get camera orientation
617
+ cam_quat = selected['camera_quaternion']
618
+
619
+ # Convert quaternion to wxyz format (viser uses w first)
620
+ # MOVi quaternion is [x, y, z, w], viser needs [w, x, y, z]
621
+ wxyz = np.array([cam_quat[3], cam_quat[0], cam_quat[1], cam_quat[2]])
622
+
623
+ # Get the RGB image for the camera frustum
624
+ if 'rgb_image' in selected:
625
+ # Use original resolution for camera frustum display
626
+ # MOVi-A is 128x128, which should display clearly
627
+ small_rgb = selected['rgb_image']
628
+
629
+ # Add camera frustum with image
630
+ current_viz['camera_frustum'] = server.scene.add_camera_frustum(
631
+ "/camera_frustum",
632
+ fov=fov,
633
+ aspect=1.0, # Square aspect ratio for MOVi
634
+ scale=2.0, # Size of frustum visualization
635
+ wxyz=wxyz,
636
+ position=cam_pos_normalized,
637
+ image=small_rgb,
638
+ )
639
+ else:
640
+ # Add camera frustum without image
641
+ current_viz['camera_frustum'] = server.scene.add_camera_frustum(
642
+ "/camera_frustum",
643
+ fov=fov,
644
+ aspect=1.0,
645
+ scale=2.0,
646
+ wxyz=wxyz,
647
+ position=cam_pos_normalized,
648
+ color=(255, 255, 0),
649
+ )
650
+
651
+ # Add camera label
652
+ if show_labels.value:
653
+ current_viz['camera_label'] = server.scene.add_label(
654
+ "/camera_label",
655
+ text=f"Camera Frame {selected['frame']}",
656
+ position=cam_pos_normalized + np.array([0, 0.5, 0]),
657
+ )
658
+
659
+ # Create instance toggles
660
+ with instance_folder:
661
+ for inst in selected['instances']:
662
+ inst_info = inst['info']
663
+ toggle = server.gui.add_checkbox(
664
+ f"Instance {inst['id']}: {inst_info['shape']} ({inst_info['color']})",
665
+ initial_value=True
666
+ )
667
+ current_viz['instance_toggles'].append(toggle)
668
+
669
+ # Add instances
670
+ for i, inst in enumerate(selected['instances']):
671
+ inst_id = inst['id']
672
+ inst_info = inst['info']
673
+
674
+ # Check if this instance should be shown
675
+ show_this = i < len(current_viz['instance_toggles']) and current_viz['instance_toggles'][i].value
676
+
677
+ if not show_this:
678
+ continue
679
+
680
+ # Get color
681
+ shape_name = inst_info['shape']
682
+ color = object_colors.get(shape_name, (128, 128, 128))
683
+
684
+ # Add point cloud
685
+ if show_points.value:
686
+ pc = server.scene.add_point_cloud(
687
+ f"/instance_{inst_id}/points",
688
+ points=inst['points'],
689
+ colors=np.array([color] * len(inst['points']), dtype=np.uint8),
690
+ point_size=point_size.value,
691
+ )
692
+ current_viz['points'][inst_id] = pc
693
+
694
+ # Add superquadric
695
+ if show_quadrics.value:
696
+ try:
697
+ vertices, faces = generate_superquadric_mesh(inst['quadric'], num_samples=20)
698
+
699
+ mesh = server.scene.add_mesh_simple(
700
+ f"/instance_{inst_id}/mesh",
701
+ vertices=vertices,
702
+ faces=faces,
703
+ color=color,
704
+ opacity=mesh_opacity.value,
705
+ )
706
+ current_viz['meshes'][inst_id] = mesh
707
+
708
+ if show_labels.value:
709
+ sq = inst['quadric']
710
+ label_text = f"{inst_info['shape']}\n"
711
+ label_text += f"ε₁={sq.shape[0]:.2f}, ε₂={sq.shape[1]:.2f}\n"
712
+ label_text += f"Inliers: {inst['inlier_ratio']:.1%}\n"
713
+ label_text += f"Outliers: {(1 - inst['inlier_ratio']):.1%}"
714
+
715
+ label = server.scene.add_label(
716
+ f"/instance_{inst_id}/label",
717
+ text=label_text,
718
+ position=sq.translation,
719
+ )
720
+ current_viz['labels'][inst_id] = label
721
+
722
+ except Exception as e:
723
+ print(f"Error visualizing instance {inst_id}: {e}")
724
+
725
+ # Update info
726
+ info_text += f"\n**Instance {inst_id}:**\n"
727
+ info_text += f"- Shape: {inst_info['shape']}\n"
728
+ info_text += f"- Size: {inst_info['size']}\n"
729
+ info_text += f"- Color: {inst_info['color']}\n"
730
+ info_text += f"- Points: {len(inst['points'])}\n"
731
+ info_text += f"- ε₁={inst['quadric'].shape[0]:.3f}, ε₂={inst['quadric'].shape[1]:.3f}\n"
732
+ info_text += f"- Inliers: {inst['inlier_ratio']:.1%}\n"
733
+
734
+ info_display.value = info_text
735
+
736
+ # Set up callbacks
737
+ @current_sample.on_update
738
+ def _(_):
739
+ # Update frame slider max value based on selected sample
740
+ for s in samples_info:
741
+ if s['name'] == current_sample.value:
742
+ frame_slider.max = s['num_frames'] - 1
743
+ frame_slider.value = 0 # Reset to first frame
744
+ break
745
+ load_frame() # Automatically load when sample changes
746
+
747
+ @frame_slider.on_update
748
+ def _(_):
749
+ if not current_viz['is_playing']: # Only load if not playing (playback will handle it)
750
+ load_frame()
751
+
752
+ # Playback functions
753
+ import threading
754
+ playback_thread = None
755
+
756
+ def playback_loop():
757
+ """Playback loop in separate thread"""
758
+ while current_viz['is_playing']:
759
+ # Move to next frame
760
+ current_frame = int(frame_slider.value)
761
+ next_frame = (current_frame + 1) % (frame_slider.max + 1)
762
+ frame_slider.value = next_frame
763
+ load_frame()
764
+
765
+ # Sleep based on FPS
766
+ time.sleep(1.0 / fps_slider.value)
767
+
768
+ @play_button.on_click
769
+ def _(_):
770
+ if not current_viz['is_playing']:
771
+ current_viz['is_playing'] = True
772
+ play_button.disabled = True
773
+ pause_button.disabled = False
774
+ # Start playback thread
775
+ playback_thread = threading.Thread(target=playback_loop)
776
+ playback_thread.start()
777
+ status_display.value = "**Status:** Playing..."
778
+
779
+ @pause_button.on_click
780
+ def _(_):
781
+ if current_viz['is_playing']:
782
+ current_viz['is_playing'] = False
783
+ play_button.disabled = False
784
+ pause_button.disabled = True
785
+ status_display.value = "**Status:** Paused"
786
+
787
+ @show_scene.on_update
788
+ def _(_):
789
+ update_scene()
790
+
791
+ @use_rgb_colors.on_update
792
+ def _(_):
793
+ if show_scene.value:
794
+ update_scene()
795
+
796
+ @show_points.on_update
797
+ def _(_):
798
+ for pc in current_viz['points'].values():
799
+ pc.visible = show_points.value
800
+
801
+ @show_quadrics.on_update
802
+ def _(_):
803
+ for mesh in current_viz['meshes'].values():
804
+ mesh.visible = show_quadrics.value
805
+
806
+ @show_labels.on_update
807
+ def _(_):
808
+ for label in current_viz['labels'].values():
809
+ label.visible = show_labels.value
810
+
811
+ @point_size.on_update
812
+ def _(event):
813
+ if current_viz['scene_cloud'] is not None:
814
+ current_viz['scene_cloud'].point_size = event.target.value
815
+ for pc in current_viz['points'].values():
816
+ pc.point_size = event.target.value
817
+
818
+ @mesh_opacity.on_update
819
+ def _(event):
820
+ for mesh in current_viz['meshes'].values():
821
+ mesh.opacity = event.target.value
822
+
823
+ @match_camera_view.on_click
824
+ def _(event: viser.GuiEvent):
825
+ """Set the viewer camera to match the current frame's camera"""
826
+ if current_viz['current_result'] is None:
827
+ status_display.value = "**Status:** Please process a frame first"
828
+ return
829
+
830
+ result = current_viz['current_result']
831
+ if 'camera_position' not in result:
832
+ status_display.value = "**Status:** No camera data available"
833
+ return
834
+
835
+ # Get normalized camera position and orientation
836
+ cam_pos = result['camera_position']
837
+ cam_quat = result['camera_quaternion'] # [x, y, z, w] format
838
+
839
+ # Apply scene normalization to camera position
840
+ scale = result.get('scene_scale', 1.0)
841
+ center = result.get('scene_center', np.zeros(3))
842
+ cam_pos_normalized = (cam_pos - center) * scale
843
+
844
+ # Convert quaternion from xyzw to wxyz format for viser
845
+ wxyz = np.array([cam_quat[3], cam_quat[0], cam_quat[1], cam_quat[2]])
846
+
847
+ # Set camera position and orientation
848
+ client = event.client
849
+ if client is not None:
850
+ client.camera.position = cam_pos_normalized
851
+ client.camera.wxyz = wxyz
852
+
853
+ # Also update the up direction based on camera orientation
854
+ from scipy.spatial.transform import Rotation
855
+ rot = Rotation.from_quat(cam_quat) # xyzw format
856
+ # In MOVi, +Y is up in camera space, but viser might need -Y
857
+ # Try negative Y to fix the upside-down issue
858
+ camera_up = rot.apply([0, -1, 0])
859
+ client.camera.up_direction = camera_up
860
+
861
+ status_display.value = f"**Status:** Matched camera view for frame {result['frame']}"
862
+
863
+ # Initial setup
864
+ pause_button.disabled = True # Initially disabled
865
+
866
+ # Load first frame of first sample
867
+ if len(samples_info) > 0:
868
+ load_frame()
869
+ else:
870
+ info_display.value = "**Sample Info:**\n\nNo samples available"
871
+
872
+ # Set initial viewer camera to a reasonable position
873
+ # Look at the origin from a distance
874
+ server.scene.set_up_direction("+y")
875
+
876
+ # Keep server running
877
+ try:
878
+ while True:
879
+ time.sleep(0.1)
880
+ except KeyboardInterrupt:
881
+ print("\nShutting down server...")
882
+ server.stop()
883
+
884
+
885
+ if __name__ == "__main__":
886
+ main()
EMS-superquadric_fitting_inference/process_movi_validation.py ADDED
@@ -0,0 +1,886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Process MOVi-A validation set with hierarchical multi-superquadric fitting
4
+ Converts depth maps to normalized point clouds for superquadric fitting
5
+ """
6
+
7
+ import numpy as np
8
+ import sys
9
+ import os
10
+ import time
11
+ import viser
12
+ import json
13
+ import cv2
14
+ from pathlib import Path
15
+ from sklearn.cluster import DBSCAN
16
+
17
+ # Add the src directory to Python path
18
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
19
+
20
+ from EMS.EMS_recovery import EMS_recovery
21
+
22
+
23
+ def depth_to_normalized_pointcloud(depth, segmentation, camera_K, camera_position=None, camera_quaternion=None, resolution=128, convert_to_zdepth=True):
24
+ """
25
+ Convert depth map to normalized point cloud in range [-10, 10]
26
+
27
+ Args:
28
+ depth: (H, W, 1) depth array (euclidean distance from camera center)
29
+ segmentation: (H, W) instance segmentation mask
30
+ camera_K: 3x3 camera intrinsic matrix
31
+ camera_position: camera position in world coordinates
32
+ camera_quaternion: camera quaternion (x,y,z,w) in world coordinates
33
+ resolution: image resolution (assuming square)
34
+ convert_to_zdepth: bool, convert euclidean depth to z-depth before processing
35
+
36
+ Returns:
37
+ dict: instance_id -> normalized point cloud in world coordinates
38
+ """
39
+ H, W = depth.shape[:2]
40
+
41
+ # Get camera parameters
42
+ fx = camera_K[0, 0]
43
+ fy = camera_K[1, 1]
44
+ cx = camera_K[0, 2]
45
+ cy = camera_K[1, 2]
46
+
47
+ # Create pixel grid
48
+ xx, yy = np.meshgrid(np.arange(W), np.arange(H))
49
+
50
+ # Convert to normalized camera coordinates
51
+ x_norm = (xx - cx) / fx
52
+ y_norm = (yy - cy) / fy
53
+
54
+ if convert_to_zdepth:
55
+ # MOVi uses euclidean distance, convert to z-depth (planar depth)
56
+ # For each pixel, we have: euclidean_dist^2 = x^2 + y^2 + z^2
57
+ # Where x = x_norm * z, y = y_norm * z
58
+ # So: euclidean_dist^2 = (x_norm^2 + y_norm^2 + 1) * z^2
59
+ z = depth[:, :, 0] / np.sqrt(x_norm**2 + y_norm**2 + 1)
60
+ else:
61
+ # Use depth as-is (assume it's already z-depth)
62
+ z = depth[:, :, 0]
63
+
64
+ # Get 3D points
65
+ x = x_norm * z
66
+ y = y_norm * z
67
+
68
+ # Stack to get point cloud (in camera coordinates)
69
+ points_3d_camera = np.stack([x, y, z], axis=-1)
70
+
71
+ # Transform from camera to world coordinates if camera pose is provided
72
+ if camera_position is not None and camera_quaternion is not None:
73
+ from scipy.spatial.transform import Rotation
74
+
75
+ # Convert quaternion to rotation matrix
76
+ # MOVi uses [x, y, z, w] format
77
+ cam_rot = Rotation.from_quat(camera_quaternion)
78
+ cam_rot_matrix = cam_rot.as_matrix()
79
+
80
+ # Transform points: World = R * Camera + T
81
+ points_3d_flat = points_3d_camera.reshape(-1, 3)
82
+ points_3d_world = points_3d_flat @ cam_rot_matrix.T + camera_position
83
+ points_3d = points_3d_world.reshape(points_3d_camera.shape)
84
+ else:
85
+ points_3d = points_3d_camera
86
+
87
+
88
+ # Normalize entire scene to [-10, 10] range
89
+ # Find scene bounds (only valid depth points)
90
+ valid_mask = z > 0
91
+ valid_points = points_3d[valid_mask]
92
+
93
+ if len(valid_points) > 0:
94
+ # Find scene extent
95
+ scene_min = np.min(valid_points, axis=0)
96
+ scene_max = np.max(valid_points, axis=0)
97
+ scene_center = (scene_min + scene_max) / 2
98
+ scene_extent = np.max(scene_max - scene_min)
99
+
100
+ # Scale to [-10, 10]
101
+ if scene_extent > 0:
102
+ scale_factor = 20.0 / scene_extent # 20 because we want -10 to 10
103
+ points_3d_normalized = (points_3d - scene_center) * scale_factor
104
+ else:
105
+ points_3d_normalized = points_3d - scene_center
106
+ else:
107
+ points_3d_normalized = points_3d
108
+
109
+ # Get unique instance IDs (excluding background=0)
110
+ instance_ids = np.unique(segmentation)
111
+ instance_ids = instance_ids[instance_ids > 0]
112
+
113
+ instance_pointclouds = {}
114
+
115
+ for inst_id in instance_ids:
116
+ # Get mask for this instance
117
+ mask = segmentation == inst_id
118
+
119
+ # Extract points for this instance (already normalized with scene)
120
+ instance_points = points_3d_normalized[mask]
121
+
122
+ if len(instance_points) < 50: # Skip if too few points
123
+ continue
124
+
125
+ instance_pointclouds[int(inst_id)] = instance_points
126
+
127
+ # Also return the full scene point cloud and segmentation for visualization
128
+ return instance_pointclouds, points_3d_normalized, segmentation, scene_center if 'scene_center' in locals() else np.zeros(3), scene_extent if 'scene_extent' in locals() else 1.0
129
+
130
+
131
+ def hierarchical_ems(
132
+ point,
133
+ OutlierRatio=0.5,
134
+ MaxIterationEM=20,
135
+ ToleranceEM=1e-3,
136
+ RelativeToleranceEM=2e-1,
137
+ MaxOptiIterations=2,
138
+ Sigma=0.3,
139
+ MaxiSwitch=2,
140
+ AdaptiveUpperBound=True,
141
+ Rescale=False,
142
+ MaxLayer=3,
143
+ Eps=1.0, # Adjusted for normalized [-10, 10] point clouds
144
+ MinPoints=50,
145
+ ):
146
+ """
147
+ Hierarchical EMS for extracting multiple superquadrics from a point cloud
148
+ """
149
+ point_seg = {key: [] for key in list(range(0, MaxLayer+1))}
150
+ point_outlier = {key: [] for key in list(range(0, MaxLayer+1))}
151
+ point_seg[0] = [point]
152
+ list_quadrics = []
153
+ quadric_info = []
154
+
155
+ for h in range(MaxLayer):
156
+ if len(point_seg[h]) == 0:
157
+ break
158
+
159
+ for c in range(len(point_seg[h])):
160
+ current_points = point_seg[h][c]
161
+ if len(current_points) < MinPoints * 2:
162
+ continue
163
+
164
+ try:
165
+ # Fit superquadric
166
+ x_raw, p_raw = EMS_recovery(
167
+ current_points,
168
+ OutlierRatio,
169
+ MaxIterationEM,
170
+ ToleranceEM,
171
+ RelativeToleranceEM,
172
+ MaxOptiIterations,
173
+ Sigma,
174
+ MaxiSwitch,
175
+ AdaptiveUpperBound,
176
+ Rescale,
177
+ )
178
+
179
+ # Calculate fitting quality
180
+ inlier_mask = p_raw > 0.5
181
+ inlier_ratio = np.sum(inlier_mask) / len(p_raw)
182
+
183
+ if inlier_ratio > 0.3: # Accept if at least 30% inliers
184
+ list_quadrics.append(x_raw)
185
+ quadric_info.append({
186
+ 'layer': h,
187
+ 'segment': c,
188
+ 'inlier_ratio': inlier_ratio,
189
+ 'num_points': len(current_points),
190
+ 'inlier_points': current_points[inlier_mask]
191
+ })
192
+
193
+ # Separate outliers for next layer
194
+ outlier_mask = p_raw < 0.1
195
+ outlier = current_points[outlier_mask]
196
+
197
+ # If many outliers and not last layer, try clustering
198
+ if len(outlier) > MinPoints * 2 and h < MaxLayer - 1:
199
+ clustering = DBSCAN(eps=Eps, min_samples=MinPoints).fit(outlier)
200
+ labels = list(set(clustering.labels_))
201
+ labels = [item for item in labels if item >= 0]
202
+
203
+ if len(labels) >= 1:
204
+ for i in range(len(labels)):
205
+ cluster_points = outlier[clustering.labels_ == labels[i]]
206
+ if len(cluster_points) > MinPoints:
207
+ point_seg[h + 1].append(cluster_points)
208
+
209
+ except Exception as e:
210
+ continue
211
+
212
+ return list_quadrics, quadric_info
213
+
214
+
215
+ def generate_superquadric_mesh(sq, num_samples=25):
216
+ """Generate mesh vertices and faces for superquadric surface"""
217
+ eta = np.linspace(-np.pi/2, np.pi/2, num_samples)
218
+ omega = np.linspace(-np.pi, np.pi, num_samples)
219
+
220
+ vertices = []
221
+ faces = []
222
+
223
+ # Generate vertices
224
+ for i, e in enumerate(eta):
225
+ for j, w in enumerate(omega):
226
+ # Superquadric parametric equations
227
+ cos_eta = np.sign(np.cos(e)) * np.abs(np.cos(e))**sq.shape[0]
228
+ sin_eta = np.sign(np.sin(e)) * np.abs(np.sin(e))**sq.shape[0]
229
+ cos_omega = np.sign(np.cos(w)) * np.abs(np.cos(w))**sq.shape[1]
230
+ sin_omega = np.sign(np.sin(w)) * np.abs(np.sin(w))**sq.shape[1]
231
+
232
+ # Local coordinates
233
+ x_local = sq.scale[0] * cos_eta * cos_omega
234
+ y_local = sq.scale[1] * cos_eta * sin_omega
235
+ z_local = sq.scale[2] * sin_eta
236
+
237
+ # Apply rotation and translation
238
+ point_local = np.array([x_local, y_local, z_local])
239
+ point_global = sq.RotM @ point_local + sq.translation
240
+
241
+ vertices.append(point_global)
242
+
243
+ vertices = np.array(vertices)
244
+
245
+ # Generate faces (triangles)
246
+ for i in range(num_samples - 1):
247
+ for j in range(num_samples - 1):
248
+ # Current vertex indices
249
+ idx1 = i * num_samples + j
250
+ idx2 = i * num_samples + (j + 1) % num_samples
251
+ idx3 = (i + 1) * num_samples + j
252
+ idx4 = (i + 1) * num_samples + (j + 1) % num_samples
253
+
254
+ # Two triangles per quad
255
+ faces.append([idx1, idx2, idx3])
256
+ faces.append([idx2, idx4, idx3])
257
+
258
+ return vertices, np.array(faces)
259
+
260
+
261
+ def preprocess_all_frames(samples_info):
262
+ """Preprocess all frames for all samples"""
263
+ print("\n" + "="*60)
264
+ print("PREPROCESSING ALL FRAMES")
265
+ print("="*60)
266
+
267
+ all_results = {}
268
+
269
+ for sample_idx, sample in enumerate(samples_info):
270
+ print(f"\nProcessing {sample['name']} ({sample_idx + 1}/{len(samples_info)})")
271
+ sample_results = {}
272
+
273
+ for frame_idx in range(sample['num_frames']):
274
+ print(f" Frame {frame_idx}/{sample['num_frames']-1}", end='', flush=True)
275
+
276
+ try:
277
+ # Load depth, segmentation and RGB
278
+ depth = np.load(sample['dir'] / "depth" / f"frame_{frame_idx:03d}.npy")
279
+ segmentation = np.load(sample['dir'] / "segmentation" / f"frame_{frame_idx:03d}.npy")
280
+
281
+ # Load RGB image
282
+ rgb_path = sample['dir'] / "rgb" / f"frame_{frame_idx:03d}.png"
283
+ rgb_image = cv2.imread(str(rgb_path))
284
+ rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
285
+
286
+ # Get camera intrinsics
287
+ camera_K = np.array(sample['metadata']['camera']['K'])
288
+
289
+ # Load camera trajectory
290
+ camera_traj = np.load(sample['dir'] / "camera_trajectory.npz")
291
+ camera_position = camera_traj['positions'][frame_idx]
292
+ camera_quaternion = camera_traj['quaternions'][frame_idx]
293
+
294
+ # Convert to normalized point clouds
295
+ instance_pointclouds, scene_points, scene_seg, scene_center, scene_extent = depth_to_normalized_pointcloud(
296
+ depth, segmentation, camera_K,
297
+ camera_position=camera_position,
298
+ camera_quaternion=camera_quaternion,
299
+ convert_to_zdepth=True
300
+ )
301
+
302
+ # Process each instance
303
+ instances = []
304
+ for inst_id, points in instance_pointclouds.items():
305
+ inst_info = sample['metadata']['instances'][inst_id - 1]
306
+
307
+ try:
308
+ # Fit superquadric
309
+ sq, p = EMS_recovery(
310
+ points,
311
+ OutlierRatio=0.13,
312
+ MaxIterationEM=20,
313
+ AdaptiveUpperBound=True,
314
+ Rescale=False
315
+ )
316
+
317
+ inlier_ratio = np.sum(p > 0.5) / len(p)
318
+
319
+ instances.append({
320
+ 'id': inst_id,
321
+ 'info': inst_info,
322
+ 'points': points,
323
+ 'quadric': sq,
324
+ 'inlier_ratio': inlier_ratio,
325
+ 'inlier_points': points[p > 0.5]
326
+ })
327
+ except Exception as e:
328
+ print(f" [Failed instance {inst_id}: {str(e)[:30]}...]", end='')
329
+
330
+ # Store frame result
331
+ sample_results[frame_idx] = {
332
+ 'metadata': sample['metadata'],
333
+ 'instances': instances,
334
+ 'scene_points': scene_points,
335
+ 'scene_seg': scene_seg,
336
+ 'rgb_image': rgb_image,
337
+ 'camera_position': camera_position,
338
+ 'camera_quaternion': camera_quaternion,
339
+ 'scene_scale': 20.0 / scene_extent if scene_extent > 0 else 1.0,
340
+ 'scene_center': scene_center
341
+ }
342
+ print(" ✓", end='', flush=True)
343
+
344
+ except Exception as e:
345
+ print(f" [Error: {str(e)}]", end='')
346
+ sample_results[frame_idx] = None
347
+
348
+ all_results[sample['name']] = sample_results
349
+ print()
350
+
351
+ print(f"\nPreprocessing complete! Processed {len(all_results)} samples")
352
+ return all_results
353
+
354
+
355
+ def main():
356
+ # Load MOVi-A validation data
357
+ data_dir = Path("/research/cbim/vast/sf895/code/WaveGen/WaveGen_v33_使用超二次元函数_Transformer/data/movi_a_128x128/validation")
358
+
359
+ if not data_dir.exists():
360
+ print(f"Error: Validation data not found at {data_dir}")
361
+ print("Please run download_movi_simple.py first to download the MOVi-A dataset")
362
+ return
363
+
364
+ # Get all validation samples
365
+ sample_dirs = sorted(list(data_dir.glob("sample_*")))
366
+ print(f"Found {len(sample_dirs)} validation samples")
367
+
368
+ if len(sample_dirs) == 0:
369
+ print("No validation samples found!")
370
+ return
371
+
372
+ # Pre-load sample metadata
373
+ samples_info = []
374
+ print("\nLoading sample metadata...")
375
+ for sample_dir in sample_dirs[:10]: # Process first 10 samples
376
+ with open(sample_dir / "metadata.json", 'r') as f:
377
+ metadata = json.load(f)
378
+ samples_info.append({
379
+ 'dir': sample_dir,
380
+ 'name': sample_dir.name,
381
+ 'metadata': metadata,
382
+ 'num_frames': metadata['num_frames']
383
+ })
384
+
385
+ print(f"Loaded metadata for {len(samples_info)} samples")
386
+
387
+ # Preprocess all frames for all samples
388
+ all_preprocessed_results = preprocess_all_frames(samples_info)
389
+
390
+ # Start viser visualization
391
+ server = viser.ViserServer(port=8080)
392
+ print(f"\n{'='*60}")
393
+ print(f"Viser server started at http://localhost:8080")
394
+ print("Open this URL in your browser to view the 3D visualization")
395
+ print("Press Ctrl+C to stop the server")
396
+ print('='*60)
397
+
398
+ # Colors for different objects
399
+ object_colors = {
400
+ 'cube': (255, 0, 0), # Red
401
+ 'sphere': (0, 255, 0), # Green
402
+ 'cylinder': (0, 0, 255), # Blue
403
+ }
404
+
405
+ # Colors for instances
406
+ instance_colors = [
407
+ (255, 0, 0), # Red
408
+ (0, 255, 0), # Green
409
+ (0, 0, 255), # Blue
410
+ (255, 255, 0), # Yellow
411
+ (255, 0, 255), # Magenta
412
+ (0, 255, 255), # Cyan
413
+ (255, 128, 0), # Orange
414
+ (128, 0, 255), # Purple
415
+ ]
416
+
417
+ # Create GUI
418
+ with server.gui.add_folder("Controls"):
419
+ # Sample selector
420
+ sample_names = [s['name'] for s in samples_info]
421
+ current_sample = server.gui.add_dropdown(
422
+ "Select Sample",
423
+ options=sample_names,
424
+ initial_value=sample_names[0] if sample_names else None
425
+ )
426
+
427
+ # Frame selector
428
+ frame_slider = server.gui.add_slider(
429
+ "Frame",
430
+ min=0,
431
+ max=23, # MOVi-A has 24 frames
432
+ step=1,
433
+ initial_value=0
434
+ )
435
+
436
+ # Playback controls
437
+ with server.gui.add_folder("Playback Controls"):
438
+ play_button = server.gui.add_button("Play ▶")
439
+ pause_button = server.gui.add_button("Pause ⏸")
440
+ fps_slider = server.gui.add_slider(
441
+ "Playback FPS",
442
+ min=1,
443
+ max=24,
444
+ step=1,
445
+ initial_value=12
446
+ )
447
+
448
+ # Status display
449
+ status_display = server.gui.add_markdown("**Status:** Ready")
450
+
451
+ # Instance selector will be updated dynamically
452
+ instance_folder = server.gui.add_folder("Instances")
453
+
454
+ # Visibility controls
455
+ show_scene = server.gui.add_checkbox("Show Background Points", initial_value=True)
456
+ show_points = server.gui.add_checkbox("Highlight only the identified instance points", initial_value=True)
457
+ show_quadrics = server.gui.add_checkbox("Show Superquadrics", initial_value=True)
458
+ show_labels = server.gui.add_checkbox("Show Labels", initial_value=False)
459
+ show_camera = server.gui.add_checkbox("Show Camera", initial_value=False)
460
+ use_rgb_colors = server.gui.add_checkbox("Show Point Colors", initial_value=True)
461
+
462
+ # Camera view button
463
+ match_camera_view = server.gui.add_button("Match Frame Camera View")
464
+
465
+ # Visual parameters
466
+ point_size = server.gui.add_slider(
467
+ "Point Size",
468
+ min=0.001,
469
+ max=0.05,
470
+ step=0.001,
471
+ initial_value=0.01
472
+ )
473
+
474
+ mesh_opacity = server.gui.add_slider(
475
+ "Mesh Opacity",
476
+ min=0.0,
477
+ max=1.0,
478
+ step=0.1,
479
+ initial_value=0.7
480
+ )
481
+
482
+ # Info display
483
+ info_display = server.gui.add_markdown("**Sample Info:**\n\nSelect a sample to view")
484
+
485
+ # Store current visualization handles and results
486
+ current_viz = {
487
+ 'scene_cloud': None,
488
+ 'points': {},
489
+ 'meshes': {},
490
+ 'labels': {},
491
+ 'camera_frustum': None,
492
+ 'camera_label': None,
493
+ 'instance_toggles': [],
494
+ 'current_result': None,
495
+ 'all_results': all_preprocessed_results, # Store preprocessed results
496
+ 'is_playing': False
497
+ }
498
+
499
+ def load_frame():
500
+ """Load the current frame from preprocessed results"""
501
+ sample_name = current_sample.value
502
+ frame_idx = int(frame_slider.value)
503
+
504
+ if sample_name not in current_viz['all_results']:
505
+ status_display.value = f"**Status:** Sample {sample_name} not found in preprocessed results"
506
+ return
507
+
508
+ sample_results = current_viz['all_results'][sample_name]
509
+ if frame_idx not in sample_results or sample_results[frame_idx] is None:
510
+ status_display.value = f"**Status:** Frame {frame_idx} not available"
511
+ return
512
+
513
+ # Get preprocessed result
514
+ current_viz['current_result'] = sample_results[frame_idx]
515
+ current_viz['current_result']['name'] = sample_name
516
+ current_viz['current_result']['frame'] = frame_idx
517
+
518
+ # Update visualization
519
+ num_instances = len(current_viz['current_result']['instances'])
520
+ status_display.value = f"**Status:** Loaded frame {frame_idx} - {num_instances} instances"
521
+ update_scene()
522
+
523
+ def update_scene():
524
+ """Update the 3D scene based on current result"""
525
+ # Clear existing visualization
526
+ if current_viz['scene_cloud'] is not None:
527
+ current_viz['scene_cloud'].remove()
528
+ current_viz['scene_cloud'] = None
529
+
530
+ if current_viz['camera_frustum'] is not None:
531
+ current_viz['camera_frustum'].remove()
532
+ current_viz['camera_frustum'] = None
533
+
534
+ if current_viz['camera_label'] is not None:
535
+ current_viz['camera_label'].remove()
536
+ current_viz['camera_label'] = None
537
+
538
+ for pc in current_viz['points'].values():
539
+ pc.remove()
540
+ current_viz['points'] = {}
541
+
542
+ for mesh in current_viz['meshes'].values():
543
+ mesh.remove()
544
+ current_viz['meshes'] = {}
545
+
546
+ for label in current_viz['labels'].values():
547
+ label.remove()
548
+ current_viz['labels'] = {}
549
+
550
+ # Clear instance toggles
551
+ for toggle in current_viz['instance_toggles']:
552
+ toggle.remove()
553
+ current_viz['instance_toggles'] = []
554
+
555
+ # Get current result
556
+ selected = current_viz['current_result']
557
+
558
+ if selected is None:
559
+ info_display.value = "**Sample Info:**\n\nClick 'Process Current Frame' to start"
560
+ return
561
+
562
+ # Update info
563
+ info_text = f"**{selected['name']} - Frame {selected['frame']}**\n\n"
564
+ info_text += f"Total instances: {len(selected['instances'])}\n"
565
+
566
+ # Show full scene point cloud if requested
567
+ if show_scene.value:
568
+ scene_points_flat = selected['scene_points'].reshape(-1, 3)
569
+ scene_seg_flat = selected['scene_seg'].reshape(-1)
570
+
571
+ # Filter out invalid points
572
+ valid_mask = ~np.isnan(scene_points_flat).any(axis=1)
573
+ scene_points_valid = scene_points_flat[valid_mask]
574
+ scene_seg_valid = scene_seg_flat[valid_mask]
575
+
576
+ if use_rgb_colors.value and 'rgb_image' in selected:
577
+ # Use RGB colors from image
578
+ rgb_flat = selected['rgb_image'].reshape(-1, 3)
579
+ rgb_valid = rgb_flat[valid_mask]
580
+ colors = rgb_valid.astype(np.uint8)
581
+ else:
582
+ # Use segmentation colors
583
+ colors = np.zeros((len(scene_points_valid), 3), dtype=np.uint8)
584
+ for i, seg_id in enumerate(scene_seg_valid):
585
+ if seg_id == 0:
586
+ colors[i] = [128, 128, 128] # Gray for background
587
+ else:
588
+ colors[i] = instance_colors[(seg_id - 1) % len(instance_colors)]
589
+
590
+ current_viz['scene_cloud'] = server.scene.add_point_cloud(
591
+ "/scene_points",
592
+ points=scene_points_valid,
593
+ colors=colors,
594
+ point_size=point_size.value,
595
+ )
596
+ info_text += f"Scene points shown: {len(scene_points_valid)}\n"
597
+
598
+ info_text += "\n"
599
+
600
+ # Show camera if requested
601
+ if show_camera.value and 'camera_position' in selected:
602
+ # Transform camera position to normalized scene coordinates
603
+ cam_pos = selected['camera_position']
604
+ scale = selected.get('scene_scale', 1.0)
605
+ center = selected.get('scene_center', np.zeros(3))
606
+ cam_pos_normalized = (cam_pos - center) * scale
607
+
608
+ # Get camera parameters from metadata
609
+ focal_length = selected['metadata']['camera']['focal_length']
610
+ sensor_width = selected['metadata']['camera']['sensor_width']
611
+ resolution = selected['metadata']['resolution']
612
+
613
+ # Calculate field of view
614
+ fov = 2 * np.arctan(sensor_width / (2 * focal_length))
615
+
616
+ # Get camera orientation
617
+ cam_quat = selected['camera_quaternion']
618
+
619
+ # Convert quaternion to wxyz format (viser uses w first)
620
+ # MOVi quaternion is [x, y, z, w], viser needs [w, x, y, z]
621
+ wxyz = np.array([cam_quat[3], cam_quat[0], cam_quat[1], cam_quat[2]])
622
+
623
+ # Get the RGB image for the camera frustum
624
+ if 'rgb_image' in selected:
625
+ # Use original resolution for camera frustum display
626
+ # MOVi-A is 128x128, which should display clearly
627
+ small_rgb = selected['rgb_image']
628
+
629
+ # Add camera frustum with image
630
+ current_viz['camera_frustum'] = server.scene.add_camera_frustum(
631
+ "/camera_frustum",
632
+ fov=fov,
633
+ aspect=1.0, # Square aspect ratio for MOVi
634
+ scale=2.0, # Size of frustum visualization
635
+ wxyz=wxyz,
636
+ position=cam_pos_normalized,
637
+ image=small_rgb,
638
+ )
639
+ else:
640
+ # Add camera frustum without image
641
+ current_viz['camera_frustum'] = server.scene.add_camera_frustum(
642
+ "/camera_frustum",
643
+ fov=fov,
644
+ aspect=1.0,
645
+ scale=2.0,
646
+ wxyz=wxyz,
647
+ position=cam_pos_normalized,
648
+ color=(255, 255, 0),
649
+ )
650
+
651
+ # Add camera label
652
+ if show_labels.value:
653
+ current_viz['camera_label'] = server.scene.add_label(
654
+ "/camera_label",
655
+ text=f"Camera Frame {selected['frame']}",
656
+ position=cam_pos_normalized + np.array([0, 0.5, 0]),
657
+ )
658
+
659
+ # Create instance toggles
660
+ with instance_folder:
661
+ for inst in selected['instances']:
662
+ inst_info = inst['info']
663
+ toggle = server.gui.add_checkbox(
664
+ f"Instance {inst['id']}: {inst_info['shape']} ({inst_info['color']})",
665
+ initial_value=True
666
+ )
667
+ current_viz['instance_toggles'].append(toggle)
668
+
669
+ # Add instances
670
+ for i, inst in enumerate(selected['instances']):
671
+ inst_id = inst['id']
672
+ inst_info = inst['info']
673
+
674
+ # Check if this instance should be shown
675
+ show_this = i < len(current_viz['instance_toggles']) and current_viz['instance_toggles'][i].value
676
+
677
+ if not show_this:
678
+ continue
679
+
680
+ # Get color
681
+ shape_name = inst_info['shape']
682
+ color = object_colors.get(shape_name, (128, 128, 128))
683
+
684
+ # Add point cloud
685
+ if show_points.value:
686
+ pc = server.scene.add_point_cloud(
687
+ f"/instance_{inst_id}/points",
688
+ points=inst['points'],
689
+ colors=np.array([color] * len(inst['points']), dtype=np.uint8),
690
+ point_size=point_size.value,
691
+ )
692
+ current_viz['points'][inst_id] = pc
693
+
694
+ # Add superquadric
695
+ if show_quadrics.value:
696
+ try:
697
+ vertices, faces = generate_superquadric_mesh(inst['quadric'], num_samples=20)
698
+
699
+ mesh = server.scene.add_mesh_simple(
700
+ f"/instance_{inst_id}/mesh",
701
+ vertices=vertices,
702
+ faces=faces,
703
+ color=color,
704
+ opacity=mesh_opacity.value,
705
+ )
706
+ current_viz['meshes'][inst_id] = mesh
707
+
708
+ if show_labels.value:
709
+ sq = inst['quadric']
710
+ label_text = f"{inst_info['shape']}\n"
711
+ label_text += f"ε₁={sq.shape[0]:.2f}, ε₂={sq.shape[1]:.2f}\n"
712
+ label_text += f"Inliers: {inst['inlier_ratio']:.1%}\n"
713
+ label_text += f"Outliers: {(1 - inst['inlier_ratio']):.1%}"
714
+
715
+ label = server.scene.add_label(
716
+ f"/instance_{inst_id}/label",
717
+ text=label_text,
718
+ position=sq.translation,
719
+ )
720
+ current_viz['labels'][inst_id] = label
721
+
722
+ except Exception as e:
723
+ print(f"Error visualizing instance {inst_id}: {e}")
724
+
725
+ # Update info
726
+ info_text += f"\n**Instance {inst_id}:**\n"
727
+ info_text += f"- Shape: {inst_info['shape']}\n"
728
+ info_text += f"- Size: {inst_info['size']}\n"
729
+ info_text += f"- Color: {inst_info['color']}\n"
730
+ info_text += f"- Points: {len(inst['points'])}\n"
731
+ info_text += f"- ε₁={inst['quadric'].shape[0]:.3f}, ε₂={inst['quadric'].shape[1]:.3f}\n"
732
+ info_text += f"- Inliers: {inst['inlier_ratio']:.1%}\n"
733
+
734
+ info_display.value = info_text
735
+
736
+ # Set up callbacks
737
+ @current_sample.on_update
738
+ def _(_):
739
+ # Update frame slider max value based on selected sample
740
+ for s in samples_info:
741
+ if s['name'] == current_sample.value:
742
+ frame_slider.max = s['num_frames'] - 1
743
+ frame_slider.value = 0 # Reset to first frame
744
+ break
745
+ load_frame() # Automatically load when sample changes
746
+
747
+ @frame_slider.on_update
748
+ def _(_):
749
+ if not current_viz['is_playing']: # Only load if not playing (playback will handle it)
750
+ load_frame()
751
+
752
+ # Playback functions
753
+ import threading
754
+ playback_thread = None
755
+
756
+ def playback_loop():
757
+ """Playback loop in separate thread"""
758
+ while current_viz['is_playing']:
759
+ # Move to next frame
760
+ current_frame = int(frame_slider.value)
761
+ next_frame = (current_frame + 1) % (frame_slider.max + 1)
762
+ frame_slider.value = next_frame
763
+ load_frame()
764
+
765
+ # Sleep based on FPS
766
+ time.sleep(1.0 / fps_slider.value)
767
+
768
+ @play_button.on_click
769
+ def _(_):
770
+ if not current_viz['is_playing']:
771
+ current_viz['is_playing'] = True
772
+ play_button.disabled = True
773
+ pause_button.disabled = False
774
+ # Start playback thread
775
+ playback_thread = threading.Thread(target=playback_loop)
776
+ playback_thread.start()
777
+ status_display.value = "**Status:** Playing..."
778
+
779
+ @pause_button.on_click
780
+ def _(_):
781
+ if current_viz['is_playing']:
782
+ current_viz['is_playing'] = False
783
+ play_button.disabled = False
784
+ pause_button.disabled = True
785
+ status_display.value = "**Status:** Paused"
786
+
787
+ @show_scene.on_update
788
+ def _(_):
789
+ update_scene()
790
+
791
+ @use_rgb_colors.on_update
792
+ def _(_):
793
+ if show_scene.value:
794
+ update_scene()
795
+
796
+ @show_points.on_update
797
+ def _(_):
798
+ for pc in current_viz['points'].values():
799
+ pc.visible = show_points.value
800
+
801
+ @show_quadrics.on_update
802
+ def _(_):
803
+ for mesh in current_viz['meshes'].values():
804
+ mesh.visible = show_quadrics.value
805
+
806
+ @show_labels.on_update
807
+ def _(_):
808
+ for label in current_viz['labels'].values():
809
+ label.visible = show_labels.value
810
+
811
+ @point_size.on_update
812
+ def _(event):
813
+ if current_viz['scene_cloud'] is not None:
814
+ current_viz['scene_cloud'].point_size = event.target.value
815
+ for pc in current_viz['points'].values():
816
+ pc.point_size = event.target.value
817
+
818
+ @mesh_opacity.on_update
819
+ def _(event):
820
+ for mesh in current_viz['meshes'].values():
821
+ mesh.opacity = event.target.value
822
+
823
+ @match_camera_view.on_click
824
+ def _(event: viser.GuiEvent):
825
+ """Set the viewer camera to match the current frame's camera"""
826
+ if current_viz['current_result'] is None:
827
+ status_display.value = "**Status:** Please process a frame first"
828
+ return
829
+
830
+ result = current_viz['current_result']
831
+ if 'camera_position' not in result:
832
+ status_display.value = "**Status:** No camera data available"
833
+ return
834
+
835
+ # Get normalized camera position and orientation
836
+ cam_pos = result['camera_position']
837
+ cam_quat = result['camera_quaternion'] # [x, y, z, w] format
838
+
839
+ # Apply scene normalization to camera position
840
+ scale = result.get('scene_scale', 1.0)
841
+ center = result.get('scene_center', np.zeros(3))
842
+ cam_pos_normalized = (cam_pos - center) * scale
843
+
844
+ # Convert quaternion from xyzw to wxyz format for viser
845
+ wxyz = np.array([cam_quat[3], cam_quat[0], cam_quat[1], cam_quat[2]])
846
+
847
+ # Set camera position and orientation
848
+ client = event.client
849
+ if client is not None:
850
+ client.camera.position = cam_pos_normalized
851
+ client.camera.wxyz = wxyz
852
+
853
+ # Also update the up direction based on camera orientation
854
+ from scipy.spatial.transform import Rotation
855
+ rot = Rotation.from_quat(cam_quat) # xyzw format
856
+ # In MOVi, +Y is up in camera space, but viser might need -Y
857
+ # Try negative Y to fix the upside-down issue
858
+ camera_up = rot.apply([0, -1, 0])
859
+ client.camera.up_direction = camera_up
860
+
861
+ status_display.value = f"**Status:** Matched camera view for frame {result['frame']}"
862
+
863
+ # Initial setup
864
+ pause_button.disabled = True # Initially disabled
865
+
866
+ # Load first frame of first sample
867
+ if len(samples_info) > 0:
868
+ load_frame()
869
+ else:
870
+ info_display.value = "**Sample Info:**\n\nNo samples available"
871
+
872
+ # Set initial viewer camera to a reasonable position
873
+ # Look at the origin from a distance
874
+ server.scene.set_up_direction("+y")
875
+
876
+ # Keep server running
877
+ try:
878
+ while True:
879
+ time.sleep(0.1)
880
+ except KeyboardInterrupt:
881
+ print("\nShutting down server...")
882
+ server.stop()
883
+
884
+
885
+ if __name__ == "__main__":
886
+ main()
EMS-superquadric_fitting_inference/process_viser_hierarchical.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hierarchical multi-superquadric fitting with viser visualization
4
+ Based on the hierarchical_ems algorithm from multiquadric_test.py
5
+ """
6
+
7
+ import numpy as np
8
+ import sys
9
+ import os
10
+ import time
11
+ import viser
12
+ from sklearn.cluster import DBSCAN
13
+
14
+ # Add the src directory to Python path
15
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
16
+
17
+ from EMS.EMS_recovery import EMS_recovery
18
+
19
+
20
+ def hierarchical_ems(
21
+ point,
22
+ OutlierRatio=0.5, # Reduced for better initial fit
23
+ MaxIterationEM=20,
24
+ ToleranceEM=1e-3,
25
+ RelativeToleranceEM=2e-1,
26
+ MaxOptiIterations=2,
27
+ Sigma=0.3,
28
+ MaxiSwitch=2,
29
+ AdaptiveUpperBound=True,
30
+ Rescale=False,
31
+ MaxLayer=3, # Reduced for faster processing
32
+ Eps=0.1, # Adjusted for normalized point clouds
33
+ MinPoints=50, # Minimum points to form a cluster
34
+ ):
35
+ """
36
+ Hierarchical EMS for extracting multiple superquadrics from a point cloud
37
+ """
38
+ point_seg = {key: [] for key in list(range(0, MaxLayer+1))}
39
+ point_outlier = {key: [] for key in list(range(0, MaxLayer+1))}
40
+ point_seg[0] = [point]
41
+ list_quadrics = []
42
+ quadric_info = [] # Store additional info about each quadric
43
+
44
+ for h in range(MaxLayer):
45
+ if len(point_seg[h]) == 0:
46
+ break
47
+
48
+ for c in range(len(point_seg[h])):
49
+ current_points = point_seg[h][c]
50
+ if len(current_points) < MinPoints * 2:
51
+ continue
52
+
53
+ print(f" Layer {h}, Segment {c}: Processing {len(current_points)} points")
54
+
55
+ try:
56
+ # Fit superquadric
57
+ x_raw, p_raw = EMS_recovery(
58
+ current_points,
59
+ OutlierRatio,
60
+ MaxIterationEM,
61
+ ToleranceEM,
62
+ RelativeToleranceEM,
63
+ MaxOptiIterations,
64
+ Sigma,
65
+ MaxiSwitch,
66
+ AdaptiveUpperBound,
67
+ Rescale,
68
+ )
69
+
70
+ # Calculate fitting quality
71
+ inlier_mask = p_raw > 0.5
72
+ inlier_ratio = np.sum(inlier_mask) / len(p_raw)
73
+
74
+ if inlier_ratio > 0.3: # Accept if at least 30% inliers
75
+ list_quadrics.append(x_raw)
76
+ quadric_info.append({
77
+ 'layer': h,
78
+ 'segment': c,
79
+ 'inlier_ratio': inlier_ratio,
80
+ 'num_points': len(current_points),
81
+ 'inlier_points': current_points[inlier_mask]
82
+ })
83
+ print(f" → Fitted superquadric with {inlier_ratio:.1%} inliers")
84
+
85
+ # Separate outliers for next layer
86
+ outlier_mask = p_raw < 0.1
87
+ outlier = current_points[outlier_mask]
88
+
89
+ # If many outliers and not last layer, try clustering
90
+ if len(outlier) > MinPoints * 2 and h < MaxLayer - 1:
91
+ clustering = DBSCAN(eps=Eps, min_samples=MinPoints).fit(outlier)
92
+ labels = list(set(clustering.labels_))
93
+ labels = [item for item in labels if item >= 0]
94
+
95
+ if len(labels) >= 1:
96
+ print(f" → Found {len(labels)} clusters in outliers")
97
+ for i in range(len(labels)):
98
+ cluster_points = outlier[clustering.labels_ == labels[i]]
99
+ if len(cluster_points) > MinPoints:
100
+ point_seg[h + 1].append(cluster_points)
101
+
102
+ except Exception as e:
103
+ print(f" → Error: {e}")
104
+ continue
105
+
106
+ return list_quadrics, quadric_info
107
+
108
+
109
+ def generate_superquadric_mesh(sq, num_samples=25):
110
+ """Generate mesh vertices and faces for superquadric surface"""
111
+ eta = np.linspace(-np.pi/2, np.pi/2, num_samples)
112
+ omega = np.linspace(-np.pi, np.pi, num_samples)
113
+
114
+ vertices = []
115
+ faces = []
116
+
117
+ # Generate vertices
118
+ for i, e in enumerate(eta):
119
+ for j, w in enumerate(omega):
120
+ # Superquadric parametric equations
121
+ cos_eta = np.sign(np.cos(e)) * np.abs(np.cos(e))**sq.shape[0]
122
+ sin_eta = np.sign(np.sin(e)) * np.abs(np.sin(e))**sq.shape[0]
123
+ cos_omega = np.sign(np.cos(w)) * np.abs(np.cos(w))**sq.shape[1]
124
+ sin_omega = np.sign(np.sin(w)) * np.abs(np.sin(w))**sq.shape[1]
125
+
126
+ # Local coordinates
127
+ x_local = sq.scale[0] * cos_eta * cos_omega
128
+ y_local = sq.scale[1] * cos_eta * sin_omega
129
+ z_local = sq.scale[2] * sin_eta
130
+
131
+ # Apply rotation and translation
132
+ point_local = np.array([x_local, y_local, z_local])
133
+ point_global = sq.RotM @ point_local + sq.translation
134
+
135
+ vertices.append(point_global)
136
+
137
+ vertices = np.array(vertices)
138
+
139
+ # Generate faces (triangles)
140
+ for i in range(num_samples - 1):
141
+ for j in range(num_samples - 1):
142
+ # Current vertex indices
143
+ idx1 = i * num_samples + j
144
+ idx2 = i * num_samples + (j + 1) % num_samples
145
+ idx3 = (i + 1) * num_samples + j
146
+ idx4 = (i + 1) * num_samples + (j + 1) % num_samples
147
+
148
+ # Two triangles per quad
149
+ faces.append([idx1, idx2, idx3])
150
+ faces.append([idx2, idx4, idx3])
151
+
152
+ return vertices, np.array(faces)
153
+
154
+
155
+ def main():
156
+ # Import utilities for reading PLY files
157
+ from EMS.utilities import read_ply
158
+
159
+ all_samples = []
160
+ sample_idx = 0
161
+
162
+ print("Loading and processing samples with hierarchical multi-quadric fitting...")
163
+
164
+ # 1. Load repository example PLY files
165
+ example_data_dir = "/research/cbim/vast/sf895/code/EMS-superquadric_fitting/MATLAB/example_scripts/data"
166
+
167
+ # Single superquadric examples
168
+ single_ply_files = [
169
+ "single_superquadric/noisy_pointCloud_example_1.ply",
170
+ "single_superquadric/noisy_pointCloud_example_2.ply",
171
+ "single_superquadric/partial_pointCloud_example_1.ply",
172
+ ]
173
+
174
+ # Multi superquadric examples
175
+ multi_ply_files = [
176
+ "multi_superquadrics/cat.ply",
177
+ "multi_superquadrics/dog.ply",
178
+ "multi_superquadrics/turtle.ply",
179
+ ]
180
+
181
+ # Process single superquadric files
182
+ for ply_file in single_ply_files:
183
+ file_path = os.path.join(example_data_dir, ply_file)
184
+ if os.path.exists(file_path):
185
+ print(f"\nProcessing {ply_file}...")
186
+ try:
187
+ # Load PLY data
188
+ point_cloud = read_ply(file_path)
189
+
190
+ # Single quadric fitting
191
+ from EMS.EMS_recovery import EMS_recovery
192
+ sq, p = EMS_recovery(point_cloud, OutlierRatio=0.2, AdaptiveUpperBound=True)
193
+
194
+ all_samples.append({
195
+ 'name': os.path.basename(ply_file),
196
+ 'idx': sample_idx,
197
+ 'points': point_cloud,
198
+ 'quadrics': [sq],
199
+ 'quadric_info': [{
200
+ 'layer': 0,
201
+ 'segment': 0,
202
+ 'inlier_ratio': np.sum(p > 0.5) / len(p),
203
+ 'num_points': len(point_cloud),
204
+ 'inlier_points': point_cloud[p > 0.5]
205
+ }]
206
+ })
207
+ sample_idx += 1
208
+
209
+ print(f" Success! Shape: {sq.shape}, Scale: {sq.scale}")
210
+
211
+ except Exception as e:
212
+ print(f" Failed: {e}")
213
+
214
+ # Process multi superquadric files
215
+ for ply_file in multi_ply_files:
216
+ file_path = os.path.join(example_data_dir, ply_file)
217
+ if os.path.exists(file_path):
218
+ print(f"\nProcessing {ply_file} (multi-quadric)...")
219
+ try:
220
+ # Load PLY data
221
+ point_cloud = read_ply(file_path)
222
+
223
+ # Hierarchical multi-quadric fitting
224
+ # Adjust parameters for these specific examples
225
+ quadrics, quadric_info = hierarchical_ems(
226
+ point_cloud,
227
+ OutlierRatio=0.9, # Higher for multi-object scenes
228
+ Eps=1.7, # Larger for non-normalized data
229
+ MinPoints=60, # Standard minimum
230
+ Rescale=True # Enable rescaling for raw PLY data
231
+ )
232
+
233
+ all_samples.append({
234
+ 'name': os.path.basename(ply_file),
235
+ 'idx': sample_idx,
236
+ 'points': point_cloud,
237
+ 'quadrics': quadrics,
238
+ 'quadric_info': quadric_info
239
+ })
240
+ sample_idx += 1
241
+
242
+ print(f"Summary: Found {len(quadrics)} superquadrics")
243
+ for j, (sq, info) in enumerate(zip(quadrics, quadric_info)):
244
+ print(f" SQ{j+1}: Shape={sq.shape}, Scale={sq.scale}, "
245
+ f"Inliers={info['inlier_ratio']:.1%}")
246
+
247
+ except Exception as e:
248
+ print(f" Failed: {e}")
249
+
250
+ # 2. Also load normalized point cloud samples if they exist
251
+ normalized_dir = "/research/cbim/vast/sf895/code/EMS-superquadric_fitting/20250811_231035_step10_stage0_waves1"
252
+ if os.path.exists(normalized_dir):
253
+ print("\n--- Processing normalized point cloud samples ---")
254
+ for i in range(2): # Just load first 2 samples
255
+ sample_name = f"sample_{i}_normalized_points.npz"
256
+ sample_path = os.path.join(normalized_dir, sample_name)
257
+
258
+ if os.path.exists(sample_path):
259
+ print(f"\nProcessing {sample_name}...")
260
+ try:
261
+ # Load data
262
+ data = np.load(sample_path)
263
+ point_cloud = data['points'][0] # First frame
264
+
265
+ # Hierarchical multi-quadric fitting
266
+ quadrics, quadric_info = hierarchical_ems(point_cloud)
267
+
268
+ all_samples.append({
269
+ 'name': sample_name,
270
+ 'idx': sample_idx,
271
+ 'points': point_cloud,
272
+ 'quadrics': quadrics,
273
+ 'quadric_info': quadric_info
274
+ })
275
+ sample_idx += 1
276
+
277
+ print(f"Summary: Found {len(quadrics)} superquadrics")
278
+
279
+ except Exception as e:
280
+ print(f" Failed: {e}")
281
+
282
+ # Start viser server
283
+ server = viser.ViserServer(port=8080)
284
+ print(f"\n{'='*60}")
285
+ print(f"Viser server started at http://localhost:8080")
286
+ print("Open this URL in your browser to view the 3D visualization")
287
+ print("Press Ctrl+C to stop the server")
288
+ print('='*60)
289
+
290
+ # Colors for different superquadrics
291
+ quadric_colors = [
292
+ (255, 0, 0), # Red
293
+ (0, 255, 0), # Green
294
+ (0, 0, 255), # Blue
295
+ (255, 255, 0), # Yellow
296
+ (255, 0, 255), # Magenta
297
+ (0, 255, 255), # Cyan
298
+ ]
299
+
300
+ # Create GUI
301
+ with server.gui.add_folder("Controls"):
302
+ # Sample selector
303
+ sample_names = [s['name'] for s in all_samples if s['points'] is not None]
304
+ current_sample = server.gui.add_dropdown(
305
+ "Select Sample",
306
+ options=sample_names,
307
+ initial_value=sample_names[0] if sample_names else None
308
+ )
309
+
310
+ # Visibility controls
311
+ show_points = server.gui.add_checkbox("Show Points", initial_value=True)
312
+ show_all_quadrics = server.gui.add_checkbox("Show All Quadrics", initial_value=True)
313
+ show_labels = server.gui.add_checkbox("Show Labels", initial_value=True)
314
+
315
+ # Individual quadric toggles will be added dynamically
316
+ quadric_toggles_folder = server.gui.add_folder("Individual Quadrics")
317
+
318
+ # Visual parameters
319
+ point_size = server.gui.add_slider(
320
+ "Point Size",
321
+ min=0.001,
322
+ max=0.02,
323
+ step=0.001,
324
+ initial_value=0.003
325
+ )
326
+
327
+ mesh_opacity = server.gui.add_slider(
328
+ "Mesh Opacity",
329
+ min=0.0,
330
+ max=1.0,
331
+ step=0.1,
332
+ initial_value=0.5
333
+ )
334
+
335
+ # Info display
336
+ info_display = server.gui.add_markdown("**Sample Info:**\n\nSelect a sample to view")
337
+
338
+ # Store current visualization handles
339
+ current_viz = {
340
+ 'points': None,
341
+ 'meshes': [],
342
+ 'labels': [],
343
+ 'quadric_toggles': []
344
+ }
345
+
346
+ def update_scene():
347
+ """Update the 3D scene based on current selection"""
348
+ # Clear existing visualization
349
+ if current_viz['points'] is not None:
350
+ current_viz['points'].remove()
351
+ current_viz['points'] = None
352
+
353
+ for mesh in current_viz['meshes']:
354
+ mesh.remove()
355
+ current_viz['meshes'] = []
356
+
357
+ for label in current_viz['labels']:
358
+ label.remove()
359
+ current_viz['labels'] = []
360
+
361
+ # Clear quadric toggles
362
+ for toggle in current_viz['quadric_toggles']:
363
+ toggle.remove()
364
+ current_viz['quadric_toggles'] = []
365
+
366
+ # Find selected sample
367
+ selected = None
368
+ for sample in all_samples:
369
+ if sample['name'] == current_sample.value:
370
+ selected = sample
371
+ break
372
+
373
+ if selected is None or selected['points'] is None:
374
+ info_display.value = "**No valid sample selected**"
375
+ return
376
+
377
+ # Update info
378
+ info_text = f"**{selected['name']}**\n\n"
379
+ info_text += f"Total points: {len(selected['points'])}\n"
380
+ info_text += f"Superquadrics found: {len(selected['quadrics'])}\n\n"
381
+
382
+ if len(selected['quadrics']) > 0:
383
+ info_text += "**Superquadric Details:**\n"
384
+ for i, (sq, info) in enumerate(zip(selected['quadrics'], selected['quadric_info'])):
385
+ info_text += f"\n**SQ{i+1}** (Layer {info['layer']}):\n"
386
+ info_text += f"- Shape: ε₁={sq.shape[0]:.3f}, ε₂={sq.shape[1]:.3f}\n"
387
+ info_text += f"- Scale: ({sq.scale[0]:.2f}, {sq.scale[1]:.2f}, {sq.scale[2]:.2f})\n"
388
+ info_text += f"- Inliers: {info['inlier_ratio']:.1%} ({info['num_points']} points)\n"
389
+
390
+ info_display.value = info_text
391
+
392
+ # Add point cloud
393
+ if show_points.value:
394
+ current_viz['points'] = server.scene.add_point_cloud(
395
+ "/current/points",
396
+ points=selected['points'],
397
+ colors=np.array([(128, 128, 128)] * len(selected['points']), dtype=np.uint8),
398
+ point_size=point_size.value,
399
+ )
400
+
401
+ # Add individual quadric toggles
402
+ with quadric_toggles_folder:
403
+ for i in range(len(selected['quadrics'])):
404
+ toggle = server.gui.add_checkbox(
405
+ f"Quadric {i+1}",
406
+ initial_value=True
407
+ )
408
+ current_viz['quadric_toggles'].append(toggle)
409
+
410
+ # Add superquadrics
411
+ for i, (sq, info) in enumerate(zip(selected['quadrics'], selected['quadric_info'])):
412
+ color = quadric_colors[i % len(quadric_colors)]
413
+
414
+ # Check if this quadric should be shown
415
+ show_this = show_all_quadrics.value
416
+ if i < len(current_viz['quadric_toggles']):
417
+ show_this = show_this and current_viz['quadric_toggles'][i].value
418
+
419
+ if show_this:
420
+ try:
421
+ vertices, faces = generate_superquadric_mesh(sq, num_samples=20)
422
+
423
+ mesh = server.scene.add_mesh_simple(
424
+ f"/current/mesh_{i}",
425
+ vertices=vertices,
426
+ faces=faces,
427
+ color=color,
428
+ opacity=mesh_opacity.value,
429
+ )
430
+ current_viz['meshes'].append(mesh)
431
+
432
+ if show_labels.value:
433
+ label = server.scene.add_label(
434
+ f"/current/label_{i}",
435
+ text=f"SQ{i+1}: ε₁={sq.shape[0]:.2f}, ε₂={sq.shape[1]:.2f}",
436
+ position=sq.translation,
437
+ )
438
+ current_viz['labels'].append(label)
439
+
440
+ except Exception as e:
441
+ print(f"Error visualizing quadric {i}: {e}")
442
+
443
+ # Set up callbacks
444
+ @current_sample.on_update
445
+ def _(_):
446
+ update_scene()
447
+
448
+ @show_points.on_update
449
+ def _(_):
450
+ if current_viz['points'] is not None:
451
+ current_viz['points'].visible = show_points.value
452
+
453
+ @show_all_quadrics.on_update
454
+ def _(_):
455
+ for mesh in current_viz['meshes']:
456
+ mesh.visible = show_all_quadrics.value
457
+
458
+ @show_labels.on_update
459
+ def _(_):
460
+ for label in current_viz['labels']:
461
+ label.visible = show_labels.value
462
+
463
+ @point_size.on_update
464
+ def _(event):
465
+ if current_viz['points'] is not None:
466
+ current_viz['points'].point_size = event.target.value
467
+
468
+ @mesh_opacity.on_update
469
+ def _(event):
470
+ for mesh in current_viz['meshes']:
471
+ mesh.opacity = event.target.value
472
+
473
+ # Initial scene
474
+ update_scene()
475
+
476
+ # Keep server running
477
+ try:
478
+ while True:
479
+ time.sleep(0.1)
480
+ except KeyboardInterrupt:
481
+ print("\nShutting down server...")
482
+ server.stop()
483
+
484
+
485
+ if __name__ == "__main__":
486
+ main()
EMS-superquadric_fitting_inference/process_viser_single.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple viser visualization for superquadric fitting results
4
+ """
5
+
6
+ import numpy as np
7
+ import sys
8
+ import os
9
+ import time
10
+ import viser
11
+
12
+ # Add the src directory to Python path
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
14
+
15
+ from EMS.EMS_recovery import EMS_recovery
16
+ from sklearn.cluster import DBSCAN
17
+
18
+
19
+ def generate_superquadric_mesh(sq, num_samples=30):
20
+ """Generate mesh vertices and faces for superquadric surface"""
21
+ eta = np.linspace(-np.pi/2, np.pi/2, num_samples)
22
+ omega = np.linspace(-np.pi, np.pi, num_samples)
23
+
24
+ vertices = []
25
+ faces = []
26
+
27
+ # Generate vertices
28
+ for i, e in enumerate(eta):
29
+ for j, w in enumerate(omega):
30
+ # Superquadric parametric equations
31
+ cos_eta = np.sign(np.cos(e)) * np.abs(np.cos(e))**sq.shape[0]
32
+ sin_eta = np.sign(np.sin(e)) * np.abs(np.sin(e))**sq.shape[0]
33
+ cos_omega = np.sign(np.cos(w)) * np.abs(np.cos(w))**sq.shape[1]
34
+ sin_omega = np.sign(np.sin(w)) * np.abs(np.sin(w))**sq.shape[1]
35
+
36
+ # Local coordinates
37
+ x_local = sq.scale[0] * cos_eta * cos_omega
38
+ y_local = sq.scale[1] * cos_eta * sin_omega
39
+ z_local = sq.scale[2] * sin_eta
40
+
41
+ # Apply rotation and translation
42
+ point_local = np.array([x_local, y_local, z_local])
43
+ point_global = sq.RotM @ point_local + sq.translation
44
+
45
+ vertices.append(point_global)
46
+
47
+ vertices = np.array(vertices)
48
+
49
+ # Generate faces (triangles)
50
+ for i in range(num_samples - 1):
51
+ for j in range(num_samples - 1):
52
+ # Current vertex indices
53
+ idx1 = i * num_samples + j
54
+ idx2 = i * num_samples + (j + 1) % num_samples
55
+ idx3 = (i + 1) * num_samples + j
56
+ idx4 = (i + 1) * num_samples + (j + 1) % num_samples
57
+
58
+ # Two triangles per quad
59
+ faces.append([idx1, idx2, idx3])
60
+ faces.append([idx2, idx4, idx3])
61
+
62
+ return vertices, np.array(faces)
63
+
64
+
65
+ def main():
66
+ # Base directory with samples
67
+ base_dir = "/research/cbim/vast/sf895/code/EMS-superquadric_fitting/20250811_231035_step10_stage0_waves1"
68
+
69
+ # Process samples
70
+ all_samples = []
71
+
72
+ print("Loading and processing samples...")
73
+ for i in range(5): # Process samples 0-4
74
+ sample_name = f"sample_{i}_normalized_points.npz"
75
+ sample_path = os.path.join(base_dir, sample_name)
76
+
77
+ if os.path.exists(sample_path):
78
+ try:
79
+ # Load data
80
+ data = np.load(sample_path)
81
+ point_cloud = data['points'][0] # First frame
82
+
83
+ # Try to fit superquadric
84
+ print(f"\nProcessing {sample_name}...")
85
+ sq, p = EMS_recovery(point_cloud, OutlierRatio=0.2, AdaptiveUpperBound=True)
86
+
87
+ all_samples.append({
88
+ 'name': sample_name,
89
+ 'idx': i,
90
+ 'points': point_cloud,
91
+ 'quadric': sq,
92
+ 'probs': p
93
+ })
94
+ print(f" Success! Shape: {sq.shape}, Scale: {sq.scale}")
95
+
96
+ except Exception as e:
97
+ print(f" Failed: {e}")
98
+ all_samples.append({
99
+ 'name': sample_name,
100
+ 'idx': i,
101
+ 'points': point_cloud if 'point_cloud' in locals() else None,
102
+ 'quadric': None,
103
+ 'probs': None
104
+ })
105
+
106
+ # Start viser server
107
+ server = viser.ViserServer(port=8080)
108
+ print(f"\n{'='*60}")
109
+ print(f"Viser server started at http://localhost:8080")
110
+ print("Open this URL in your browser to view the 3D visualization")
111
+ print("Press Ctrl+C to stop the server")
112
+ print('='*60)
113
+
114
+ # Colors for different samples
115
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
116
+
117
+ # Create GUI
118
+ with server.gui.add_folder("Controls"):
119
+ # Sample selector
120
+ sample_names = [s['name'] for s in all_samples if s['points'] is not None]
121
+ current_sample = server.gui.add_dropdown(
122
+ "Select Sample",
123
+ options=sample_names,
124
+ initial_value=sample_names[0] if sample_names else None
125
+ )
126
+
127
+ # Visibility controls
128
+ show_points = server.gui.add_checkbox("Show Points", initial_value=True)
129
+ show_quadric = server.gui.add_checkbox("Show Quadric", initial_value=True)
130
+
131
+ # Visual parameters
132
+ point_size = server.gui.add_slider(
133
+ "Point Size",
134
+ min=0.001,
135
+ max=0.02,
136
+ step=0.001,
137
+ initial_value=0.005
138
+ )
139
+
140
+ mesh_opacity = server.gui.add_slider(
141
+ "Mesh Opacity",
142
+ min=0.0,
143
+ max=1.0,
144
+ step=0.1,
145
+ initial_value=0.5
146
+ )
147
+
148
+ # Info display
149
+ info_display = server.gui.add_markdown("**Sample Info:**\n\nSelect a sample to view")
150
+
151
+ # Store current visualization handles
152
+ current_viz = {'points': None, 'mesh': None, 'label': None}
153
+
154
+ def update_scene():
155
+ """Update the 3D scene based on current selection"""
156
+ # Clear existing visualization
157
+ if current_viz['points'] is not None:
158
+ current_viz['points'].remove()
159
+ current_viz['points'] = None
160
+ if current_viz['mesh'] is not None:
161
+ current_viz['mesh'].remove()
162
+ current_viz['mesh'] = None
163
+ if current_viz['label'] is not None:
164
+ current_viz['label'].remove()
165
+ current_viz['label'] = None
166
+
167
+ # Find selected sample
168
+ selected = None
169
+ for sample in all_samples:
170
+ if sample['name'] == current_sample.value:
171
+ selected = sample
172
+ break
173
+
174
+ if selected is None or selected['points'] is None:
175
+ info_display.value = "**No valid sample selected**"
176
+ return
177
+
178
+ # Update info
179
+ info_text = f"**{selected['name']}**\n\n"
180
+ info_text += f"Points: {len(selected['points'])}\n\n"
181
+
182
+ if selected['quadric'] is not None:
183
+ sq = selected['quadric']
184
+ info_text += f"**Superquadric Parameters:**\n"
185
+ info_text += f"- Shape: ε₁={sq.shape[0]:.3f}, ε₂={sq.shape[1]:.3f}\n"
186
+ info_text += f"- Scale: ({sq.scale[0]:.2f}, {sq.scale[1]:.2f}, {sq.scale[2]:.2f})\n"
187
+ info_text += f"- Translation: ({sq.translation[0]:.2f}, {sq.translation[1]:.2f}, {sq.translation[2]:.2f})\n"
188
+ else:
189
+ info_text += "**No superquadric fitted**"
190
+
191
+ info_display.value = info_text
192
+
193
+ color = colors[selected['idx'] % len(colors)]
194
+
195
+ # Add point cloud
196
+ if show_points.value:
197
+ current_viz['points'] = server.scene.add_point_cloud(
198
+ "/current/points",
199
+ points=selected['points'],
200
+ colors=np.array([color] * len(selected['points']), dtype=np.uint8),
201
+ point_size=point_size.value,
202
+ )
203
+
204
+ # Add superquadric
205
+ if show_quadric.value and selected['quadric'] is not None:
206
+ try:
207
+ vertices, faces = generate_superquadric_mesh(selected['quadric'], num_samples=25)
208
+
209
+ current_viz['mesh'] = server.scene.add_mesh_simple(
210
+ "/current/mesh",
211
+ vertices=vertices,
212
+ faces=faces,
213
+ color=color,
214
+ opacity=mesh_opacity.value,
215
+ )
216
+
217
+ # Add label
218
+ current_viz['label'] = server.scene.add_label(
219
+ "/current/label",
220
+ text=f"ε₁={selected['quadric'].shape[0]:.2f}, ε₂={selected['quadric'].shape[1]:.2f}",
221
+ position=selected['quadric'].translation,
222
+ )
223
+
224
+ except Exception as e:
225
+ print(f"Error visualizing quadric: {e}")
226
+
227
+ # Set up callbacks
228
+ @current_sample.on_update
229
+ def _(_):
230
+ update_scene()
231
+
232
+ @show_points.on_update
233
+ def _(_):
234
+ update_scene()
235
+
236
+ @show_quadric.on_update
237
+ def _(_):
238
+ update_scene()
239
+
240
+ @point_size.on_update
241
+ def _(event):
242
+ if current_viz['points'] is not None:
243
+ current_viz['points'].point_size = event.target.value
244
+
245
+ @mesh_opacity.on_update
246
+ def _(event):
247
+ if current_viz['mesh'] is not None:
248
+ current_viz['mesh'].opacity = event.target.value
249
+
250
+ # Initial scene
251
+ update_scene()
252
+
253
+ # Keep server running
254
+ try:
255
+ while True:
256
+ time.sleep(0.1)
257
+ except KeyboardInterrupt:
258
+ print("\nShutting down server...")
259
+ server.stop()
260
+
261
+
262
+ if __name__ == "__main__":
263
+ main()
EMS-superquadric_fitting_inference/pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=42"]
3
+ build-backend = "setuptools.build_meta"
EMS-superquadric_fitting_inference/setup.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+ with open("README.md", "r", encoding="utf-8") as fh:
4
+ long_description = fh.read()
5
+
6
+ setuptools.setup(
7
+ name='EMS',
8
+ version='0.0.1',
9
+ description='EMS: a package for probabilistic recovery of superquadrics from point clouds',
10
+ url='https://github.com/bmlklwx/EMS-probabilistic_superquadric_fitting.git',
11
+ author='Weixiao Liu, Yuwei Wu, Sipu Ruan, Gregory Chirikjian',
12
+ author_email='wliu72@jhu.edu',
13
+ long_description=long_description,
14
+ long_description_content_type="text/markdown",
15
+
16
+ install_requires=[
17
+ 'numpy',
18
+ 'scipy',
19
+ 'plyfile',
20
+ 'mayavi',
21
+ 'numba'
22
+ ],
23
+ classifiers=[
24
+ "Programming Language :: Python :: 3",
25
+ "License :: OSI Approved :: MIT License",
26
+ "Operating System :: OS Independent",
27
+ ],
28
+
29
+ package_dir={"": "src"},
30
+ packages=setuptools.find_packages(where="src"),
31
+ python_requires='>=3.6'
32
+ )
EMS-superquadric_fitting_inference/src/EMS/EMS_recovery.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from numba import njit
3
+ from scipy.optimize import least_squares
4
+
5
+ from EMS.superquadrics import rotations, superquadric
6
+
7
+ def EMS_recovery(
8
+ point, OutlierRatio=0.1, MaxIterationEM=20,
9
+ ToleranceEM=1e-3, RelativeToleranceEM=1e-1,
10
+ MaxOptiIterations=3, Sigma=0, MaxiSwitch=2,
11
+ AdaptiveUpperBound=False, Rescale=True):
12
+ # The function conducting probabilistic superquadric recovery.
13
+ # Input: point - point cloud np array of N * 3
14
+ #
15
+
16
+ # ---------------------------------------INITIALIZATIONS--------------------------------------------
17
+ # translate the points to the center of mass
18
+ point = np.array(point, dtype=float)
19
+ t0 = np.mean(point, 0)
20
+ point = point - t0
21
+
22
+ # rescale
23
+ if Rescale is True:
24
+ max_length = np.max(point)
25
+ scale = max_length / 10
26
+ point = point / scale
27
+
28
+ # eigen analysis for rotation initialization
29
+ EigVec = EigenAnalysis(point)
30
+ R0 = rotations()
31
+ R0.RotM = np.array([-EigVec[:, 0], -EigVec[:, 2],
32
+ np.cross(EigVec[:, 0], EigVec[:, 2])]).T
33
+ euler0 = R0.euler
34
+
35
+ # scale initialization
36
+ point_rot0 = point @ R0.RotM
37
+ s0 = np.median(np.abs(point_rot0), 0)
38
+
39
+ # initialize configuration
40
+ x0 = np.array([1.0, 1.0, s0[0], s0[1], s0[2],
41
+ euler0[0], euler0[1], euler0[2], 0, 0, 0])
42
+
43
+ # set lower and upper bounds for the superquadrics
44
+ upper = 4 * np.max(np.abs(point))
45
+ lb = np.array([0, 0, 0.001, 0.001, 0.001, -2 * np.pi, -2 *
46
+ np.pi, -2 * np.pi, -upper, -upper, -upper])
47
+ ub = np.array([2.0, 2.0, upper, upper, upper, 2 * np.pi,
48
+ 2 * np.pi, 2 * np.pi, upper, upper, upper])
49
+
50
+ # calculate bounding volume of ourlier space
51
+ V = BoundVolume(point_rot0)
52
+
53
+ # set prior outlier density
54
+ p0 = 1 / V
55
+
56
+ # initialize variance
57
+ if Sigma == 0:
58
+ sigma2 = V ** (1 / 3) / 10
59
+ else:
60
+ sigma2 = Sigma
61
+
62
+ # initialize EMS
63
+ x = x0
64
+ cost = 0.0
65
+ num_switch = int(0)
66
+ p = np.ones(point.shape[0])
67
+
68
+ # ---------------------------------------EMS ALGORITHM--------------------------------------------
69
+ for iterEM in range(MaxIterationEM):
70
+ # evaluating distance from points to superquadric
71
+ dist = Distance(point, x)
72
+
73
+ # inferring the postierior outlier probability (E-step)
74
+ if OutlierRatio != 0:
75
+ p = OutlierProb(dist, sigma2, OutlierRatio, p0)
76
+
77
+ # calculate adaptive upper bound
78
+ if AdaptiveUpperBound is True:
79
+ R_cur = Euler2RotM(x[5: 8])
80
+ point_cur = point @ R_cur - x[8: 11] @ R_cur
81
+ ub_a = 1.1 * np.max(np.abs(point_cur), 0)
82
+ ub[2: 5] = ub_a
83
+ ub[8: 11] = ub_a
84
+ lb[8: 11] = -ub_a
85
+
86
+ # Optimize the superquadric configuration (M-step)
87
+ optfunc = least_squares(CostFunc, x, bounds=(
88
+ lb, ub), max_nfev=MaxOptiIterations, args=(point, p, sigma2))
89
+ x_n = optfunc.x
90
+ cost_n = 2 * optfunc.cost
91
+
92
+ # update sigma
93
+ sigma2_n = cost_n / (3 * np.sum(p))
94
+
95
+ # evaluate raletive decreasing of cost
96
+ relative_cost = (cost - cost_n) / cost_n
97
+
98
+ # check optimality for termination
99
+ if (cost_n < ToleranceEM and iterEM > 0) or \
100
+ (relative_cost < RelativeToleranceEM and num_switch >= MaxiSwitch and iterEM > 4):
101
+ x = x_n
102
+ break
103
+
104
+ # check for entering similarity switch
105
+ if relative_cost < RelativeToleranceEM and iterEM > 0:
106
+ # entering similarity switch (S-step)
107
+ # initialize swith success flag
108
+ switch_success = False
109
+
110
+ # search for similarity candidates
111
+ x_candidate = SimilarityCandidates(x)
112
+
113
+ # evaluating switch (S-step)
114
+ x, cost, sigma2, switch_success = Switch(
115
+ x_candidate, point, p, AdaptiveUpperBound, ub, lb, MaxOptiIterations, \
116
+ sigma2, sigma2_n, cost, cost_n, x_n, switch_success
117
+ )
118
+
119
+ num_switch = num_switch + 1
120
+
121
+ else:
122
+ # update parameter and prepare for the next EM iteration
123
+ cost = cost_n
124
+ sigma2 = sigma2_n
125
+ x = x_n
126
+
127
+ if Rescale is True:
128
+ x[2 : 5] = x[2 : 5] * scale
129
+ x[8 : 11] = x[8 : 11] * scale
130
+
131
+ x[8 : 11] = x[8 : 11] + t0
132
+
133
+ sq = superquadric(x[0 : 2], x[2 : 5], x[5 : 8], x[8 : 11])
134
+
135
+ return sq, p
136
+
137
+ # ---------------------------------------UTILITIES-------------------------------------------
138
+ @njit(cache=True)
139
+ def SimilarityCandidates(x):
140
+ # axis mismatch similarity
141
+ axis_0 = Euler2RotM(x[5: 8])
142
+ axis_1 = axis_0[:, np.array([1, 2, 0])]
143
+ axis_2 = axis_0[:, np.array([2, 0, 1])]
144
+ eul_1 = RotM2Euler(axis_1)
145
+ eul_2 = RotM2Euler(axis_2)
146
+ x_axis = np.array(
147
+ [[x[1], x[0], x[3], x[4], x[2], eul_1[0], eul_1[1], eul_1[2], x[8], x[9], x[10]],
148
+ [x[1], x[0], x[4], x[2], x[3], eul_2[0], eul_2[1], eul_2[2], x[8], x[9], x[10]]]
149
+ )
150
+
151
+ # duality similarities
152
+ scale_ratio = x[np.array([3, 4, 2])] / x[2 : 5]
153
+ scale_idx = np.argwhere(np.logical_and(scale_ratio > 0.6, scale_ratio < 1.4))
154
+ x_rot = np.zeros((scale_idx.shape[0], 11))
155
+
156
+ for idx in range(scale_idx.shape[0]):
157
+ if scale_idx[idx, 0] == 0:
158
+ eul_rot = RotM2Euler(axis_0 @ Euler2RotM(np.array([np.pi / 4, 0.0, 0.0])))
159
+ if x[1] <= 1:
160
+ x_rot[idx, :] = np.array(
161
+ [x[0], 2 - x[1],
162
+ ((1 - np.sqrt(2)) * x[1] + np.sqrt(2)) * min(x[2], x[3]),
163
+ ((1 - np.sqrt(2)) * x[1] + np.sqrt(2)) * min(x[2], x[3]),
164
+ x[4], eul_rot[0], eul_rot[1], eul_rot[2],
165
+ x[8], x[9], x[10]]
166
+ )
167
+ else:
168
+ x_rot[idx, :] = np.array(
169
+ [x[0], 2 - x[1],
170
+ ((np.sqrt(2)/2 - 1) * x[1] + 2 - np.sqrt(2) / 2) * min(x[2], x[3]),
171
+ ((np.sqrt(2)/2 - 1) * x[1] + 2 - np.sqrt(2) / 2) * min(x[2], x[3]),
172
+ x[4], eul_rot[0], eul_rot[1], eul_rot[2],
173
+ x[8], x[9], x[10]]
174
+ )
175
+
176
+ elif scale_idx[idx, 0] == 1:
177
+ eul_rot = RotM2Euler(axis_1 @ Euler2RotM(np.array([np.pi / 4, 0.0, 0.0])))
178
+ if x[0] <= 1:
179
+ x_rot[idx, :] = np.array(
180
+ [x[1], 2 - x[0],
181
+ ((1 - np.sqrt(2)) * x[0] + np.sqrt(2)) * min(x[3], x[4]),
182
+ ((1 - np.sqrt(2)) * x[0] + np.sqrt(2)) * min(x[3], x[4]),
183
+ x[2], eul_rot[0], eul_rot[1], eul_rot[2],
184
+ x[8], x[9], x[10]]
185
+ )
186
+ else:
187
+ x_rot[idx, :] = np.array(
188
+ [x[1], 2 - x[0],
189
+ ((np.sqrt(2)/2 - 1) * x[0] + 2 - np.sqrt(2)/2) * min(x[3], x[4]),
190
+ ((np.sqrt(2)/2 - 1) * x[0] + 2 - np.sqrt(2)/2) * min(x[3], x[4]),
191
+ x[2], eul_rot[0], eul_rot[1], eul_rot[2],
192
+ x[8], x[9], x[10]]
193
+ )
194
+
195
+ elif scale_idx[idx, 0] == 2:
196
+ eul_rot = RotM2Euler(axis_2 @ Euler2RotM(np.array([np.pi / 4, 0.0, 0.0])))
197
+ if x[0] <= 1:
198
+ x_rot[idx, :] = np.array(
199
+ [x[1], 2 - x[0],
200
+ ((1 - np.sqrt(2)) * x[0] + np.sqrt(2)) * min(x[4], x[2]),
201
+ ((1 - np.sqrt(2)) * x[0] + np.sqrt(2)) * min(x[4], x[2]),
202
+ x[2], eul_rot[0], eul_rot[1], eul_rot[2],
203
+ x[8], x[9], x[10]]
204
+ )
205
+ else:
206
+ x_rot[idx, :] = np.array(
207
+ [x[1], 2 - x[0],
208
+ ((np.sqrt(2)/2 - 1) * x[0] + 2 - np.sqrt(2)/2) * min(x[4], x[2]),
209
+ ((np.sqrt(2)/2 - 1) * x[0] + 2 - np.sqrt(2)/2) * min(x[4], x[2]),
210
+ x[2], eul_rot[0], eul_rot[1], eul_rot[2],
211
+ x[8], x[9], x[10]]
212
+ )
213
+
214
+ x_candidate = np.zeros((2 + x_rot.shape[0], 11))
215
+ x_candidate[0 : 2] = x_axis
216
+ if scale_idx.shape[0] > 0:
217
+ x_candidate[2 : 2 + scale_idx.shape[0]] = x_rot
218
+
219
+ return x_candidate
220
+
221
+ def Switch(
222
+ x_candidate, point, p, AdaptiveUpperBound, ub, lb, MaxOptiIterations, \
223
+ sigma2, sigma2_n, cost, cost_n, x_n, switch_success
224
+ ):
225
+ cost_candidate = SwitchCost(x_candidate, point, p)
226
+ idx_nan = np.argwhere(
227
+ np.logical_and(~np.isnan(cost_candidate), ~np.isinf(cost_candidate))
228
+ ).reshape(1, -1)[0]
229
+
230
+ cost_candidate = cost_candidate[idx_nan]
231
+ idx = np.argsort(cost_candidate)
232
+
233
+ for i in idx:
234
+ if AdaptiveUpperBound is True:
235
+ R_cur = Euler2RotM(x_candidate[i, 5: 8])
236
+ point_cur = point @ R_cur - x_candidate[i, 8: 11] @ R_cur
237
+ ub_a = 1.1 * np.max(np.abs(point_cur), 0)
238
+ ub[2: 5] = ub_a
239
+ ub[8: 11] = ub_a
240
+ lb[8: 11] = -ub_a
241
+
242
+ x_candidate[i] = np.minimum(x_candidate[i], ub)
243
+ x_candidate[i] = np.maximum(x_candidate[i], lb)
244
+
245
+ optfunc = least_squares(CostFunc, x_candidate[i], bounds=(
246
+ lb, ub), max_nfev=MaxOptiIterations, args=(point, p, sigma2))
247
+ x_switch = optfunc.x
248
+ cost_switch = 2 * optfunc.cost
249
+
250
+ if cost_switch < min(cost_n, cost):
251
+ x = x_switch
252
+ cost = cost_switch
253
+
254
+ sigma2 = cost_switch / (3 * sum(p))
255
+ switch_success = True
256
+ break
257
+
258
+ if switch_success == False:
259
+ cost = cost_n
260
+ sigma2 = sigma2_n
261
+ x = x_n
262
+
263
+ return x, cost, sigma2, switch_success
264
+
265
+ @njit(cache=True)
266
+ def SwitchCost(x_candidate, point, p):
267
+ val = np.zeros(x_candidate.shape[0])
268
+ for i in range(x_candidate.shape[0]):
269
+ val[i] = np.sum(p * (Distance(point, x_candidate[i]) ** 2))
270
+ return val
271
+
272
+ @njit(cache=True)
273
+ def EigenAnalysis(point):
274
+ CovM = point.T @ point / point.shape[0]
275
+ EVal, EVec = np.linalg.eig(CovM)
276
+ idx = np.flip(np.argsort(EVal))
277
+ return EVec[:, idx]
278
+
279
+ @njit(cache=True)
280
+ def BoundVolume(point):
281
+ V = (np.max(point[:, 0]) - np.min(point[:, 0])) * \
282
+ (np.max(point[:, 1]) - np.min(point[:, 1])) * \
283
+ (np.max(point[:, 2]) - np.min(point[:, 2]))
284
+ return V
285
+
286
+ @njit(cache=True)
287
+ def Distance(point, x):
288
+ # approximate the distance from a point to its nearest point on the superquadric surface
289
+ # extract transformation from superquadric parameters
290
+ R = Euler2RotM(x[5: 8])
291
+ t = x[8: 11]
292
+
293
+ # transform to the canonical frame
294
+ point_c = point @ R - t @ R
295
+
296
+ # calculating radial distance
297
+ # r_norm = np.linalg.norm(point_c, axis=1)
298
+ r_norm = np.sqrt(np.sum(point_c ** 2, 1))
299
+
300
+ dist = r_norm * np.abs((
301
+ (((point_c[:, 0] / x[2]) ** 2) ** (1 / x[1]) +
302
+ ((point_c[:, 1] / x[3]) ** 2) ** (1 / x[1])) ** (x[1] / x[0]) +
303
+ ((point_c[:, 2] / x[4]) ** 2) ** (1 / x[0])) ** (-x[0] / 2) - 1
304
+ )
305
+ return dist
306
+
307
+ @njit(cache=True)
308
+ def CostFunc(x, point, p, sigma2):
309
+ if sigma2 > 1e-10:
310
+ value = p ** 0.5 * Distance(point, x)
311
+ else:
312
+ value = np.abs((p * Distance(point, x) ** 2 + 2 *
313
+ sigma2 * np.log(SurfaceArea(x)))) ** 0.5
314
+ return value
315
+
316
+ @njit(cache=True)
317
+ def OutlierProb(dist, sigma2, w, p0):
318
+ c = (2 * np.pi * sigma2) ** (- 3 / 2)
319
+ const = (w * p0) / (c * (1 - w))
320
+ p = np.exp(-1 / (2 * sigma2) * dist ** 2)
321
+ p = p / (const + p)
322
+ return p
323
+
324
+ @njit(cache=True)
325
+ def SurfaceArea(x):
326
+ a00 = 8 * (x[2] * x[3] + x[3] * x[4] + x[2] * x[4])
327
+ a02 = 8 * (x[2] ** 2 + x[3] ** 2) ** 0.5 * x[4] + 4 * x[2] * x[3]
328
+ a20 = 4 * (x[2] * (x[3] ** 2 + x[4] ** 2) ** 0.5 +
329
+ x[3] * (x[2] ** 2 + x[4] ** 2) ** 0.5)
330
+ a = (x[2] ** 2 + x[3] ** 2) ** 0.5
331
+ b = (x[3] ** 2 + x[4] ** 2) ** 0.5
332
+ c = (x[2] ** 2 + x[4] ** 2) ** 0.5
333
+ s = (a + b + c) / 2
334
+ a22 = 8 * (s * (s - a) * (s - b) * (s - c)) ** (1/2)
335
+ area = np.array([[1 - x[0] / 2, x[0] / 2]]) @ np.array([[a00, a02],
336
+ [a20, a22]]) @ np.array([[1 - x[1] / 2], [x[1] / 2]])
337
+ return area[0, 0]
338
+
339
+ @njit(cache=True)
340
+ def Euler2RotM(euler):
341
+ # from euler angles to rotation matrix (ZYX_intrinsic)
342
+
343
+ RotZ = np.array(
344
+ [[np.cos(euler[0]), -np.sin(euler[0]), 0.0],
345
+ [np.sin(euler[0]), np.cos(euler[0]), 0.0],
346
+ [0.0, 0.0, 1.0]]
347
+ )
348
+
349
+ RotY = np.array(
350
+ [[np.cos(euler[1]), 0.0, np.sin(euler[1])],
351
+ [0.0, 1.0, 0.0],
352
+ [-np.sin(euler[1]), 0.0, np.cos(euler[1])]]
353
+ )
354
+
355
+ RotX = np.array(
356
+ [[1.0, 0.0, 0.0],
357
+ [0.0, np.cos(euler[2]), -np.sin(euler[2])],
358
+ [0.0, np.sin(euler[2]), np.cos(euler[2])]]
359
+ )
360
+
361
+ return RotZ @ RotY @ RotX
362
+
363
+ @njit(cache=True)
364
+ def RotM2Euler(R):
365
+
366
+ s = np.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])
367
+ singular = s < 1e-6
368
+
369
+ if not singular:
370
+ x = np.arctan2(R[2, 1], R[2, 2])
371
+ y = np.arctan2(-R[2, 0], s)
372
+ z = np.arctan2(R[1, 0], R[0, 0])
373
+ else:
374
+ x = np.arctan2(-R[1, 2], R[1, 1])
375
+ y = np.arctan2(-R[2, 0], s)
376
+ z = 0
377
+
378
+ return np.array([z, y, x])
EMS-superquadric_fitting_inference/src/EMS/__init__.py ADDED
File without changes
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.BoundVolume-279.py311.1.nbc ADDED
Binary file (44.5 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.BoundVolume-279.py311.nbi ADDED
Binary file (1.59 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.BoundVolume-279.py312.1.nbc ADDED
Binary file (46.5 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.BoundVolume-279.py312.nbi ADDED
Binary file (1.33 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py311.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:575177a225271df742d51f983c9267f0ff29006fac837a569792fecf97e09396
3
+ size 268721
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py311.nbi ADDED
Binary file (1.66 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py312.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd12b1b67c67b3ee8c3af1db109a8a96a6071a0bdbbbf051dabbe3fc95e89025
3
+ size 195624
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.CostFunc-307.py312.nbi ADDED
Binary file (1.39 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py311.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a38d3fe9a0f6a6853e9ab94bb66c3319aa0f2bdcf9b1995165f5ef8fc5d3a676
3
+ size 218890
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py311.nbi ADDED
Binary file (1.65 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py312.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49a3e75f660f91481c1410a7388b106a817a348d2ca3093c9635cf93a4dc7b59
3
+ size 175972
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Distance-286.py312.nbi ADDED
Binary file (1.39 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py311.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:350b381765ffc2a6d02a6380dec17deb4fbf73636116b9208bd800ffef369dfb
3
+ size 136198
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py311.nbi ADDED
Binary file (1.59 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py312.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:748f9d90b283eaf074de14248d68a842979537c268b1a7277597b620ee161f71
3
+ size 147907
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.EigenAnalysis-272.py312.nbi ADDED
Binary file (1.33 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Euler2RotM-339.py311.1.nbc ADDED
Binary file (88.4 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Euler2RotM-339.py311.nbi ADDED
Binary file (1.59 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Euler2RotM-339.py312.1.nbc ADDED
Binary file (91.6 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.Euler2RotM-339.py312.nbi ADDED
Binary file (1.33 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.OutlierProb-316.py311.1.nbc ADDED
Binary file (36.3 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.OutlierProb-316.py311.nbi ADDED
Binary file (1.6 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.OutlierProb-316.py312.1.nbc ADDED
Binary file (34.2 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.OutlierProb-316.py312.nbi ADDED
Binary file (1.34 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.RotM2Euler-363.py311.1.nbc ADDED
Binary file (38 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.RotM2Euler-363.py311.nbi ADDED
Binary file (1.59 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.RotM2Euler-363.py312.1.nbc ADDED
Binary file (38.9 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.RotM2Euler-363.py312.nbi ADDED
Binary file (1.33 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py311.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2451c48b493ecdcf4fe2b870a7dc7f96e4160c6011eed929c6e8fff5a480353
3
+ size 488688
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py311.nbi ADDED
Binary file (1.6 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py312.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05a72a147c53aa27bd81fc144108c22d9fb3aa30221ee4c3f9c069116d0a5d0e
3
+ size 454977
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SimilarityCandidates-138.py312.nbi ADDED
Binary file (1.34 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SurfaceArea-324.py311.1.nbc ADDED
Binary file (84.4 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SurfaceArea-324.py311.nbi ADDED
Binary file (1.59 kB). View file
 
EMS-superquadric_fitting_inference/src/EMS/__pycache__/EMS_recovery.SurfaceArea-324.py312.1.nbc ADDED
Binary file (87.6 kB). View file