Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -0
- 40w_2000-100000edge_2000-75000active.txt +3 -0
- 40w_2000-200000edge_2000-100000active.txt +3 -0
- MERGED_DATASET_count_200_2000_10000_train_2000min_100000max.txt +0 -0
- MERGED_DATASET_filtered_2000-100000edge_2000-32678active.txt +0 -0
- __pycache__/dataset_triposf.cpython-310.pyc +0 -0
- __pycache__/dataset_triposf_head.cpython-310.pyc +0 -0
- __pycache__/query_point.cpython-310.pyc +0 -0
- __pycache__/utils.bresenham_3d_array-192.py310.1.nbc +3 -0
- __pycache__/utils.bresenham_3d_array-192.py310.nbi +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- __pycache__/vertex_encoder.cpython-310.pyc +0 -0
- config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_addhead.yaml +82 -0
- config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_head_woca.yaml +97 -0
- config_edge_1024_error_8enc_8dec_woself_finetune_128to512.yaml +101 -0
- config_slat_flow_128to256_pointnet_test.yaml +124 -0
- dataset_triposf.py +924 -0
- dataset_triposf_head.py +1000 -0
- debug_viz/step_0_batch_0.ply +0 -0
- debug_viz/step_0_batch_1.ply +0 -0
- filter_active_voxels.py +106 -0
- generate_npz.py +118 -0
- mesh_augment.py +79 -0
- metric.py +300 -0
- metric_cd.py +190 -0
- query_point.py +259 -0
- test_slat_flow_128to1024_pointnet.py +403 -0
- test_slat_flow_128to256_pointnet.py +403 -0
- test_slat_vae_128to1024_pointnet.py +0 -0
- test_slat_vae_128to1024_pointnet_vae.py +0 -0
- test_slat_vae_128to1024_pointnet_vae_addhead.py +0 -0
- test_slat_vae_128to1024_pointnet_vae_head.py +1339 -0
- test_slat_vae_128to1024_pointnet_vae_head_woca.py +0 -0
- test_slat_vae_128to256_pointnet_vae_head.py +1349 -0
- test_slat_vae_128to512_pointnet_vae_head.py +1636 -0
- train_slat_flow_128to1024_pointnet.py +484 -0
- train_slat_vae_512_128to1024_pointnet.py +682 -0
- train_slat_vae_512_128to1024_pointnet_addhead.py +788 -0
- train_slat_vae_512_128to1024_pointnet_head.py +930 -0
- train_slat_vae_512_128to256_pointnet_head.py +917 -0
- train_slat_vae_512_128to512_pointnet_head.py +1090 -0
- trellis/__init__.py +6 -0
- trellis/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/datasets/__init__.py +58 -0
- trellis/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/datasets/__pycache__/components.cpython-310.pyc +0 -0
- trellis/datasets/__pycache__/sparse_structure_latent.cpython-310.pyc +0 -0
- trellis/datasets/components.py +137 -0
- trellis/datasets/sparse_feat2render.py +134 -0
- trellis/datasets/sparse_structure.py +107 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
40w_2000-100000edge_2000-75000active.txt filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
40w_2000-200000edge_2000-100000active.txt filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
__pycache__/utils.bresenham_3d_array-192.py310.1.nbc filter=lfs diff=lfs merge=lfs -text
|
40w_2000-100000edge_2000-75000active.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e283bf011c720c1300f7c4621c15af8360df05f0bcfbd71faf984e2d61f6e17
|
| 3 |
+
size 38231752
|
40w_2000-200000edge_2000-100000active.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46aae1f5d9bce421cb402a69b3b112daf586ed8f470c7de4b1e6d73adc82ef56
|
| 3 |
+
size 44315087
|
MERGED_DATASET_count_200_2000_10000_train_2000min_100000max.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
MERGED_DATASET_filtered_2000-100000edge_2000-32678active.txt
ADDED
|
File without changes
|
__pycache__/dataset_triposf.cpython-310.pyc
ADDED
|
Binary file (24.8 kB). View file
|
|
|
__pycache__/dataset_triposf_head.cpython-310.pyc
ADDED
|
Binary file (26.1 kB). View file
|
|
|
__pycache__/query_point.cpython-310.pyc
ADDED
|
Binary file (8.77 kB). View file
|
|
|
__pycache__/utils.bresenham_3d_array-192.py310.1.nbc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d783005c5692d0ddfb1fc45298776ac767bde1326be014a6e65e76e41aeef898
|
| 3 |
+
size 113601
|
__pycache__/utils.bresenham_3d_array-192.py310.nbi
ADDED
|
Binary file (1.59 kB). View file
|
|
|
__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (24.5 kB). View file
|
|
|
__pycache__/vertex_encoder.cpython-310.pyc
ADDED
|
Binary file (19.2 kB). View file
|
|
|
config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_addhead.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"dataset":
|
| 2 |
+
"base_resolution": 1024
|
| 3 |
+
"cache_dir": "/gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to1024_819200_head"
|
| 4 |
+
"cache_filter_path": "/gemini/user/private/zhaotianhao/Triposf/MERGED_DATASET_filtered_2000-75000edge_2000-326780active.txt"
|
| 5 |
+
"filter_active_voxels": true
|
| 6 |
+
"min_resolution": 128
|
| 7 |
+
"n_train_samples": 1024
|
| 8 |
+
"path": "/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/train"
|
| 9 |
+
"renders_dir": "None"
|
| 10 |
+
"sample_type": "dora"
|
| 11 |
+
"experiment":
|
| 12 |
+
"save_dir": "/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs1_128to1024_wolabel_dir_sorted_dora_bigger_addhead"
|
| 13 |
+
"model":
|
| 14 |
+
"add_block_embed": true
|
| 15 |
+
"add_direction": false
|
| 16 |
+
"add_edge_glb_feats": true
|
| 17 |
+
"attn_first": false
|
| 18 |
+
"block_size": 16
|
| 19 |
+
"decoder_blocks_edge":
|
| 20 |
+
- "in_channels": 768
|
| 21 |
+
"model_channels": 768
|
| 22 |
+
"num_blocks": 0
|
| 23 |
+
"num_heads": 0
|
| 24 |
+
"out_channels": 128
|
| 25 |
+
"resolution": 128
|
| 26 |
+
- "in_channels": 128
|
| 27 |
+
"model_channels": 128
|
| 28 |
+
"num_blocks": 0
|
| 29 |
+
"num_heads": 0
|
| 30 |
+
"out_channels": 64
|
| 31 |
+
"resolution": 256
|
| 32 |
+
- "in_channels": 64
|
| 33 |
+
"model_channels": 64
|
| 34 |
+
"num_blocks": 0
|
| 35 |
+
"num_heads": 0
|
| 36 |
+
"out_channels": 32
|
| 37 |
+
"resolution": 512
|
| 38 |
+
"decoder_blocks_vtx":
|
| 39 |
+
- "in_channels": 768
|
| 40 |
+
"model_channels": 768
|
| 41 |
+
"num_blocks": 0
|
| 42 |
+
"num_heads": 0
|
| 43 |
+
"out_channels": 128
|
| 44 |
+
"resolution": 128
|
| 45 |
+
- "in_channels": 128
|
| 46 |
+
"model_channels": 128
|
| 47 |
+
"num_blocks": 0
|
| 48 |
+
"num_heads": 0
|
| 49 |
+
"out_channels": 64
|
| 50 |
+
"resolution": 256
|
| 51 |
+
- "in_channels": 64
|
| 52 |
+
"model_channels": 64
|
| 53 |
+
"num_blocks": 0
|
| 54 |
+
"num_heads": 0
|
| 55 |
+
"out_channels": 32
|
| 56 |
+
"resolution": 512
|
| 57 |
+
"embed_dim": 1024
|
| 58 |
+
"encoder_blocks":
|
| 59 |
+
- "in_channels": 1024
|
| 60 |
+
"model_channels": 768
|
| 61 |
+
"num_blocks": 12
|
| 62 |
+
"num_heads": 12
|
| 63 |
+
"out_channels": 768
|
| 64 |
+
"in_channels": 1024
|
| 65 |
+
"latent_dim": 16
|
| 66 |
+
"model_channels": 384
|
| 67 |
+
"multires": 12
|
| 68 |
+
"pos_encoding": "nerf"
|
| 69 |
+
"pred_direction": true
|
| 70 |
+
"relative_embed": true
|
| 71 |
+
"using_attn": false
|
| 72 |
+
"training":
|
| 73 |
+
"batch_size": 1
|
| 74 |
+
"checkpoint_path": "/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs1_128to1024_wolabel_dir_sorted_dora_bigger/checkpoint_epoch14_batch10432_loss0.1438.pt"
|
| 75 |
+
"from_pretrained": true
|
| 76 |
+
"gamma": 0.95
|
| 77 |
+
"lr": 0.0001
|
| 78 |
+
"max_epochs": 100
|
| 79 |
+
"num_workers": 12
|
| 80 |
+
"save_every": 1
|
| 81 |
+
"start_epoch": 1
|
| 82 |
+
"step_size": 1
|
config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_head_woca.yaml
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
path: /gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/train
|
| 3 |
+
cache_dir: /gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to1024_819200_head
|
| 4 |
+
|
| 5 |
+
renders_dir: None
|
| 6 |
+
|
| 7 |
+
filter_active_voxels: true
|
| 8 |
+
cache_filter_path: /gemini/user/private/zhaotianhao/Triposf/MERGED_DATASET_filtered_2000-75000edge_2000-326780active.txt
|
| 9 |
+
|
| 10 |
+
base_resolution: 1024
|
| 11 |
+
min_resolution: 128
|
| 12 |
+
|
| 13 |
+
n_train_samples: 1024
|
| 14 |
+
sample_type: dora
|
| 15 |
+
|
| 16 |
+
model:
|
| 17 |
+
pred_direction: false
|
| 18 |
+
relative_embed: true
|
| 19 |
+
using_attn: false
|
| 20 |
+
add_block_embed: true
|
| 21 |
+
multires: 12
|
| 22 |
+
|
| 23 |
+
embed_dim: 1024 #64
|
| 24 |
+
in_channels: 1024 #64
|
| 25 |
+
model_channels: 384
|
| 26 |
+
latent_dim: 16
|
| 27 |
+
|
| 28 |
+
block_size: 16
|
| 29 |
+
pos_encoding: 'nerf'
|
| 30 |
+
# pos_encoding: 'embedding'
|
| 31 |
+
attn_first: false
|
| 32 |
+
|
| 33 |
+
add_edge_glb_feats: true
|
| 34 |
+
add_direction: false
|
| 35 |
+
|
| 36 |
+
encoder_blocks:
|
| 37 |
+
- in_channels: 1024
|
| 38 |
+
model_channels: 512
|
| 39 |
+
num_blocks: 8
|
| 40 |
+
num_heads: 8
|
| 41 |
+
out_channels: 512
|
| 42 |
+
|
| 43 |
+
decoder_blocks_edge:
|
| 44 |
+
- in_channels: 512
|
| 45 |
+
model_channels: 512
|
| 46 |
+
num_blocks: 0
|
| 47 |
+
num_heads: 0
|
| 48 |
+
out_channels: 256
|
| 49 |
+
resolution: 128
|
| 50 |
+
- in_channels: 256
|
| 51 |
+
model_channels: 256
|
| 52 |
+
num_blocks: 0
|
| 53 |
+
num_heads: 0
|
| 54 |
+
out_channels: 128
|
| 55 |
+
resolution: 256
|
| 56 |
+
- in_channels: 128
|
| 57 |
+
model_channels: 128
|
| 58 |
+
num_blocks: 0
|
| 59 |
+
num_heads: 0
|
| 60 |
+
out_channels: 64
|
| 61 |
+
resolution: 512
|
| 62 |
+
|
| 63 |
+
decoder_blocks_vtx:
|
| 64 |
+
- in_channels: 512
|
| 65 |
+
model_channels: 512
|
| 66 |
+
num_blocks: 0
|
| 67 |
+
num_heads: 0
|
| 68 |
+
out_channels: 256
|
| 69 |
+
resolution: 128
|
| 70 |
+
- in_channels: 256
|
| 71 |
+
model_channels: 256
|
| 72 |
+
num_blocks: 0
|
| 73 |
+
num_heads: 0
|
| 74 |
+
out_channels: 128
|
| 75 |
+
resolution: 256
|
| 76 |
+
- in_channels: 128
|
| 77 |
+
model_channels: 128
|
| 78 |
+
num_blocks: 0
|
| 79 |
+
num_heads: 0
|
| 80 |
+
out_channels: 64
|
| 81 |
+
resolution: 512
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
training:
|
| 85 |
+
batch_size: 2
|
| 86 |
+
lr: 1.e-4
|
| 87 |
+
step_size: 1
|
| 88 |
+
gamma: 0.95
|
| 89 |
+
save_every: 5
|
| 90 |
+
start_epoch: 0
|
| 91 |
+
max_epochs: 200
|
| 92 |
+
num_workers: 4
|
| 93 |
+
from_pretrained: false
|
| 94 |
+
checkpoint_path: /root/Trisf/experiments_edge/vae/train_9w_200_2000face/9w_128to1024/checkpoint_epoch14_batch5216_loss0.2745.pt
|
| 95 |
+
|
| 96 |
+
experiment:
|
| 97 |
+
save_dir: "/root/Trisf/experiments_edge/vae/{dataset_name}_9w_200_2000face/shapenet_bs{batch_size}_128to1024_wolabel_dir_sorted_dora_small_allasyloss"
|
config_edge_1024_error_8enc_8dec_woself_finetune_128to512.yaml
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
# path: /gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/train
|
| 3 |
+
# cache_dir: /gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to1024_819200_head
|
| 4 |
+
|
| 5 |
+
path: /home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01
|
| 6 |
+
cache_dir: /home/tiger/yy/src/dataset_cache/unique_files_glb_under6000face_2degree_30ratio_0.01
|
| 7 |
+
|
| 8 |
+
renders_dir: None
|
| 9 |
+
|
| 10 |
+
filter_active_voxels: true
|
| 11 |
+
cache_filter_path: /home/tiger/yy/src/Michelangelo-master/40w_2000-200000edge_2000-100000active.txt # 5epoch
|
| 12 |
+
# cache_filter_path: /home/tiger/yy/src/Michelangelo-master/40w_2000-100000edge_2000-75000active.txt # 0-4 epoch
|
| 13 |
+
|
| 14 |
+
base_resolution: 512
|
| 15 |
+
min_resolution: 128
|
| 16 |
+
|
| 17 |
+
n_train_samples: 1024
|
| 18 |
+
sample_type: dora
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
pred_direction: false
|
| 22 |
+
relative_embed: true
|
| 23 |
+
using_attn: false
|
| 24 |
+
add_block_embed: true
|
| 25 |
+
multires: 12
|
| 26 |
+
|
| 27 |
+
embed_dim: 1024 #64
|
| 28 |
+
in_channels: 1024 #64
|
| 29 |
+
model_channels: 384
|
| 30 |
+
latent_dim: 16
|
| 31 |
+
|
| 32 |
+
block_size: 16
|
| 33 |
+
pos_encoding: 'nerf'
|
| 34 |
+
attn_first: false
|
| 35 |
+
|
| 36 |
+
add_edge_glb_feats: true
|
| 37 |
+
add_direction: false
|
| 38 |
+
|
| 39 |
+
encoder_blocks:
|
| 40 |
+
- in_channels: 1024
|
| 41 |
+
model_channels: 512
|
| 42 |
+
num_blocks: 8
|
| 43 |
+
num_heads: 8
|
| 44 |
+
out_channels: 512
|
| 45 |
+
|
| 46 |
+
decoder_blocks_edge:
|
| 47 |
+
- in_channels: 512
|
| 48 |
+
model_channels: 512
|
| 49 |
+
num_blocks: 0
|
| 50 |
+
num_heads: 0
|
| 51 |
+
out_channels: 256
|
| 52 |
+
resolution: 128
|
| 53 |
+
- in_channels: 256
|
| 54 |
+
model_channels: 256
|
| 55 |
+
num_blocks: 0
|
| 56 |
+
num_heads: 0
|
| 57 |
+
out_channels: 128
|
| 58 |
+
resolution: 256
|
| 59 |
+
# - in_channels: 128
|
| 60 |
+
# model_channels: 128
|
| 61 |
+
# num_blocks: 0
|
| 62 |
+
# num_heads: 0
|
| 63 |
+
# out_channels: 64
|
| 64 |
+
# resolution: 512
|
| 65 |
+
|
| 66 |
+
decoder_blocks_vtx:
|
| 67 |
+
- in_channels: 512
|
| 68 |
+
model_channels: 512
|
| 69 |
+
num_blocks: 0
|
| 70 |
+
num_heads: 0
|
| 71 |
+
out_channels: 256
|
| 72 |
+
resolution: 128
|
| 73 |
+
- in_channels: 256
|
| 74 |
+
model_channels: 256
|
| 75 |
+
num_blocks: 0
|
| 76 |
+
num_heads: 0
|
| 77 |
+
out_channels: 128
|
| 78 |
+
resolution: 256
|
| 79 |
+
# - in_channels: 128
|
| 80 |
+
# model_channels: 128
|
| 81 |
+
# num_blocks: 0
|
| 82 |
+
# num_heads: 0
|
| 83 |
+
# out_channels: 64
|
| 84 |
+
# resolution: 512
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
training:
|
| 88 |
+
batch_size: 2
|
| 89 |
+
lr: 4.e-5
|
| 90 |
+
step_size: 1
|
| 91 |
+
gamma: 0.95
|
| 92 |
+
save_every: 1
|
| 93 |
+
start_epoch: 0
|
| 94 |
+
max_epochs: 10
|
| 95 |
+
num_workers: 32
|
| 96 |
+
from_pretrained: true
|
| 97 |
+
# checkpoint_path: /home/tiger/yy/src/checkpoint_epoch2_batch30000_loss0.1829.pt
|
| 98 |
+
checkpoint_path: /home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch7_batch10000_loss0.1175.pt
|
| 99 |
+
|
| 100 |
+
experiment:
|
| 101 |
+
save_dir: "/home/tiger/yy/src/checkpoints/vae/{dataset_name}/shapenet_bs{batch_size}_128to512_wolabel_dir_sorted_dora_small_lowlr"
|
config_slat_flow_128to256_pointnet_test.yaml
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
pred_direction: true
|
| 3 |
+
relative_embed: true
|
| 4 |
+
using_attn: false
|
| 5 |
+
add_block_embed: true
|
| 6 |
+
multires: 12
|
| 7 |
+
|
| 8 |
+
embed_dim: 1024
|
| 9 |
+
in_channels: 1024
|
| 10 |
+
model_channels: 384
|
| 11 |
+
latent_dim: 16
|
| 12 |
+
|
| 13 |
+
block_size: 16
|
| 14 |
+
pos_encoding: 'nerf'
|
| 15 |
+
attn_first: false
|
| 16 |
+
|
| 17 |
+
add_edge_glb_feats: true
|
| 18 |
+
add_direction: false
|
| 19 |
+
|
| 20 |
+
encoder_blocks:
|
| 21 |
+
- in_channels: 1024
|
| 22 |
+
model_channels: 512
|
| 23 |
+
num_blocks: 8
|
| 24 |
+
num_heads: 8
|
| 25 |
+
out_channels: 512
|
| 26 |
+
|
| 27 |
+
decoder_blocks_edge:
|
| 28 |
+
- in_channels: 512
|
| 29 |
+
model_channels: 512
|
| 30 |
+
num_blocks: 0
|
| 31 |
+
num_heads: 0
|
| 32 |
+
out_channels: 128
|
| 33 |
+
resolution: 128
|
| 34 |
+
# - in_channels: 128
|
| 35 |
+
# model_channels: 128
|
| 36 |
+
# num_blocks: 0
|
| 37 |
+
# num_heads: 0
|
| 38 |
+
# out_channels: 64
|
| 39 |
+
# resolution: 256
|
| 40 |
+
# - in_channels: 64
|
| 41 |
+
# model_channels: 64
|
| 42 |
+
# num_blocks: 0
|
| 43 |
+
# num_heads: 0
|
| 44 |
+
# out_channels: 32
|
| 45 |
+
# resolution: 512
|
| 46 |
+
|
| 47 |
+
decoder_blocks_vtx:
|
| 48 |
+
- in_channels: 512
|
| 49 |
+
model_channels: 512
|
| 50 |
+
num_blocks: 0
|
| 51 |
+
num_heads: 0
|
| 52 |
+
out_channels: 128
|
| 53 |
+
resolution: 128
|
| 54 |
+
# - in_channels: 128
|
| 55 |
+
# model_channels: 128
|
| 56 |
+
# num_blocks: 0
|
| 57 |
+
# num_heads: 0
|
| 58 |
+
# out_channels: 64
|
| 59 |
+
# resolution: 256
|
| 60 |
+
# - in_channels: 64
|
| 61 |
+
# model_channels: 64
|
| 62 |
+
# num_blocks: 0
|
| 63 |
+
# num_heads: 0
|
| 64 |
+
# out_channels: 32
|
| 65 |
+
# resolution: 512
|
| 66 |
+
|
| 67 |
+
"t_schedule":
|
| 68 |
+
"name": "logitNormal"
|
| 69 |
+
"args":
|
| 70 |
+
"mean": 1.0
|
| 71 |
+
"std": 1.0
|
| 72 |
+
|
| 73 |
+
"sigma_min": 1.e-5
|
| 74 |
+
|
| 75 |
+
training:
|
| 76 |
+
batch_size: 1
|
| 77 |
+
lr: 1.e-4
|
| 78 |
+
step_size: 20
|
| 79 |
+
gamma: 0.95
|
| 80 |
+
save_every: 2000
|
| 81 |
+
start_epoch: 0
|
| 82 |
+
max_epochs: 300000
|
| 83 |
+
num_workers: 4
|
| 84 |
+
|
| 85 |
+
output_dir: /gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope
|
| 86 |
+
clip_model_path: "/gemini/user/private/zhaotianhao/clip-vit-large-patch14"
|
| 87 |
+
dinov2_model_path: "/gemini/user/private/zhaotianhao/dinov2-large"
|
| 88 |
+
|
| 89 |
+
vae_path: /gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to256_dir_sorted_dora_head_small/checkpoint_epoch13_batch6000_loss0.1381.pt
|
| 90 |
+
denoiser_checkpoint_path: false
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
dataset:
|
| 94 |
+
path: /gemini/user/private/zhaotianhao/data/trellis_clean_mesh
|
| 95 |
+
path: /gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/train
|
| 96 |
+
|
| 97 |
+
cache_dir: /gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to256_819200_head
|
| 98 |
+
|
| 99 |
+
renders_dir: None
|
| 100 |
+
filter_active_voxels: false
|
| 101 |
+
cache_filter_path: None
|
| 102 |
+
|
| 103 |
+
base_resolution: 1024
|
| 104 |
+
min_resolution: 128
|
| 105 |
+
|
| 106 |
+
n_train_samples: 1024
|
| 107 |
+
sample_type: dora
|
| 108 |
+
|
| 109 |
+
flow:
|
| 110 |
+
"resolution": 128
|
| 111 |
+
"in_channels": 16
|
| 112 |
+
"out_channels": 16
|
| 113 |
+
"model_channels": 768
|
| 114 |
+
"cond_channels": 1024
|
| 115 |
+
"num_blocks": 8
|
| 116 |
+
"num_heads": 8
|
| 117 |
+
"mlp_ratio": 4
|
| 118 |
+
"patch_size": 2
|
| 119 |
+
"num_io_res_blocks": 2
|
| 120 |
+
"io_block_channels": [128]
|
| 121 |
+
"pe_mode": "rope"
|
| 122 |
+
"qk_rms_norm": true
|
| 123 |
+
"qk_rms_norm_cross": false
|
| 124 |
+
"use_fp16": false
|
dataset_triposf.py
ADDED
|
@@ -0,0 +1,924 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from typing import *
|
| 4 |
+
import trimesh
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
# from utils import quantize_vertices
|
| 9 |
+
from utils import get_voxel_line
|
| 10 |
+
import random
|
| 11 |
+
import hashlib
|
| 12 |
+
import json
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import rembg
|
| 17 |
+
import open3d as o3d
|
| 18 |
+
from trimesh import grouping
|
| 19 |
+
|
| 20 |
+
def normalize_mesh(mesh_path):
|
| 21 |
+
scene = trimesh.load(mesh_path, process=False, force='scene')
|
| 22 |
+
meshes = []
|
| 23 |
+
for node_name in scene.graph.nodes_geometry:
|
| 24 |
+
geom_name = scene.graph[node_name][1]
|
| 25 |
+
geometry = scene.geometry[geom_name]
|
| 26 |
+
transform = scene.graph[node_name][0]
|
| 27 |
+
if isinstance(geometry, trimesh.Trimesh):
|
| 28 |
+
geometry.apply_transform(transform)
|
| 29 |
+
meshes.append(geometry)
|
| 30 |
+
|
| 31 |
+
mesh = trimesh.util.concatenate(meshes)
|
| 32 |
+
|
| 33 |
+
center = mesh.bounding_box.centroid
|
| 34 |
+
mesh.apply_translation(-center)
|
| 35 |
+
scale = max(mesh.bounding_box.extents)
|
| 36 |
+
mesh.apply_scale(2.0 / scale * 0.5)
|
| 37 |
+
|
| 38 |
+
return mesh
|
| 39 |
+
|
| 40 |
+
def quantize_vertices(vertices: torch.Tensor, res: int):
|
| 41 |
+
"""
|
| 42 |
+
Quantize normalized vertices (range approx [-0.5, 0.5]) to integer grid [0, res-1].
|
| 43 |
+
"""
|
| 44 |
+
normalized = vertices + 0.5
|
| 45 |
+
scaled = normalized * res
|
| 46 |
+
quantized = torch.floor(scaled).clamp(0, res - 1).int()
|
| 47 |
+
return quantized
|
| 48 |
+
|
| 49 |
+
def sample_edges_dora(tm_mesh, n_samples):
|
| 50 |
+
adj_faces = tm_mesh.face_adjacency
|
| 51 |
+
adj_edges = tm_mesh.face_adjacency_edges
|
| 52 |
+
|
| 53 |
+
internal_data = None
|
| 54 |
+
if len(adj_faces) > 0:
|
| 55 |
+
n0 = tm_mesh.face_normals[adj_faces[:, 0]]
|
| 56 |
+
n1 = tm_mesh.face_normals[adj_faces[:, 1]]
|
| 57 |
+
sum_normals = n0 + n1
|
| 58 |
+
norms = np.linalg.norm(sum_normals, axis=1, keepdims=True)
|
| 59 |
+
norms[norms < 1e-6] = 1.0
|
| 60 |
+
int_normals = sum_normals / norms
|
| 61 |
+
|
| 62 |
+
int_v_start = tm_mesh.vertices[adj_edges[:, 0]]
|
| 63 |
+
int_v_end = tm_mesh.vertices[adj_edges[:, 1]]
|
| 64 |
+
|
| 65 |
+
faces_pair = tm_mesh.faces[adj_faces]
|
| 66 |
+
sum_face_indices = np.sum(faces_pair, axis=2)
|
| 67 |
+
sum_edge_indices = np.sum(adj_edges, axis=1)
|
| 68 |
+
|
| 69 |
+
unique_idx_0 = sum_face_indices[:, 0] - sum_edge_indices
|
| 70 |
+
unique_idx_1 = sum_face_indices[:, 1] - sum_edge_indices
|
| 71 |
+
|
| 72 |
+
v_unique_0 = tm_mesh.vertices[unique_idx_0]
|
| 73 |
+
v_unique_1 = tm_mesh.vertices[unique_idx_1]
|
| 74 |
+
int_v_virtual = (v_unique_0 + v_unique_1) * 0.5
|
| 75 |
+
|
| 76 |
+
internal_data = (int_v_start, int_v_end, int_normals, int_v_virtual)
|
| 77 |
+
|
| 78 |
+
edges_sorted = tm_mesh.edges_sorted
|
| 79 |
+
|
| 80 |
+
if len(edges_sorted) == 0:
|
| 81 |
+
boundary_data = None
|
| 82 |
+
else:
|
| 83 |
+
boundary_group = grouping.group_rows(edges_sorted, require_count=1)
|
| 84 |
+
|
| 85 |
+
boundary_data = None
|
| 86 |
+
if len(boundary_group) > 0:
|
| 87 |
+
boundary_indices = np.concatenate([np.atleast_1d(g) for g in boundary_group])
|
| 88 |
+
|
| 89 |
+
boundary_face_indices = boundary_indices // 3
|
| 90 |
+
|
| 91 |
+
bnd_normals = tm_mesh.face_normals[boundary_face_indices]
|
| 92 |
+
|
| 93 |
+
bnd_edge_v_indices = edges_sorted[boundary_indices]
|
| 94 |
+
bnd_v_start = tm_mesh.vertices[bnd_edge_v_indices[:, 0]]
|
| 95 |
+
bnd_v_end = tm_mesh.vertices[bnd_edge_v_indices[:, 1]]
|
| 96 |
+
|
| 97 |
+
boundary_face_v_indices = tm_mesh.faces[boundary_face_indices]
|
| 98 |
+
|
| 99 |
+
sum_face = np.sum(boundary_face_v_indices, axis=1)
|
| 100 |
+
sum_edge = np.sum(bnd_edge_v_indices, axis=1)
|
| 101 |
+
unique_idx = sum_face - sum_edge
|
| 102 |
+
|
| 103 |
+
bnd_v_virtual = tm_mesh.vertices[unique_idx]
|
| 104 |
+
|
| 105 |
+
boundary_data = (bnd_v_start, bnd_v_end, bnd_normals, bnd_v_virtual)
|
| 106 |
+
|
| 107 |
+
if internal_data is None and boundary_data is None:
|
| 108 |
+
return None, None, None
|
| 109 |
+
|
| 110 |
+
parts_start, parts_end, parts_norm, parts_virt = [], [], [], []
|
| 111 |
+
|
| 112 |
+
if internal_data is not None:
|
| 113 |
+
parts_start.append(internal_data[0])
|
| 114 |
+
parts_end.append(internal_data[1])
|
| 115 |
+
parts_norm.append(internal_data[2])
|
| 116 |
+
parts_virt.append(internal_data[3])
|
| 117 |
+
|
| 118 |
+
if boundary_data is not None:
|
| 119 |
+
parts_start.append(boundary_data[0])
|
| 120 |
+
parts_end.append(boundary_data[1])
|
| 121 |
+
parts_norm.append(boundary_data[2])
|
| 122 |
+
parts_virt.append(boundary_data[3])
|
| 123 |
+
|
| 124 |
+
if not parts_start:
|
| 125 |
+
return None, None, None
|
| 126 |
+
|
| 127 |
+
all_v_start = np.concatenate(parts_start, axis=0)
|
| 128 |
+
all_v_end = np.concatenate(parts_end, axis=0)
|
| 129 |
+
all_normals = np.concatenate(parts_norm, axis=0)
|
| 130 |
+
all_v_virtual = np.concatenate(parts_virt, axis=0)
|
| 131 |
+
|
| 132 |
+
edge_vectors = all_v_end - all_v_start
|
| 133 |
+
edge_lengths = np.linalg.norm(edge_vectors, axis=1)
|
| 134 |
+
total_length = np.sum(edge_lengths)
|
| 135 |
+
|
| 136 |
+
if total_length < 1e-9:
|
| 137 |
+
probs = np.ones(len(edge_lengths)) / len(edge_lengths)
|
| 138 |
+
else:
|
| 139 |
+
probs = edge_lengths / total_length
|
| 140 |
+
probs = probs / probs.sum()
|
| 141 |
+
|
| 142 |
+
chosen_indices = np.random.choice(len(edge_lengths), size=n_samples, p=probs)
|
| 143 |
+
|
| 144 |
+
t = np.random.rand(n_samples, 1)
|
| 145 |
+
|
| 146 |
+
sel_v_start = all_v_start[chosen_indices]
|
| 147 |
+
sel_v_end = all_v_end[chosen_indices]
|
| 148 |
+
sel_normals = all_normals[chosen_indices]
|
| 149 |
+
sel_v_virtual = all_v_virtual[chosen_indices]
|
| 150 |
+
|
| 151 |
+
sampled_points = sel_v_start + (sel_v_end - sel_v_start) * t
|
| 152 |
+
|
| 153 |
+
vertex_triplets = np.stack([sel_v_start, sel_v_end, sel_v_virtual], axis=1).astype(np.float32)
|
| 154 |
+
|
| 155 |
+
return sampled_points.astype(np.float32), sel_normals.astype(np.float32), vertex_triplets
|
| 156 |
+
|
| 157 |
+
def load_quantized_mesh_dora(
|
| 158 |
+
mesh_path,
|
| 159 |
+
mesh_load=None,
|
| 160 |
+
volume_resolution=256,
|
| 161 |
+
use_normals=True,
|
| 162 |
+
pc_sample_number=4096000,
|
| 163 |
+
edge_sample_ratio=0.2
|
| 164 |
+
):
|
| 165 |
+
cube_dilate = np.array([
|
| 166 |
+
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, -1], [0, -1, 0], [0, 1, 1], [0, -1, 1], [0, 1, -1], [0, -1, -1],
|
| 167 |
+
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 0, -1], [1, -1, 0], [1, 1, 1], [1, -1, 1], [1, 1, -1], [1, -1, -1],
|
| 168 |
+
[-1, 0, 0], [-1, 0, 1], [-1, 1, 0], [-1, 0, -1], [-1, -1, 0], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1], [-1, -1, -1]
|
| 169 |
+
]) / (volume_resolution * 4 - 1)
|
| 170 |
+
|
| 171 |
+
if mesh_load is None:
|
| 172 |
+
mesh_o3d = o3d.io.read_triangle_mesh(mesh_path)
|
| 173 |
+
vertices = np.clip(np.asarray(mesh_o3d.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 174 |
+
faces = np.asarray(mesh_o3d.triangles)
|
| 175 |
+
mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
|
| 176 |
+
else:
|
| 177 |
+
vertices = np.clip(np.asarray(mesh_load.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 178 |
+
faces = np.asarray(mesh_load.faces)
|
| 179 |
+
mesh_o3d = o3d.geometry.TriangleMesh()
|
| 180 |
+
mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
|
| 181 |
+
mesh_o3d.triangles = o3d.utility.Vector3iVector(faces)
|
| 182 |
+
|
| 183 |
+
tm_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
|
| 184 |
+
|
| 185 |
+
n_edge_samples = int(pc_sample_number * edge_sample_ratio)
|
| 186 |
+
p_edge, n_edge, triplets_edge = sample_edges_dora(tm_mesh, n_edge_samples)
|
| 187 |
+
|
| 188 |
+
if p_edge is None:
|
| 189 |
+
# print('p_edge is none!')
|
| 190 |
+
n_surface_samples = pc_sample_number
|
| 191 |
+
else:
|
| 192 |
+
# print('p_edge is right!')
|
| 193 |
+
n_surface_samples = pc_sample_number - n_edge_samples
|
| 194 |
+
|
| 195 |
+
p_surf, idx_surf = tm_mesh.sample(n_surface_samples, return_index=True)
|
| 196 |
+
p_surf = p_surf.astype(np.float32)
|
| 197 |
+
n_surf = tm_mesh.face_normals[idx_surf].astype(np.float32)
|
| 198 |
+
|
| 199 |
+
v_indices_surf = faces[idx_surf]
|
| 200 |
+
triplets_surf = vertices[v_indices_surf]
|
| 201 |
+
|
| 202 |
+
if p_edge is None:
|
| 203 |
+
final_points = p_surf
|
| 204 |
+
final_normals = n_surf
|
| 205 |
+
final_triplets = triplets_surf
|
| 206 |
+
else:
|
| 207 |
+
final_points = np.concatenate([p_surf, p_edge.astype(np.float32)], axis=0)
|
| 208 |
+
if use_normals:
|
| 209 |
+
final_normals = np.concatenate([n_surf, n_edge.astype(np.float32)], axis=0)
|
| 210 |
+
|
| 211 |
+
final_triplets = np.concatenate([triplets_surf, triplets_edge], axis=0)
|
| 212 |
+
|
| 213 |
+
voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
|
| 214 |
+
mesh_o3d,
|
| 215 |
+
voxel_size=1. / volume_resolution,
|
| 216 |
+
min_bound=[-0.5, -0.5, -0.5],
|
| 217 |
+
max_bound=[0.5, 0.5, 0.5]
|
| 218 |
+
)
|
| 219 |
+
voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
|
| 220 |
+
|
| 221 |
+
voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
|
| 222 |
+
o3d.geometry.PointCloud(
|
| 223 |
+
o3d.utility.Vector3dVector(
|
| 224 |
+
np.clip(
|
| 225 |
+
(final_points[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
|
| 226 |
+
-0.5 + 1e-6, 0.5 - 1e-6)
|
| 227 |
+
)
|
| 228 |
+
),
|
| 229 |
+
voxel_size=1. / volume_resolution,
|
| 230 |
+
min_bound=[-0.5, -0.5, -0.5],
|
| 231 |
+
max_bound=[0.5, 0.5, 0.5]
|
| 232 |
+
)
|
| 233 |
+
voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
|
| 234 |
+
voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
|
| 235 |
+
|
| 236 |
+
features_list = [torch.from_numpy(final_points)]
|
| 237 |
+
|
| 238 |
+
if use_normals:
|
| 239 |
+
features_list.append(torch.from_numpy(final_normals))
|
| 240 |
+
|
| 241 |
+
view_dtype = np.dtype((np.void, final_triplets.dtype.itemsize * final_triplets.shape[-1]))
|
| 242 |
+
v_view = final_triplets.view(view_dtype).squeeze(-1)
|
| 243 |
+
|
| 244 |
+
sort_idx = np.argsort(v_view, axis=1)
|
| 245 |
+
|
| 246 |
+
batch_indices = np.arange(final_triplets.shape[0])[:, None]
|
| 247 |
+
v_sorted = final_triplets[batch_indices, sort_idx]
|
| 248 |
+
|
| 249 |
+
v1 = v_sorted[:, 0, :]
|
| 250 |
+
v2 = v_sorted[:, 1, :]
|
| 251 |
+
v3 = v_sorted[:, 2, :]
|
| 252 |
+
|
| 253 |
+
dir1 = v1 - final_points
|
| 254 |
+
dir2 = v2 - final_points
|
| 255 |
+
dir3 = v3 - final_points
|
| 256 |
+
|
| 257 |
+
features_list.append(torch.Tensor(dir1.astype(np.float32)))
|
| 258 |
+
features_list.append(torch.Tensor(dir2.astype(np.float32)))
|
| 259 |
+
features_list.append(torch.Tensor(dir3.astype(np.float32)))
|
| 260 |
+
|
| 261 |
+
points_sample = torch.cat(features_list, axis=-1)
|
| 262 |
+
|
| 263 |
+
return voxels, points_sample
|
| 264 |
+
|
| 265 |
+
def load_quantized_mesh_original(
|
| 266 |
+
mesh_path,
|
| 267 |
+
mesh_load=None,
|
| 268 |
+
volume_resolution=256,
|
| 269 |
+
use_normals=True,
|
| 270 |
+
pc_sample_number=4096000,
|
| 271 |
+
):
|
| 272 |
+
cube_dilate = np.array(
|
| 273 |
+
[
|
| 274 |
+
[0, 0, 0],
|
| 275 |
+
[0, 0, 1],
|
| 276 |
+
[0, 1, 0],
|
| 277 |
+
[0, 0, -1],
|
| 278 |
+
[0, -1, 0],
|
| 279 |
+
[0, 1, 1],
|
| 280 |
+
[0, -1, 1],
|
| 281 |
+
[0, 1, -1],
|
| 282 |
+
[0, -1, -1],
|
| 283 |
+
|
| 284 |
+
[1, 0, 0],
|
| 285 |
+
[1, 0, 1],
|
| 286 |
+
[1, 1, 0],
|
| 287 |
+
[1, 0, -1],
|
| 288 |
+
[1, -1, 0],
|
| 289 |
+
[1, 1, 1],
|
| 290 |
+
[1, -1, 1],
|
| 291 |
+
[1, 1, -1],
|
| 292 |
+
[1, -1, -1],
|
| 293 |
+
|
| 294 |
+
[-1, 0, 0],
|
| 295 |
+
[-1, 0, 1],
|
| 296 |
+
[-1, 1, 0],
|
| 297 |
+
[-1, 0, -1],
|
| 298 |
+
[-1, -1, 0],
|
| 299 |
+
[-1, 1, 1],
|
| 300 |
+
[-1, -1, 1],
|
| 301 |
+
[-1, 1, -1],
|
| 302 |
+
[-1, -1, -1],
|
| 303 |
+
]
|
| 304 |
+
) / (volume_resolution * 4 - 1)
|
| 305 |
+
|
| 306 |
+
if mesh_load is None:
|
| 307 |
+
mesh = o3d.io.read_triangle_mesh(mesh_path)
|
| 308 |
+
vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 309 |
+
faces = np.asarray(mesh.triangles)
|
| 310 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
| 311 |
+
else:
|
| 312 |
+
vertices = np.clip(np.asarray(mesh_load.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 313 |
+
faces = np.asarray(mesh_load.faces)
|
| 314 |
+
|
| 315 |
+
mesh = o3d.geometry.TriangleMesh()
|
| 316 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
| 317 |
+
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
|
| 321 |
+
mesh,
|
| 322 |
+
voxel_size=1. / volume_resolution,
|
| 323 |
+
min_bound=[-0.5, -0.5, -0.5],
|
| 324 |
+
max_bound=[0.5, 0.5, 0.5]
|
| 325 |
+
)
|
| 326 |
+
voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
|
| 327 |
+
|
| 328 |
+
points_normals_sample = trimesh.Trimesh(vertices=vertices, faces=faces).sample(count=pc_sample_number, return_index=True)
|
| 329 |
+
points_xyz_np = points_normals_sample[0].astype(np.float32)
|
| 330 |
+
points_sample = points_normals_sample[0].astype(np.float32)
|
| 331 |
+
face_indices = points_normals_sample[1]
|
| 332 |
+
|
| 333 |
+
voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
|
| 334 |
+
o3d.geometry.PointCloud(
|
| 335 |
+
o3d.utility.Vector3dVector(
|
| 336 |
+
np.clip(
|
| 337 |
+
(points_sample[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
|
| 338 |
+
-0.5 + 1e-6, 0.5 - 1e-6)
|
| 339 |
+
)
|
| 340 |
+
),
|
| 341 |
+
voxel_size=1. / volume_resolution,
|
| 342 |
+
min_bound=[-0.5, -0.5, -0.5],
|
| 343 |
+
max_bound=[0.5, 0.5, 0.5]
|
| 344 |
+
)
|
| 345 |
+
voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
|
| 346 |
+
voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
|
| 347 |
+
|
| 348 |
+
features_list = [torch.from_numpy(points_xyz_np)]
|
| 349 |
+
|
| 350 |
+
if use_normals:
|
| 351 |
+
mesh.compute_triangle_normals()
|
| 352 |
+
normals_sample = np.asarray(
|
| 353 |
+
mesh.triangle_normals
|
| 354 |
+
)[points_normals_sample[1]].astype(np.float32)
|
| 355 |
+
# points_sample = torch.cat((torch.Tensor(points_sample), torch.Tensor(normals_sample)), axis=-1)
|
| 356 |
+
features_list.append(torch.from_numpy(normals_sample))
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
########################################
|
| 360 |
+
# add direction to three vtx
|
| 361 |
+
########################################
|
| 362 |
+
|
| 363 |
+
## wo sort
|
| 364 |
+
# sampled_face_v_indices = faces[face_indices]
|
| 365 |
+
# v1 = vertices[sampled_face_v_indices[:, 0]]
|
| 366 |
+
# v2 = vertices[sampled_face_v_indices[:, 1]]
|
| 367 |
+
# v3 = vertices[sampled_face_v_indices[:, 2]]
|
| 368 |
+
|
| 369 |
+
# w sort
|
| 370 |
+
sampled_face_v_indices = faces[face_indices]
|
| 371 |
+
v_batch = np.stack([
|
| 372 |
+
vertices[sampled_face_v_indices[:, 0]],
|
| 373 |
+
vertices[sampled_face_v_indices[:, 1]],
|
| 374 |
+
vertices[sampled_face_v_indices[:, 2]]
|
| 375 |
+
], axis=1)
|
| 376 |
+
|
| 377 |
+
view_dtype = np.dtype((np.void, v_batch.dtype.itemsize * v_batch.shape[-1]))
|
| 378 |
+
v_view = v_batch.view(view_dtype).squeeze(-1) # 变成 (N, 3) 的 void
|
| 379 |
+
|
| 380 |
+
sort_idx = np.argsort(v_view, axis=1) # (N, 3)
|
| 381 |
+
|
| 382 |
+
batch_indices = np.arange(v_batch.shape[0])[:, None]
|
| 383 |
+
v_sorted = v_batch[batch_indices, sort_idx] # (N, 3, 3)
|
| 384 |
+
|
| 385 |
+
v1 = v_sorted[:, 0, :]
|
| 386 |
+
v2 = v_sorted[:, 1, :]
|
| 387 |
+
v3 = v_sorted[:, 2, :]
|
| 388 |
+
# --------------------
|
| 389 |
+
|
| 390 |
+
dir1 = v1 - points_xyz_np
|
| 391 |
+
dir2 = v2 - points_xyz_np
|
| 392 |
+
dir3 = v3 - points_xyz_np
|
| 393 |
+
|
| 394 |
+
features_list.append(torch.Tensor(dir1.astype(np.float32)))
|
| 395 |
+
features_list.append(torch.Tensor(dir2.astype(np.float32)))
|
| 396 |
+
features_list.append(torch.Tensor(dir3.astype(np.float32)))
|
| 397 |
+
|
| 398 |
+
points_sample = torch.cat(features_list, axis=-1)
|
| 399 |
+
########################################
|
| 400 |
+
# add direction to three vtx
|
| 401 |
+
########################################
|
| 402 |
+
|
| 403 |
+
return voxels, points_sample
|
| 404 |
+
|
| 405 |
+
def get_sha256(filepath: str) -> str:
|
| 406 |
+
sha256_hash = hashlib.sha256()
|
| 407 |
+
with open(filepath, "rb") as f:
|
| 408 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
| 409 |
+
sha256_hash.update(byte_block)
|
| 410 |
+
return sha256_hash.hexdigest()
|
| 411 |
+
|
| 412 |
+
class VoxelVertexDataset_edge(Dataset):
|
| 413 |
+
def __init__(self,
|
| 414 |
+
root_dir: str,
|
| 415 |
+
base_resolution: int = 256,
|
| 416 |
+
min_resolution: int = 128,
|
| 417 |
+
img_res: int = 518,
|
| 418 |
+
cache_dir: str = "dataset_cache_test",
|
| 419 |
+
renders_dir: str = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond',
|
| 420 |
+
process_img: bool = False,
|
| 421 |
+
n_pre_samples: int = 1024,
|
| 422 |
+
|
| 423 |
+
active_voxel_res: int = 64,
|
| 424 |
+
pc_sample_number: int = 409600,
|
| 425 |
+
|
| 426 |
+
filter_active_voxels: bool = False, #####
|
| 427 |
+
min_active_voxels: int = 2000,
|
| 428 |
+
max_active_voxels: int = 40000,
|
| 429 |
+
|
| 430 |
+
cache_filter_path: str = "/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/data/filter_name/objaverse_200_2000_2000min_25000max.txt",
|
| 431 |
+
|
| 432 |
+
sample_type: str = 'uniform',
|
| 433 |
+
):
|
| 434 |
+
self.root_dir = root_dir
|
| 435 |
+
self.cache_dir = cache_dir
|
| 436 |
+
self.img_res = img_res
|
| 437 |
+
self.renders_dir = renders_dir
|
| 438 |
+
self.process_img = process_img
|
| 439 |
+
self.filter_active_voxels=filter_active_voxels
|
| 440 |
+
self.min_active_voxels=min_active_voxels
|
| 441 |
+
self.max_active_voxels=max_active_voxels
|
| 442 |
+
|
| 443 |
+
self.active_voxel_res = active_voxel_res
|
| 444 |
+
self.pc_sample_number = pc_sample_number
|
| 445 |
+
|
| 446 |
+
self.sample_type = sample_type
|
| 447 |
+
|
| 448 |
+
# self.image_transform = transforms.ToTensor()
|
| 449 |
+
self.image_transform = transforms.Compose([
|
| 450 |
+
transforms.ToTensor(),
|
| 451 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 452 |
+
])
|
| 453 |
+
|
| 454 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 455 |
+
assert (base_resolution & (base_resolution - 1)) == 0, "Resolution must be power of 2"
|
| 456 |
+
assert (min_resolution & (min_resolution - 1)) == 0, "Resolution must be power of 2"
|
| 457 |
+
self.res_levels = [
|
| 458 |
+
2**i for i in range(
|
| 459 |
+
int(np.log2(min_resolution)),
|
| 460 |
+
int(np.log2(base_resolution)) + 1
|
| 461 |
+
)
|
| 462 |
+
]
|
| 463 |
+
|
| 464 |
+
if self.active_voxel_res is not None and self.active_voxel_res not in self.res_levels:
|
| 465 |
+
self.res_levels.append(active_voxel_res)
|
| 466 |
+
self.res_levels.sort()
|
| 467 |
+
|
| 468 |
+
all_obj_files = sorted([f for f in os.listdir(root_dir) if f.endswith(('.obj', '.ply', '.glb'))])
|
| 469 |
+
if not all_obj_files:
|
| 470 |
+
raise ValueError(f"No OBJ files found in {root_dir}")
|
| 471 |
+
|
| 472 |
+
if self.process_img:
|
| 473 |
+
map_file_path = os.path.join(os.path.dirname(self.renders_dir), 'map.json')
|
| 474 |
+
if os.path.exists(map_file_path):
|
| 475 |
+
print(f"Loading pre-computed hash map from {map_file_path}")
|
| 476 |
+
with open(map_file_path, 'r') as f:
|
| 477 |
+
file_map = json.load(f)
|
| 478 |
+
filename_to_hash = {item['filename']: item['sha256'] for item in file_map}
|
| 479 |
+
all_obj_hashes = [filename_to_hash.get(fname) for fname in all_obj_files]
|
| 480 |
+
else:
|
| 481 |
+
print("No hash map found. Calculating SHA256 hashes on the fly... (This may take a moment)")
|
| 482 |
+
all_obj_hashes = []
|
| 483 |
+
for fname in tqdm(all_obj_files, desc="Hashing .obj files"):
|
| 484 |
+
fpath = os.path.join(self.root_dir, fname)
|
| 485 |
+
all_obj_hashes.append(get_sha256(fpath))
|
| 486 |
+
|
| 487 |
+
else:
|
| 488 |
+
print("process_img is False, skipping SHA256 hash calculation.")
|
| 489 |
+
all_obj_hashes = [None] * len(all_obj_files)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
if self.filter_active_voxels and cache_filter_path:
|
| 493 |
+
filtered_list_cache_path = cache_filter_path
|
| 494 |
+
|
| 495 |
+
if os.path.exists(filtered_list_cache_path):
|
| 496 |
+
print(f"Loading filtered BASENAMES from: {filtered_list_cache_path}")
|
| 497 |
+
basename_to_fullname_map = {os.path.splitext(f)[0]: f for f in all_obj_files}
|
| 498 |
+
|
| 499 |
+
with open(filtered_list_cache_path, 'r') as f:
|
| 500 |
+
filtered_basenames = [line.strip() for line in f if line.strip()]
|
| 501 |
+
|
| 502 |
+
self.obj_files = []
|
| 503 |
+
for basename in filtered_basenames:
|
| 504 |
+
if basename in basename_to_fullname_map:
|
| 505 |
+
self.obj_files.append(basename_to_fullname_map[basename])
|
| 506 |
+
else:
|
| 507 |
+
print(f"[WARN] Basename '{basename}' from filter list not found in directory '{self.root_dir}'. Skipping.")
|
| 508 |
+
|
| 509 |
+
file_to_hash_map = dict(zip(all_obj_files, all_obj_hashes))
|
| 510 |
+
self.obj_hashes = [file_to_hash_map.get(fname) for fname in self.obj_files] # 使用 .get 更安全
|
| 511 |
+
|
| 512 |
+
print(f"Loaded and matched {len(self.obj_files)} samples from the filter list.")
|
| 513 |
+
|
| 514 |
+
else:
|
| 515 |
+
print(f"Cache filter file not found: {filtered_list_cache_path}. Proceeding with on-the-fly filtering...")
|
| 516 |
+
|
| 517 |
+
else:
|
| 518 |
+
self.obj_files = all_obj_files
|
| 519 |
+
self.obj_hashes = all_obj_hashes
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if not self.obj_files:
|
| 523 |
+
raise ValueError(f"No OBJ files found in {root_dir}")
|
| 524 |
+
|
| 525 |
+
self.rembg_session = None
|
| 526 |
+
|
| 527 |
+
def _init_rembg_session_if_needed(self):
|
| 528 |
+
if self.rembg_session is None:
|
| 529 |
+
print(f"Initializing rembg session for worker {os.getpid()}...")
|
| 530 |
+
self.rembg_session = rembg.new_session(model_name='u2net')
|
| 531 |
+
|
| 532 |
+
def preprocess_image(self, input: Image.Image) -> Image.Image:
|
| 533 |
+
self._init_rembg_session_if_needed()
|
| 534 |
+
has_alpha = False
|
| 535 |
+
if input.mode == 'RGBA':
|
| 536 |
+
alpha = np.array(input)[:, :, 3]
|
| 537 |
+
if not np.all(alpha == 255):
|
| 538 |
+
has_alpha = True
|
| 539 |
+
if has_alpha:
|
| 540 |
+
output = input
|
| 541 |
+
else:
|
| 542 |
+
input = input.convert('RGB')
|
| 543 |
+
max_size = max(input.size)
|
| 544 |
+
scale = min(1, 1024 / max_size)
|
| 545 |
+
if scale < 1:
|
| 546 |
+
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
| 547 |
+
if getattr(self, 'rembg_session', None) is None:
|
| 548 |
+
self.rembg_session = rembg.new_session('u2net')
|
| 549 |
+
output = rembg.remove(input, session=self.rembg_session)
|
| 550 |
+
output_np = np.array(output)
|
| 551 |
+
alpha = output_np[:, :, 3]
|
| 552 |
+
bbox = np.argwhere(alpha > 0.8 * 255)
|
| 553 |
+
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
|
| 554 |
+
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
|
| 555 |
+
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
| 556 |
+
size = int(size * 1.2)
|
| 557 |
+
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
|
| 558 |
+
output = output.crop(bbox) # type: ignore
|
| 559 |
+
output = output.resize((518, 518), Image.Resampling.LANCZOS)
|
| 560 |
+
output = np.array(output).astype(np.float32) / 255
|
| 561 |
+
output = output[:, :, :3] * output[:, :, 3:4]
|
| 562 |
+
output = Image.fromarray((output * 255).astype(np.uint8))
|
| 563 |
+
return output
|
| 564 |
+
|
| 565 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 566 |
+
name = os.path.splitext(self.obj_files[idx])[0]
|
| 567 |
+
cache_path = os.path.join(self.cache_dir, f"{name}_precombined.npz")
|
| 568 |
+
|
| 569 |
+
sha256_hash = self.obj_hashes[idx]
|
| 570 |
+
mesh_render_dir = os.path.join(self.renders_dir, sha256_hash) if sha256_hash else ""
|
| 571 |
+
|
| 572 |
+
image_path = ''
|
| 573 |
+
if mesh_render_dir and os.path.isdir(mesh_render_dir):
|
| 574 |
+
try:
|
| 575 |
+
render_files = [f for f in os.listdir(mesh_render_dir) if f.endswith('.png')]
|
| 576 |
+
if render_files:
|
| 577 |
+
image_path = os.path.join(mesh_render_dir, random.choice(render_files))
|
| 578 |
+
except OSError as e:
|
| 579 |
+
print(f"[WARN] Could not access render directory {mesh_render_dir}: {e}")
|
| 580 |
+
|
| 581 |
+
if self.process_img:
|
| 582 |
+
try:
|
| 583 |
+
if image_path and os.path.exists(image_path):
|
| 584 |
+
image_obj = self.image_transform(self.preprocess_image(Image.open(image_path)).convert('RGB'))
|
| 585 |
+
else:
|
| 586 |
+
image_obj = self.image_transform(Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)).convert('RGB'))
|
| 587 |
+
except Exception as e:
|
| 588 |
+
image_obj = self.image_transform(Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)).convert('RGB'))
|
| 589 |
+
print(f'Error processing image {image_path}: {e}')
|
| 590 |
+
|
| 591 |
+
if os.path.exists(cache_path):
|
| 592 |
+
try:
|
| 593 |
+
loaded = np.load(cache_path, allow_pickle=True)
|
| 594 |
+
data = {
|
| 595 |
+
'original_faces': torch.from_numpy(loaded['original_faces']),
|
| 596 |
+
'original_vertices': torch.from_numpy(loaded['original_vertices']),
|
| 597 |
+
}
|
| 598 |
+
for res in self.res_levels:
|
| 599 |
+
# Load standard voxel data
|
| 600 |
+
if f'combined_voxels_{res}' in loaded:
|
| 601 |
+
data[f'combined_voxels_{res}'] = torch.from_numpy(loaded[f'combined_voxels_{res}'])
|
| 602 |
+
data[f'combined_voxel_labels_{res}'] = torch.from_numpy(loaded[f'combined_voxel_labels_{res}'])
|
| 603 |
+
data[f'gt_combined_endpoints_{res}'] = torch.from_numpy(loaded[f'gt_combined_endpoints_{res}'])
|
| 604 |
+
|
| 605 |
+
data[f'gt_vertex_voxels_{res}'] = torch.from_numpy(loaded[f'gt_vertex_voxels_{res}'])
|
| 606 |
+
data[f'gt_edge_voxels_{res}'] = torch.from_numpy(loaded[f'gt_edge_voxels_{res}'])
|
| 607 |
+
data[f'gt_edge_endpoints_{res}'] = torch.from_numpy(loaded[f'gt_edge_endpoints_{res}'])
|
| 608 |
+
data[f'gt_edge_errors_{res}'] = torch.from_numpy(loaded[f'gt_edge_errors_{res}'])
|
| 609 |
+
|
| 610 |
+
# Load Active Voxels and Point Cloud for Local Pooling
|
| 611 |
+
if res == self.active_voxel_res:
|
| 612 |
+
if f'active_voxels_{res}' in loaded:
|
| 613 |
+
data[f'active_voxels_{res}'] = torch.from_numpy(loaded[f'active_voxels_{res}'])
|
| 614 |
+
if f'point_cloud_{res}' in loaded:
|
| 615 |
+
data[f'point_cloud_{res}'] = torch.from_numpy(loaded[f'point_cloud_{res}'])
|
| 616 |
+
|
| 617 |
+
if self.process_img:
|
| 618 |
+
data['image'] = image_obj
|
| 619 |
+
data['image_path'] = image_path
|
| 620 |
+
return data
|
| 621 |
+
|
| 622 |
+
except Exception as e:
|
| 623 |
+
print(f"[WARN] Corrupted NPZ cache {cache_path}, regenerating... {e}")
|
| 624 |
+
os.remove(cache_path)
|
| 625 |
+
|
| 626 |
+
try:
|
| 627 |
+
mesh_path = os.path.join(self.root_dir, self.obj_files[idx])
|
| 628 |
+
mesh = normalize_mesh(mesh_path)
|
| 629 |
+
if mesh.is_empty or not hasattr(mesh.vertices, 'shape') or mesh.vertices.shape[0] < 3 or not hasattr(mesh.faces, 'shape') or mesh.faces.shape[0] < 1:
|
| 630 |
+
raise ValueError("Invalid or empty mesh")
|
| 631 |
+
except Exception as e:
|
| 632 |
+
print(f"[ERROR] Failed to load mesh: {self.obj_files[idx]} | {e}")
|
| 633 |
+
return self.__getitem__((idx + 1) % len(self))
|
| 634 |
+
|
| 635 |
+
vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
|
| 636 |
+
faces = torch.tensor(mesh.faces, dtype=torch.long)
|
| 637 |
+
|
| 638 |
+
data = {'original_faces': faces.clone(), 'original_vertices': vertices.clone()}
|
| 639 |
+
|
| 640 |
+
for res in self.res_levels:
|
| 641 |
+
quantized = quantize_vertices(vertices, res)
|
| 642 |
+
tmesh = trimesh.Trimesh(vertices=quantized.numpy(), faces=faces.numpy())
|
| 643 |
+
tmesh.merge_vertices()
|
| 644 |
+
|
| 645 |
+
vertex_voxels_raw = torch.from_numpy(tmesh.vertices.astype(np.int32))
|
| 646 |
+
edges_raw = tmesh.edges_unique
|
| 647 |
+
|
| 648 |
+
vertex_labels_raw = torch.zeros(vertex_voxels_raw.shape[0], dtype=torch.long)
|
| 649 |
+
|
| 650 |
+
all_edge_voxels = []
|
| 651 |
+
edge_endpoints = []
|
| 652 |
+
edge_errors = []
|
| 653 |
+
|
| 654 |
+
for u_idx, v_idx in edges_raw:
|
| 655 |
+
p1_grid, p2_grid = vertex_voxels_raw[u_idx].float(), vertex_voxels_raw[v_idx].float()
|
| 656 |
+
v, ep, err = get_voxel_line(p1_grid, p2_grid)
|
| 657 |
+
all_edge_voxels.extend(v)
|
| 658 |
+
edge_endpoints.extend(ep)
|
| 659 |
+
edge_errors.extend(err)
|
| 660 |
+
|
| 661 |
+
if all_edge_voxels:
|
| 662 |
+
edge_voxels_np = np.array(all_edge_voxels, dtype=np.int32)
|
| 663 |
+
edge_endpoints_np = np.array([np.stack(pair) for pair in edge_endpoints], dtype=np.float32)
|
| 664 |
+
edge_errors_np = np.array(edge_errors, dtype=np.float32)
|
| 665 |
+
|
| 666 |
+
unique_edge_voxels_np, first_indices = np.unique(edge_voxels_np, axis=0, return_index=True)
|
| 667 |
+
edge_voxels_raw = torch.from_numpy(unique_edge_voxels_np)
|
| 668 |
+
edge_labels_raw = torch.ones(len(edge_voxels_raw), dtype=torch.long)
|
| 669 |
+
edge_endpoints_raw = torch.from_numpy(edge_endpoints_np[first_indices])
|
| 670 |
+
edge_errors_raw = torch.from_numpy(edge_errors_np[first_indices])
|
| 671 |
+
else:
|
| 672 |
+
edge_voxels_raw = torch.empty(0, 3, dtype=torch.int32)
|
| 673 |
+
edge_labels_raw = torch.empty(0, dtype=torch.long)
|
| 674 |
+
edge_endpoints_raw = torch.empty(0, 2, 3, dtype=torch.float32)
|
| 675 |
+
edge_errors_raw = torch.empty(0, 3, dtype=torch.float32)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
if res == self.active_voxel_res:
|
| 679 |
+
try:
|
| 680 |
+
if self.sample_type == 'uniform':
|
| 681 |
+
# triposf-style, normilize wrong
|
| 682 |
+
ts_voxels, ts_points = load_quantized_mesh_original(
|
| 683 |
+
mesh_path=os.path.join(self.root_dir, self.obj_files[idx]),
|
| 684 |
+
mesh_load=mesh,
|
| 685 |
+
volume_resolution=res,
|
| 686 |
+
use_normals=True,
|
| 687 |
+
pc_sample_number=self.pc_sample_number,
|
| 688 |
+
)
|
| 689 |
+
else:
|
| 690 |
+
ts_voxels, ts_points = load_quantized_mesh_dora(
|
| 691 |
+
mesh_path=os.path.join(self.root_dir, self.obj_files[idx]),
|
| 692 |
+
mesh_load=mesh,
|
| 693 |
+
volume_resolution=res,
|
| 694 |
+
use_normals=True,
|
| 695 |
+
pc_sample_number=self.pc_sample_number,
|
| 696 |
+
edge_sample_ratio=0.5,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
# Convert types
|
| 700 |
+
# Voxels from TripoSF are float Tensor (N, 3), convert to int32
|
| 701 |
+
data[f'active_voxels_{res}'] = ts_voxels.int()
|
| 702 |
+
data[f'point_cloud_{res}'] = ts_points
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
except Exception as e:
|
| 706 |
+
print(f"[ERROR] Failed to compute active voxels/points for {name} at res {res}: {e}")
|
| 707 |
+
data[f'active_voxels_{res}'] = torch.empty(0, 3, dtype=torch.int32)
|
| 708 |
+
data[f'point_cloud_{res}'] = torch.empty(0, 6, dtype=torch.float32)
|
| 709 |
+
|
| 710 |
+
combined_voxels = torch.cat([vertex_voxels_raw, edge_voxels_raw], dim=0)
|
| 711 |
+
combined_labels = torch.cat([vertex_labels_raw, edge_labels_raw], dim=0)
|
| 712 |
+
|
| 713 |
+
if combined_voxels.numel() > 0:
|
| 714 |
+
unique_voxels, inverse_indices = torch.unique(combined_voxels, dim=0, return_inverse=True)
|
| 715 |
+
|
| 716 |
+
zero_mask = (combined_labels == 0)
|
| 717 |
+
if zero_mask.any():
|
| 718 |
+
zero_per_unique = torch.zeros(len(unique_voxels), dtype=torch.bool)
|
| 719 |
+
zero_per_unique.scatter_(0, inverse_indices[zero_mask], True)
|
| 720 |
+
final_combined_labels = torch.where(zero_per_unique, 0, 1).long()
|
| 721 |
+
else:
|
| 722 |
+
final_combined_labels = torch.ones(len(unique_voxels), dtype=torch.long)
|
| 723 |
+
|
| 724 |
+
if edge_voxels_raw.numel() > 0:
|
| 725 |
+
edge_endpoint_map = {tuple(coord): ep for coord, ep in zip(edge_voxels_raw.numpy(), edge_endpoints_raw.numpy())}
|
| 726 |
+
|
| 727 |
+
endpoints_arr = np.empty((len(unique_voxels), 2, 3), dtype=np.float32)
|
| 728 |
+
unique_voxels_np = unique_voxels.numpy()
|
| 729 |
+
|
| 730 |
+
for j, coord in enumerate(unique_voxels_np):
|
| 731 |
+
coord_tuple = tuple(coord)
|
| 732 |
+
if coord_tuple in edge_endpoint_map:
|
| 733 |
+
endpoints_arr[j] = edge_endpoint_map[coord_tuple]
|
| 734 |
+
else:
|
| 735 |
+
endpoints_arr[j, 0, :] = coord
|
| 736 |
+
endpoints_arr[j, 1, :] = coord
|
| 737 |
+
final_combined_endpoints = torch.from_numpy(endpoints_arr)
|
| 738 |
+
else:
|
| 739 |
+
final_combined_endpoints = unique_voxels.float().unsqueeze(1).repeat(1, 2, 1)
|
| 740 |
+
else:
|
| 741 |
+
unique_voxels = torch.empty(0, 3, dtype=torch.int32)
|
| 742 |
+
final_combined_labels = torch.empty(0, dtype=torch.long)
|
| 743 |
+
final_combined_endpoints = torch.empty(0, 2, 3, dtype=torch.float32)
|
| 744 |
+
|
| 745 |
+
data[f'combined_voxels_{res}'] = unique_voxels
|
| 746 |
+
data[f'combined_voxel_labels_{res}'] = final_combined_labels
|
| 747 |
+
data[f'gt_combined_endpoints_{res}'] = final_combined_endpoints.reshape(-1, 6)
|
| 748 |
+
|
| 749 |
+
data[f'gt_vertex_voxels_{res}'] = vertex_voxels_raw
|
| 750 |
+
data[f'gt_edge_voxels_{res}'] = edge_voxels_raw
|
| 751 |
+
data[f'gt_edge_endpoints_{res}'] = edge_endpoints_raw.reshape(-1, 6)
|
| 752 |
+
data[f'gt_edge_errors_{res}'] = edge_errors_raw
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
save_dict = {
|
| 756 |
+
'original_faces': data['original_faces'].numpy(),
|
| 757 |
+
'original_vertices': data['original_vertices'].numpy(),
|
| 758 |
+
}
|
| 759 |
+
for res in self.res_levels:
|
| 760 |
+
for key_suffix in [
|
| 761 |
+
'combined_voxels', 'combined_voxel_labels', 'gt_combined_endpoints',
|
| 762 |
+
'gt_vertex_voxels', 'gt_edge_voxels', 'gt_edge_endpoints', 'gt_edge_errors',
|
| 763 |
+
]:
|
| 764 |
+
full_key = f'{key_suffix}_{res}'
|
| 765 |
+
if full_key in data:
|
| 766 |
+
save_dict[full_key] = data[full_key].numpy()
|
| 767 |
+
|
| 768 |
+
if f'active_voxels_{res}' in data:
|
| 769 |
+
save_dict[f'active_voxels_{res}'] = data[f'active_voxels_{res}'].numpy()
|
| 770 |
+
|
| 771 |
+
if f'point_cloud_{res}' in data:
|
| 772 |
+
save_dict[f'point_cloud_{res}'] = data[f'point_cloud_{res}'].numpy()
|
| 773 |
+
|
| 774 |
+
try:
|
| 775 |
+
np.savez_compressed(cache_path, **save_dict)
|
| 776 |
+
except Exception as e:
|
| 777 |
+
print(f"[ERROR] Failed to save cache {cache_path}: {e}")
|
| 778 |
+
if os.path.exists(cache_path): os.remove(cache_path)
|
| 779 |
+
|
| 780 |
+
if self.process_img:
|
| 781 |
+
data['image'] = image_obj
|
| 782 |
+
data['image_path'] = image_path
|
| 783 |
+
|
| 784 |
+
return data
|
| 785 |
+
|
| 786 |
+
def __len__(self) -> int:
|
| 787 |
+
return len(self.obj_files)
|
| 788 |
+
|
| 789 |
+
def collate_fn_pointnet(
|
| 790 |
+
batch: List[Dict[str, torch.Tensor]],
|
| 791 |
+
) -> Dict[str, torch.Tensor]:
|
| 792 |
+
|
| 793 |
+
if not batch:
|
| 794 |
+
return {}
|
| 795 |
+
|
| 796 |
+
batch = [b for b in batch if b is not None]
|
| 797 |
+
if not batch:
|
| 798 |
+
return {}
|
| 799 |
+
|
| 800 |
+
collated = {
|
| 801 |
+
'original_faces': [b['original_faces'] for b in batch],
|
| 802 |
+
'original_vertices': [b['original_vertices'] for b in batch],
|
| 803 |
+
'image_path': [b['image_path'] for b in batch],
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
if 'image' in batch[0] and batch[0]['image'] is not None:
|
| 807 |
+
collated['image'] = torch.stack([b['image'] for b in batch])
|
| 808 |
+
|
| 809 |
+
res_levels = []
|
| 810 |
+
for k in batch[0].keys():
|
| 811 |
+
if k.startswith('gt_vertex_voxels_'):
|
| 812 |
+
try:
|
| 813 |
+
res_levels.append(int(k.split('_')[-1]))
|
| 814 |
+
except ValueError:
|
| 815 |
+
pass
|
| 816 |
+
res_levels.sort()
|
| 817 |
+
|
| 818 |
+
for res in res_levels:
|
| 819 |
+
all_active_voxels_list = []
|
| 820 |
+
all_point_clouds_list = []
|
| 821 |
+
|
| 822 |
+
all_combined_voxels_list = []
|
| 823 |
+
all_combined_labels_list = []
|
| 824 |
+
all_vertex_voxels_only = []
|
| 825 |
+
all_edge_voxels_only = []
|
| 826 |
+
all_edge_endpoints_only = []
|
| 827 |
+
all_combined_endpoints = []
|
| 828 |
+
all_combined_errors_list = []
|
| 829 |
+
layout = []
|
| 830 |
+
|
| 831 |
+
vtx_offset = 0
|
| 832 |
+
adj_flat_offset = 0
|
| 833 |
+
start_idx = 0
|
| 834 |
+
|
| 835 |
+
# Attempt to find device from first tensor
|
| 836 |
+
device = torch.device('cpu')
|
| 837 |
+
for v in batch[0].values():
|
| 838 |
+
if isinstance(v, torch.Tensor):
|
| 839 |
+
device = v.device
|
| 840 |
+
break
|
| 841 |
+
|
| 842 |
+
for i, sample in enumerate(batch):
|
| 843 |
+
vertex_voxels = sample.get(f'gt_vertex_voxels_{res}', torch.empty(0,3,dtype=torch.int32)).to(device)
|
| 844 |
+
vertex_labels = torch.zeros(vertex_voxels.shape[0], dtype=torch.long, device=device)
|
| 845 |
+
edge_voxels = sample.get(f'gt_edge_voxels_{res}', torch.empty(0,3,dtype=torch.int32)).to(device)
|
| 846 |
+
edge_labels = torch.ones(edge_voxels.shape[0], dtype=torch.long, device=device)
|
| 847 |
+
edge_endpoints= sample.get(f'gt_edge_endpoints_{res}', torch.empty(0,6,dtype=torch.float32)).to(device)
|
| 848 |
+
edge_errors = sample.get(f'gt_edge_errors_{res}', torch.empty(0,3,dtype=torch.float32)).to(device)
|
| 849 |
+
|
| 850 |
+
vertex_errors = sample.get(f'gt_vertex_errors_{res}', torch.zeros_like(vertex_voxels, dtype=torch.float32)).to(device)
|
| 851 |
+
|
| 852 |
+
if vertex_voxels.numel() > 0:
|
| 853 |
+
idx = torch.full((vertex_voxels.shape[0],1), i, dtype=torch.int32, device=device)
|
| 854 |
+
all_vertex_voxels_only.append(torch.cat([idx, vertex_voxels], dim=1))
|
| 855 |
+
|
| 856 |
+
if edge_voxels.numel() > 0:
|
| 857 |
+
idx = torch.full((edge_voxels.shape[0],1), i, dtype=torch.int32, device=device)
|
| 858 |
+
all_edge_voxels_only.append(torch.cat([idx, edge_voxels], dim=1))
|
| 859 |
+
all_edge_endpoints_only.append(
|
| 860 |
+
torch.cat([idx.to(torch.float32), edge_endpoints], dim=1))
|
| 861 |
+
|
| 862 |
+
if vertex_voxels.numel() + edge_voxels.numel() > 0:
|
| 863 |
+
combined_voxels = torch.cat([vertex_voxels, edge_voxels], dim=0)
|
| 864 |
+
combined_labels = torch.cat([vertex_labels, edge_labels], dim=0)
|
| 865 |
+
|
| 866 |
+
endpoints = torch.zeros(combined_voxels.size(0), 6, dtype=torch.float32, device=device)
|
| 867 |
+
if edge_voxels.numel() > 0:
|
| 868 |
+
endpoints[-edge_voxels.size(0):] = edge_endpoints
|
| 869 |
+
if vertex_voxels.numel() > 0:
|
| 870 |
+
endpoints[:vertex_voxels.size(0)] = vertex_voxels.repeat(1,2).float()
|
| 871 |
+
|
| 872 |
+
combined_errors = torch.cat([vertex_errors, edge_errors], dim=0)
|
| 873 |
+
|
| 874 |
+
batch_idx_int = torch.full((combined_voxels.shape[0],1), i, dtype=torch.int32, device=device)
|
| 875 |
+
all_combined_voxels_list.append(torch.cat([batch_idx_int, combined_voxels], dim=1))
|
| 876 |
+
all_combined_labels_list.append(combined_labels)
|
| 877 |
+
|
| 878 |
+
batch_idx_float = batch_idx_int.to(torch.float32)
|
| 879 |
+
all_combined_endpoints.append(torch.cat([batch_idx_float, endpoints], dim=1))
|
| 880 |
+
all_combined_errors_list.append(torch.cat([batch_idx_float, combined_errors], dim=1))
|
| 881 |
+
|
| 882 |
+
layout.append(slice(start_idx, start_idx + combined_voxels.shape[0]))
|
| 883 |
+
start_idx += combined_voxels.shape[0]
|
| 884 |
+
else:
|
| 885 |
+
layout.append(slice(start_idx, start_idx))
|
| 886 |
+
|
| 887 |
+
# Active Voxels (Sparse Coords)
|
| 888 |
+
active_voxels = sample.get(f'active_voxels_{res}', torch.empty(0, 3, dtype=torch.int32)).to(device)
|
| 889 |
+
if active_voxels.numel() > 0:
|
| 890 |
+
idx = torch.full((active_voxels.shape[0], 1), i, dtype=torch.int32, device=device)
|
| 891 |
+
all_active_voxels_list.append(torch.cat([idx, active_voxels], dim=1))
|
| 892 |
+
|
| 893 |
+
# ==========================================
|
| 894 |
+
# Modified Section: Collect Point Clouds
|
| 895 |
+
# ==========================================
|
| 896 |
+
# pc = sample.get(f'point_cloud_{res}', torch.empty(0, 6, dtype=torch.float32)).to(device)
|
| 897 |
+
pc = sample.get(f'point_cloud_{res}', torch.empty(0, 15, dtype=torch.float32)).to(device)
|
| 898 |
+
# We expect all samples to have point clouds if res == active_voxel_res
|
| 899 |
+
if pc.numel() > 0:
|
| 900 |
+
all_point_clouds_list.append(pc)
|
| 901 |
+
|
| 902 |
+
collated[f'layout_{res}'] = layout
|
| 903 |
+
|
| 904 |
+
def cat_or_empty(lst, shape, dtype):
|
| 905 |
+
return torch.cat(lst, dim=0) if lst else torch.empty(shape, dtype=dtype, device=device)
|
| 906 |
+
|
| 907 |
+
collated[f'combined_voxels_{res}'] = cat_or_empty(all_combined_voxels_list,(0,4),torch.int32)
|
| 908 |
+
collated[f'combined_voxel_labels_{res}'] = cat_or_empty(all_combined_labels_list,(0,),torch.long)
|
| 909 |
+
collated[f'gt_vertex_voxels_{res}'] = cat_or_empty(all_vertex_voxels_only,(0,4),torch.int32)
|
| 910 |
+
collated[f'gt_edge_voxels_{res}'] = cat_or_empty(all_edge_voxels_only,(0,4),torch.int32)
|
| 911 |
+
collated[f'gt_edge_endpoints_{res}'] = cat_or_empty(all_edge_endpoints_only,(0,7),torch.float32)
|
| 912 |
+
collated[f'gt_combined_endpoints_{res}'] = cat_or_empty(all_combined_endpoints,(0,7),torch.float32)
|
| 913 |
+
collated[f'gt_combined_errors_{res}'] = cat_or_empty(all_combined_errors_list,(0,4),torch.float32)
|
| 914 |
+
|
| 915 |
+
collated[f'active_voxels_{res}'] = cat_or_empty(all_active_voxels_list, (0, 4), torch.int32)
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
if all_point_clouds_list:
|
| 919 |
+
collated[f'point_cloud_{res}'] = torch.stack(all_point_clouds_list, dim=0)
|
| 920 |
+
else:
|
| 921 |
+
# collated[f'point_cloud_{res}'] = torch.empty((0, 6), dtype=torch.float32, device=device)
|
| 922 |
+
collated[f'point_cloud_{res}'] = torch.empty((0, 15), dtype=torch.float32, device=device)
|
| 923 |
+
|
| 924 |
+
return collated
|
dataset_triposf_head.py
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from typing import *
|
| 4 |
+
import trimesh
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
# from utils import quantize_vertices
|
| 9 |
+
from utils import get_voxel_line
|
| 10 |
+
import random
|
| 11 |
+
import hashlib
|
| 12 |
+
import json
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import rembg
|
| 17 |
+
import open3d as o3d
|
| 18 |
+
from trimesh import grouping
|
| 19 |
+
|
| 20 |
+
def normalize_mesh(mesh_path):
|
| 21 |
+
scene = trimesh.load(mesh_path, process=False, force='scene')
|
| 22 |
+
meshes = []
|
| 23 |
+
for node_name in scene.graph.nodes_geometry:
|
| 24 |
+
geom_name = scene.graph[node_name][1]
|
| 25 |
+
geometry = scene.geometry[geom_name]
|
| 26 |
+
transform = scene.graph[node_name][0]
|
| 27 |
+
if isinstance(geometry, trimesh.Trimesh):
|
| 28 |
+
geometry.apply_transform(transform)
|
| 29 |
+
meshes.append(geometry)
|
| 30 |
+
|
| 31 |
+
mesh = trimesh.util.concatenate(meshes)
|
| 32 |
+
|
| 33 |
+
center = mesh.bounding_box.centroid
|
| 34 |
+
mesh.apply_translation(-center)
|
| 35 |
+
scale = max(mesh.bounding_box.extents)
|
| 36 |
+
mesh.apply_scale(2.0 / scale * 0.5)
|
| 37 |
+
|
| 38 |
+
return mesh
|
| 39 |
+
|
| 40 |
+
def quantize_vertices(vertices: torch.Tensor, res: int):
|
| 41 |
+
"""
|
| 42 |
+
Quantize normalized vertices (range approx [-0.5, 0.5]) to integer grid [0, res-1].
|
| 43 |
+
"""
|
| 44 |
+
normalized = vertices + 0.5
|
| 45 |
+
scaled = normalized * res
|
| 46 |
+
quantized = torch.floor(scaled).clamp(0, res - 1).int()
|
| 47 |
+
return quantized
|
| 48 |
+
|
| 49 |
+
def sample_edges_dora(tm_mesh, n_samples):
|
| 50 |
+
adj_faces = tm_mesh.face_adjacency
|
| 51 |
+
adj_edges = tm_mesh.face_adjacency_edges
|
| 52 |
+
|
| 53 |
+
internal_data = None
|
| 54 |
+
if len(adj_faces) > 0:
|
| 55 |
+
n0 = tm_mesh.face_normals[adj_faces[:, 0]]
|
| 56 |
+
n1 = tm_mesh.face_normals[adj_faces[:, 1]]
|
| 57 |
+
sum_normals = n0 + n1
|
| 58 |
+
norms = np.linalg.norm(sum_normals, axis=1, keepdims=True)
|
| 59 |
+
norms[norms < 1e-6] = 1.0
|
| 60 |
+
int_normals = sum_normals / norms
|
| 61 |
+
|
| 62 |
+
int_v_start = tm_mesh.vertices[adj_edges[:, 0]]
|
| 63 |
+
int_v_end = tm_mesh.vertices[adj_edges[:, 1]]
|
| 64 |
+
|
| 65 |
+
faces_pair = tm_mesh.faces[adj_faces]
|
| 66 |
+
sum_face_indices = np.sum(faces_pair, axis=2)
|
| 67 |
+
sum_edge_indices = np.sum(adj_edges, axis=1)
|
| 68 |
+
|
| 69 |
+
unique_idx_0 = sum_face_indices[:, 0] - sum_edge_indices
|
| 70 |
+
unique_idx_1 = sum_face_indices[:, 1] - sum_edge_indices
|
| 71 |
+
|
| 72 |
+
v_unique_0 = tm_mesh.vertices[unique_idx_0]
|
| 73 |
+
v_unique_1 = tm_mesh.vertices[unique_idx_1]
|
| 74 |
+
int_v_virtual = (v_unique_0 + v_unique_1) * 0.5
|
| 75 |
+
|
| 76 |
+
internal_data = (int_v_start, int_v_end, int_normals, int_v_virtual)
|
| 77 |
+
|
| 78 |
+
edges_sorted = tm_mesh.edges_sorted
|
| 79 |
+
|
| 80 |
+
if len(edges_sorted) == 0:
|
| 81 |
+
boundary_data = None
|
| 82 |
+
else:
|
| 83 |
+
boundary_group = grouping.group_rows(edges_sorted, require_count=1)
|
| 84 |
+
|
| 85 |
+
boundary_data = None
|
| 86 |
+
if len(boundary_group) > 0:
|
| 87 |
+
boundary_indices = np.concatenate([np.atleast_1d(g) for g in boundary_group])
|
| 88 |
+
|
| 89 |
+
boundary_face_indices = boundary_indices // 3
|
| 90 |
+
|
| 91 |
+
bnd_normals = tm_mesh.face_normals[boundary_face_indices]
|
| 92 |
+
|
| 93 |
+
bnd_edge_v_indices = edges_sorted[boundary_indices]
|
| 94 |
+
bnd_v_start = tm_mesh.vertices[bnd_edge_v_indices[:, 0]]
|
| 95 |
+
bnd_v_end = tm_mesh.vertices[bnd_edge_v_indices[:, 1]]
|
| 96 |
+
|
| 97 |
+
boundary_face_v_indices = tm_mesh.faces[boundary_face_indices]
|
| 98 |
+
|
| 99 |
+
sum_face = np.sum(boundary_face_v_indices, axis=1)
|
| 100 |
+
sum_edge = np.sum(bnd_edge_v_indices, axis=1)
|
| 101 |
+
unique_idx = sum_face - sum_edge
|
| 102 |
+
|
| 103 |
+
bnd_v_virtual = tm_mesh.vertices[unique_idx]
|
| 104 |
+
|
| 105 |
+
boundary_data = (bnd_v_start, bnd_v_end, bnd_normals, bnd_v_virtual)
|
| 106 |
+
|
| 107 |
+
if internal_data is None and boundary_data is None:
|
| 108 |
+
return None, None, None
|
| 109 |
+
|
| 110 |
+
parts_start, parts_end, parts_norm, parts_virt = [], [], [], []
|
| 111 |
+
|
| 112 |
+
if internal_data is not None:
|
| 113 |
+
parts_start.append(internal_data[0])
|
| 114 |
+
parts_end.append(internal_data[1])
|
| 115 |
+
parts_norm.append(internal_data[2])
|
| 116 |
+
parts_virt.append(internal_data[3])
|
| 117 |
+
|
| 118 |
+
if boundary_data is not None:
|
| 119 |
+
parts_start.append(boundary_data[0])
|
| 120 |
+
parts_end.append(boundary_data[1])
|
| 121 |
+
parts_norm.append(boundary_data[2])
|
| 122 |
+
parts_virt.append(boundary_data[3])
|
| 123 |
+
|
| 124 |
+
if not parts_start:
|
| 125 |
+
return None, None, None
|
| 126 |
+
|
| 127 |
+
all_v_start = np.concatenate(parts_start, axis=0)
|
| 128 |
+
all_v_end = np.concatenate(parts_end, axis=0)
|
| 129 |
+
all_normals = np.concatenate(parts_norm, axis=0)
|
| 130 |
+
all_v_virtual = np.concatenate(parts_virt, axis=0)
|
| 131 |
+
|
| 132 |
+
edge_vectors = all_v_end - all_v_start
|
| 133 |
+
edge_lengths = np.linalg.norm(edge_vectors, axis=1)
|
| 134 |
+
total_length = np.sum(edge_lengths)
|
| 135 |
+
|
| 136 |
+
if total_length < 1e-9:
|
| 137 |
+
probs = np.ones(len(edge_lengths)) / len(edge_lengths)
|
| 138 |
+
else:
|
| 139 |
+
probs = edge_lengths / total_length
|
| 140 |
+
probs = probs / probs.sum()
|
| 141 |
+
|
| 142 |
+
chosen_indices = np.random.choice(len(edge_lengths), size=n_samples, p=probs)
|
| 143 |
+
|
| 144 |
+
t = np.random.rand(n_samples, 1)
|
| 145 |
+
|
| 146 |
+
sel_v_start = all_v_start[chosen_indices]
|
| 147 |
+
sel_v_end = all_v_end[chosen_indices]
|
| 148 |
+
sel_normals = all_normals[chosen_indices]
|
| 149 |
+
sel_v_virtual = all_v_virtual[chosen_indices]
|
| 150 |
+
|
| 151 |
+
sampled_points = sel_v_start + (sel_v_end - sel_v_start) * t
|
| 152 |
+
|
| 153 |
+
vertex_triplets = np.stack([sel_v_start, sel_v_end, sel_v_virtual], axis=1).astype(np.float32)
|
| 154 |
+
|
| 155 |
+
return sampled_points.astype(np.float32), sel_normals.astype(np.float32), vertex_triplets
|
| 156 |
+
|
| 157 |
+
def sample_edges_hunyuan(tm_mesh, n_samples):
|
| 158 |
+
if tm_mesh.vertex_normals is None:
|
| 159 |
+
tm_mesh.compute_vertex_normals()
|
| 160 |
+
|
| 161 |
+
V = tm_mesh.vertices
|
| 162 |
+
F = tm_mesh.faces
|
| 163 |
+
VN = tm_mesh.vertex_normals
|
| 164 |
+
|
| 165 |
+
# Edge 0: v0 -> v1 (Virtual: v2)
|
| 166 |
+
# Edge 1: v1 -> v2 (Virtual: v0)
|
| 167 |
+
# Edge 2: v2 -> v0 (Virtual: v1)
|
| 168 |
+
|
| 169 |
+
idx_start = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
|
| 170 |
+
idx_end = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
|
| 171 |
+
idx_virt = np.concatenate([F[:, 2], F[:, 0], F[:, 1]])
|
| 172 |
+
|
| 173 |
+
v_start = V[idx_start]
|
| 174 |
+
v_end = V[idx_end]
|
| 175 |
+
|
| 176 |
+
edge_vectors = v_end - v_start
|
| 177 |
+
edge_lengths = np.linalg.norm(edge_vectors, axis=1)
|
| 178 |
+
|
| 179 |
+
total_length = np.sum(edge_lengths)
|
| 180 |
+
|
| 181 |
+
if total_length < 1e-9:
|
| 182 |
+
probs = np.ones(len(edge_lengths)) / len(edge_lengths)
|
| 183 |
+
else:
|
| 184 |
+
probs = edge_lengths / total_length
|
| 185 |
+
|
| 186 |
+
chosen_indices = np.random.choice(len(edge_lengths), size=n_samples, p=probs)
|
| 187 |
+
|
| 188 |
+
sel_v_start = v_start[chosen_indices]
|
| 189 |
+
sel_v_end = v_end[chosen_indices]
|
| 190 |
+
sel_v_virt = V[idx_virt[chosen_indices]]
|
| 191 |
+
|
| 192 |
+
sel_vn_start = VN[idx_start[chosen_indices]]
|
| 193 |
+
sel_vn_end = VN[idx_end[chosen_indices]]
|
| 194 |
+
|
| 195 |
+
t = np.random.rand(n_samples, 1)
|
| 196 |
+
|
| 197 |
+
sampled_points = sel_v_start + (sel_v_end - sel_v_start) * t
|
| 198 |
+
|
| 199 |
+
sampled_normals = sel_vn_start * (1 - t) + sel_vn_end * t
|
| 200 |
+
|
| 201 |
+
norm_vals = np.linalg.norm(sampled_normals, axis=1, keepdims=True)
|
| 202 |
+
norm_vals[norm_vals < 1e-6] = 1.0
|
| 203 |
+
sampled_normals = sampled_normals / norm_vals
|
| 204 |
+
|
| 205 |
+
vertex_triplets = np.stack([sel_v_start, sel_v_end, sel_v_virt], axis=1)
|
| 206 |
+
|
| 207 |
+
return sampled_points.astype(np.float32), sampled_normals.astype(np.float32), vertex_triplets.astype(np.float32)
|
| 208 |
+
|
| 209 |
+
def load_quantized_mesh_dora(
|
| 210 |
+
mesh_path,
|
| 211 |
+
mesh_load=None,
|
| 212 |
+
volume_resolution=256,
|
| 213 |
+
use_normals=True,
|
| 214 |
+
pc_sample_number=4096000,
|
| 215 |
+
edge_sample_ratio=0.2
|
| 216 |
+
):
|
| 217 |
+
cube_dilate = np.array([
|
| 218 |
+
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, -1], [0, -1, 0], [0, 1, 1], [0, -1, 1], [0, 1, -1], [0, -1, -1],
|
| 219 |
+
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 0, -1], [1, -1, 0], [1, 1, 1], [1, -1, 1], [1, 1, -1], [1, -1, -1],
|
| 220 |
+
[-1, 0, 0], [-1, 0, 1], [-1, 1, 0], [-1, 0, -1], [-1, -1, 0], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1], [-1, -1, -1]
|
| 221 |
+
]) / (volume_resolution * 4 - 1)
|
| 222 |
+
|
| 223 |
+
if mesh_load is None:
|
| 224 |
+
mesh_o3d = o3d.io.read_triangle_mesh(mesh_path)
|
| 225 |
+
vertices = np.clip(np.asarray(mesh_o3d.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 226 |
+
faces = np.asarray(mesh_o3d.triangles)
|
| 227 |
+
mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
|
| 228 |
+
else:
|
| 229 |
+
vertices = np.clip(np.asarray(mesh_load.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 230 |
+
faces = np.asarray(mesh_load.faces)
|
| 231 |
+
mesh_o3d = o3d.geometry.TriangleMesh()
|
| 232 |
+
mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
|
| 233 |
+
mesh_o3d.triangles = o3d.utility.Vector3iVector(faces)
|
| 234 |
+
|
| 235 |
+
tm_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
|
| 236 |
+
|
| 237 |
+
n_edge_samples = int(pc_sample_number * edge_sample_ratio)
|
| 238 |
+
p_edge, n_edge, triplets_edge = sample_edges_dora(tm_mesh, n_edge_samples)
|
| 239 |
+
|
| 240 |
+
if p_edge is None:
|
| 241 |
+
# print('p_edge is none!')
|
| 242 |
+
n_surface_samples = pc_sample_number
|
| 243 |
+
else:
|
| 244 |
+
# print('p_edge is right!')
|
| 245 |
+
n_surface_samples = pc_sample_number - n_edge_samples
|
| 246 |
+
|
| 247 |
+
p_surf, idx_surf = tm_mesh.sample(n_surface_samples, return_index=True)
|
| 248 |
+
p_surf = p_surf.astype(np.float32)
|
| 249 |
+
n_surf = tm_mesh.face_normals[idx_surf].astype(np.float32)
|
| 250 |
+
|
| 251 |
+
v_indices_surf = faces[idx_surf]
|
| 252 |
+
triplets_surf = vertices[v_indices_surf]
|
| 253 |
+
|
| 254 |
+
if p_edge is None:
|
| 255 |
+
final_points = p_surf
|
| 256 |
+
final_normals = n_surf
|
| 257 |
+
final_triplets = triplets_surf
|
| 258 |
+
else:
|
| 259 |
+
final_points = np.concatenate([p_surf, p_edge.astype(np.float32)], axis=0)
|
| 260 |
+
if use_normals:
|
| 261 |
+
final_normals = np.concatenate([n_surf, n_edge.astype(np.float32)], axis=0)
|
| 262 |
+
|
| 263 |
+
final_triplets = np.concatenate([triplets_surf, triplets_edge], axis=0)
|
| 264 |
+
|
| 265 |
+
voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
|
| 266 |
+
mesh_o3d,
|
| 267 |
+
voxel_size=1. / volume_resolution,
|
| 268 |
+
min_bound=[-0.5, -0.5, -0.5],
|
| 269 |
+
max_bound=[0.5, 0.5, 0.5]
|
| 270 |
+
)
|
| 271 |
+
voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
|
| 272 |
+
|
| 273 |
+
voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
|
| 274 |
+
o3d.geometry.PointCloud(
|
| 275 |
+
o3d.utility.Vector3dVector(
|
| 276 |
+
np.clip(
|
| 277 |
+
(final_points[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
|
| 278 |
+
-0.5 + 1e-6, 0.5 - 1e-6)
|
| 279 |
+
)
|
| 280 |
+
),
|
| 281 |
+
voxel_size=1. / volume_resolution,
|
| 282 |
+
min_bound=[-0.5, -0.5, -0.5],
|
| 283 |
+
max_bound=[0.5, 0.5, 0.5]
|
| 284 |
+
)
|
| 285 |
+
voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
|
| 286 |
+
voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
|
| 287 |
+
|
| 288 |
+
features_list = [torch.from_numpy(final_points)]
|
| 289 |
+
|
| 290 |
+
if use_normals:
|
| 291 |
+
features_list.append(torch.from_numpy(final_normals))
|
| 292 |
+
|
| 293 |
+
view_dtype = np.dtype((np.void, final_triplets.dtype.itemsize * final_triplets.shape[-1]))
|
| 294 |
+
v_view = final_triplets.view(view_dtype).squeeze(-1)
|
| 295 |
+
|
| 296 |
+
sort_idx = np.argsort(v_view, axis=1)
|
| 297 |
+
|
| 298 |
+
batch_indices = np.arange(final_triplets.shape[0])[:, None]
|
| 299 |
+
v_sorted = final_triplets[batch_indices, sort_idx]
|
| 300 |
+
|
| 301 |
+
v1 = v_sorted[:, 0, :]
|
| 302 |
+
v2 = v_sorted[:, 1, :]
|
| 303 |
+
v3 = v_sorted[:, 2, :]
|
| 304 |
+
|
| 305 |
+
dir1 = v1 - final_points
|
| 306 |
+
dir2 = v2 - final_points
|
| 307 |
+
dir3 = v3 - final_points
|
| 308 |
+
|
| 309 |
+
features_list.append(torch.Tensor(dir1.astype(np.float32)))
|
| 310 |
+
features_list.append(torch.Tensor(dir2.astype(np.float32)))
|
| 311 |
+
features_list.append(torch.Tensor(dir3.astype(np.float32)))
|
| 312 |
+
|
| 313 |
+
points_sample = torch.cat(features_list, axis=-1)
|
| 314 |
+
|
| 315 |
+
return voxels, points_sample
|
| 316 |
+
|
| 317 |
+
def load_quantized_mesh_original(
|
| 318 |
+
mesh_path,
|
| 319 |
+
mesh_load=None,
|
| 320 |
+
volume_resolution=256,
|
| 321 |
+
use_normals=True,
|
| 322 |
+
pc_sample_number=4096000,
|
| 323 |
+
):
|
| 324 |
+
cube_dilate = np.array(
|
| 325 |
+
[
|
| 326 |
+
[0, 0, 0],
|
| 327 |
+
[0, 0, 1],
|
| 328 |
+
[0, 1, 0],
|
| 329 |
+
[0, 0, -1],
|
| 330 |
+
[0, -1, 0],
|
| 331 |
+
[0, 1, 1],
|
| 332 |
+
[0, -1, 1],
|
| 333 |
+
[0, 1, -1],
|
| 334 |
+
[0, -1, -1],
|
| 335 |
+
|
| 336 |
+
[1, 0, 0],
|
| 337 |
+
[1, 0, 1],
|
| 338 |
+
[1, 1, 0],
|
| 339 |
+
[1, 0, -1],
|
| 340 |
+
[1, -1, 0],
|
| 341 |
+
[1, 1, 1],
|
| 342 |
+
[1, -1, 1],
|
| 343 |
+
[1, 1, -1],
|
| 344 |
+
[1, -1, -1],
|
| 345 |
+
|
| 346 |
+
[-1, 0, 0],
|
| 347 |
+
[-1, 0, 1],
|
| 348 |
+
[-1, 1, 0],
|
| 349 |
+
[-1, 0, -1],
|
| 350 |
+
[-1, -1, 0],
|
| 351 |
+
[-1, 1, 1],
|
| 352 |
+
[-1, -1, 1],
|
| 353 |
+
[-1, 1, -1],
|
| 354 |
+
[-1, -1, -1],
|
| 355 |
+
]
|
| 356 |
+
) / (volume_resolution * 4 - 1)
|
| 357 |
+
|
| 358 |
+
if mesh_load is None:
|
| 359 |
+
mesh = o3d.io.read_triangle_mesh(mesh_path)
|
| 360 |
+
vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 361 |
+
faces = np.asarray(mesh.triangles)
|
| 362 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
| 363 |
+
else:
|
| 364 |
+
vertices = np.clip(np.asarray(mesh_load.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 365 |
+
faces = np.asarray(mesh_load.faces)
|
| 366 |
+
|
| 367 |
+
mesh = o3d.geometry.TriangleMesh()
|
| 368 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
| 369 |
+
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
|
| 373 |
+
mesh,
|
| 374 |
+
voxel_size=1. / volume_resolution,
|
| 375 |
+
min_bound=[-0.5, -0.5, -0.5],
|
| 376 |
+
max_bound=[0.5, 0.5, 0.5]
|
| 377 |
+
)
|
| 378 |
+
voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
|
| 379 |
+
|
| 380 |
+
points_normals_sample = trimesh.Trimesh(vertices=vertices, faces=faces).sample(count=pc_sample_number, return_index=True)
|
| 381 |
+
points_xyz_np = points_normals_sample[0].astype(np.float32)
|
| 382 |
+
points_sample = points_normals_sample[0].astype(np.float32)
|
| 383 |
+
face_indices = points_normals_sample[1]
|
| 384 |
+
|
| 385 |
+
voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
|
| 386 |
+
o3d.geometry.PointCloud(
|
| 387 |
+
o3d.utility.Vector3dVector(
|
| 388 |
+
np.clip(
|
| 389 |
+
(points_sample[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
|
| 390 |
+
-0.5 + 1e-6, 0.5 - 1e-6)
|
| 391 |
+
)
|
| 392 |
+
),
|
| 393 |
+
voxel_size=1. / volume_resolution,
|
| 394 |
+
min_bound=[-0.5, -0.5, -0.5],
|
| 395 |
+
max_bound=[0.5, 0.5, 0.5]
|
| 396 |
+
)
|
| 397 |
+
voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
|
| 398 |
+
voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
|
| 399 |
+
|
| 400 |
+
features_list = [torch.from_numpy(points_xyz_np)]
|
| 401 |
+
|
| 402 |
+
if use_normals:
|
| 403 |
+
mesh.compute_triangle_normals()
|
| 404 |
+
normals_sample = np.asarray(
|
| 405 |
+
mesh.triangle_normals
|
| 406 |
+
)[points_normals_sample[1]].astype(np.float32)
|
| 407 |
+
# points_sample = torch.cat((torch.Tensor(points_sample), torch.Tensor(normals_sample)), axis=-1)
|
| 408 |
+
features_list.append(torch.from_numpy(normals_sample))
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
########################################
|
| 412 |
+
# add direction to three vtx
|
| 413 |
+
########################################
|
| 414 |
+
|
| 415 |
+
## wo sort
|
| 416 |
+
# sampled_face_v_indices = faces[face_indices]
|
| 417 |
+
# v1 = vertices[sampled_face_v_indices[:, 0]]
|
| 418 |
+
# v2 = vertices[sampled_face_v_indices[:, 1]]
|
| 419 |
+
# v3 = vertices[sampled_face_v_indices[:, 2]]
|
| 420 |
+
|
| 421 |
+
# w sort
|
| 422 |
+
sampled_face_v_indices = faces[face_indices]
|
| 423 |
+
v_batch = np.stack([
|
| 424 |
+
vertices[sampled_face_v_indices[:, 0]],
|
| 425 |
+
vertices[sampled_face_v_indices[:, 1]],
|
| 426 |
+
vertices[sampled_face_v_indices[:, 2]]
|
| 427 |
+
], axis=1)
|
| 428 |
+
|
| 429 |
+
view_dtype = np.dtype((np.void, v_batch.dtype.itemsize * v_batch.shape[-1]))
|
| 430 |
+
v_view = v_batch.view(view_dtype).squeeze(-1) # 变成 (N, 3) 的 void
|
| 431 |
+
|
| 432 |
+
sort_idx = np.argsort(v_view, axis=1) # (N, 3)
|
| 433 |
+
|
| 434 |
+
batch_indices = np.arange(v_batch.shape[0])[:, None]
|
| 435 |
+
v_sorted = v_batch[batch_indices, sort_idx] # (N, 3, 3)
|
| 436 |
+
|
| 437 |
+
v1 = v_sorted[:, 0, :]
|
| 438 |
+
v2 = v_sorted[:, 1, :]
|
| 439 |
+
v3 = v_sorted[:, 2, :]
|
| 440 |
+
# --------------------
|
| 441 |
+
|
| 442 |
+
dir1 = v1 - points_xyz_np
|
| 443 |
+
dir2 = v2 - points_xyz_np
|
| 444 |
+
dir3 = v3 - points_xyz_np
|
| 445 |
+
|
| 446 |
+
features_list.append(torch.Tensor(dir1.astype(np.float32)))
|
| 447 |
+
features_list.append(torch.Tensor(dir2.astype(np.float32)))
|
| 448 |
+
features_list.append(torch.Tensor(dir3.astype(np.float32)))
|
| 449 |
+
|
| 450 |
+
points_sample = torch.cat(features_list, axis=-1)
|
| 451 |
+
########################################
|
| 452 |
+
# add direction to three vtx
|
| 453 |
+
########################################
|
| 454 |
+
|
| 455 |
+
return voxels, points_sample
|
| 456 |
+
|
| 457 |
+
def get_sha256(filepath: str) -> str:
|
| 458 |
+
sha256_hash = hashlib.sha256()
|
| 459 |
+
with open(filepath, "rb") as f:
|
| 460 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
| 461 |
+
sha256_hash.update(byte_block)
|
| 462 |
+
return sha256_hash.hexdigest()
|
| 463 |
+
|
| 464 |
+
class VoxelVertexDataset_edge(Dataset):
|
| 465 |
+
def __init__(self,
|
| 466 |
+
root_dir: str,
|
| 467 |
+
base_resolution: int = 256,
|
| 468 |
+
min_resolution: int = 128,
|
| 469 |
+
img_res: int = 518,
|
| 470 |
+
cache_dir: str = "dataset_cache_test",
|
| 471 |
+
renders_dir: str = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond',
|
| 472 |
+
process_img: bool = False,
|
| 473 |
+
n_pre_samples: int = 1024,
|
| 474 |
+
|
| 475 |
+
active_voxel_res: int = 64,
|
| 476 |
+
pc_sample_number: int = 409600,
|
| 477 |
+
|
| 478 |
+
filter_active_voxels: bool = False, #####
|
| 479 |
+
min_active_voxels: int = 2000,
|
| 480 |
+
max_active_voxels: int = 40000,
|
| 481 |
+
|
| 482 |
+
cache_filter_path: str = "/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/data/filter_name/objaverse_200_2000_2000min_25000max.txt",
|
| 483 |
+
|
| 484 |
+
sample_type: str = 'uniform',
|
| 485 |
+
):
|
| 486 |
+
self.root_dir = root_dir
|
| 487 |
+
self.cache_dir = cache_dir
|
| 488 |
+
self.img_res = img_res
|
| 489 |
+
self.renders_dir = renders_dir
|
| 490 |
+
self.process_img = process_img
|
| 491 |
+
self.filter_active_voxels=filter_active_voxels
|
| 492 |
+
self.min_active_voxels=min_active_voxels
|
| 493 |
+
self.max_active_voxels=max_active_voxels
|
| 494 |
+
|
| 495 |
+
self.active_voxel_res = active_voxel_res
|
| 496 |
+
self.pc_sample_number = pc_sample_number
|
| 497 |
+
|
| 498 |
+
self.sample_type = sample_type
|
| 499 |
+
|
| 500 |
+
# self.image_transform = transforms.ToTensor()
|
| 501 |
+
self.image_transform = transforms.Compose([
|
| 502 |
+
transforms.ToTensor(),
|
| 503 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 504 |
+
])
|
| 505 |
+
|
| 506 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 507 |
+
assert (base_resolution & (base_resolution - 1)) == 0, "Resolution must be power of 2"
|
| 508 |
+
assert (min_resolution & (min_resolution - 1)) == 0, "Resolution must be power of 2"
|
| 509 |
+
self.res_levels = [
|
| 510 |
+
2**i for i in range(
|
| 511 |
+
int(np.log2(min_resolution)),
|
| 512 |
+
int(np.log2(base_resolution)) + 1
|
| 513 |
+
)
|
| 514 |
+
]
|
| 515 |
+
|
| 516 |
+
if self.active_voxel_res is not None and self.active_voxel_res not in self.res_levels:
|
| 517 |
+
self.res_levels.append(active_voxel_res)
|
| 518 |
+
self.res_levels.sort()
|
| 519 |
+
|
| 520 |
+
all_obj_files = sorted([f for f in os.listdir(root_dir) if f.endswith(('.obj', '.ply', '.glb'))])
|
| 521 |
+
if not all_obj_files:
|
| 522 |
+
raise ValueError(f"No OBJ files found in {root_dir}")
|
| 523 |
+
|
| 524 |
+
if self.process_img:
|
| 525 |
+
map_file_path = os.path.join(os.path.dirname(self.renders_dir), 'map.json')
|
| 526 |
+
if os.path.exists(map_file_path):
|
| 527 |
+
print(f"Loading pre-computed hash map from {map_file_path}")
|
| 528 |
+
with open(map_file_path, 'r') as f:
|
| 529 |
+
file_map = json.load(f)
|
| 530 |
+
filename_to_hash = {item['filename']: item['sha256'] for item in file_map}
|
| 531 |
+
all_obj_hashes = [filename_to_hash.get(fname) for fname in all_obj_files]
|
| 532 |
+
else:
|
| 533 |
+
print("No hash map found. Calculating SHA256 hashes on the fly... (This may take a moment)")
|
| 534 |
+
all_obj_hashes = []
|
| 535 |
+
for fname in tqdm(all_obj_files, desc="Hashing .obj files"):
|
| 536 |
+
fpath = os.path.join(self.root_dir, fname)
|
| 537 |
+
all_obj_hashes.append(get_sha256(fpath))
|
| 538 |
+
|
| 539 |
+
else:
|
| 540 |
+
print("process_img is False, skipping SHA256 hash calculation.")
|
| 541 |
+
all_obj_hashes = [None] * len(all_obj_files)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
if self.filter_active_voxels and cache_filter_path:
|
| 545 |
+
filtered_list_cache_path = cache_filter_path
|
| 546 |
+
|
| 547 |
+
if os.path.exists(filtered_list_cache_path):
|
| 548 |
+
print(f"Loading filtered BASENAMES from: {filtered_list_cache_path}")
|
| 549 |
+
basename_to_fullname_map = {os.path.splitext(f)[0]: f for f in all_obj_files}
|
| 550 |
+
|
| 551 |
+
with open(filtered_list_cache_path, 'r') as f:
|
| 552 |
+
filtered_basenames = [line.strip() for line in f if line.strip()]
|
| 553 |
+
|
| 554 |
+
self.obj_files = []
|
| 555 |
+
for basename in filtered_basenames:
|
| 556 |
+
if basename in basename_to_fullname_map:
|
| 557 |
+
self.obj_files.append(basename_to_fullname_map[basename])
|
| 558 |
+
else:
|
| 559 |
+
print(f"[WARN] Basename '{basename}' from filter list not found in directory '{self.root_dir}'. Skipping.")
|
| 560 |
+
|
| 561 |
+
file_to_hash_map = dict(zip(all_obj_files, all_obj_hashes))
|
| 562 |
+
self.obj_hashes = [file_to_hash_map.get(fname) for fname in self.obj_files] # 使用 .get 更安全
|
| 563 |
+
|
| 564 |
+
print(f"Loaded and matched {len(self.obj_files)} samples from the filter list.")
|
| 565 |
+
|
| 566 |
+
else:
|
| 567 |
+
print(f"Cache filter file not found: {filtered_list_cache_path}. Proceeding with on-the-fly filtering...")
|
| 568 |
+
|
| 569 |
+
else:
|
| 570 |
+
self.obj_files = all_obj_files
|
| 571 |
+
self.obj_hashes = all_obj_hashes
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
if not self.obj_files:
|
| 575 |
+
raise ValueError(f"No OBJ files found in {root_dir}")
|
| 576 |
+
|
| 577 |
+
self.rembg_session = None
|
| 578 |
+
|
| 579 |
+
def _init_rembg_session_if_needed(self):
|
| 580 |
+
if self.rembg_session is None:
|
| 581 |
+
print(f"Initializing rembg session for worker {os.getpid()}...")
|
| 582 |
+
self.rembg_session = rembg.new_session(model_name='u2net')
|
| 583 |
+
|
| 584 |
+
def preprocess_image(self, input: Image.Image) -> Image.Image:
|
| 585 |
+
self._init_rembg_session_if_needed()
|
| 586 |
+
has_alpha = False
|
| 587 |
+
if input.mode == 'RGBA':
|
| 588 |
+
alpha = np.array(input)[:, :, 3]
|
| 589 |
+
if not np.all(alpha == 255):
|
| 590 |
+
has_alpha = True
|
| 591 |
+
if has_alpha:
|
| 592 |
+
output = input
|
| 593 |
+
else:
|
| 594 |
+
input = input.convert('RGB')
|
| 595 |
+
max_size = max(input.size)
|
| 596 |
+
scale = min(1, 1024 / max_size)
|
| 597 |
+
if scale < 1:
|
| 598 |
+
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
| 599 |
+
if getattr(self, 'rembg_session', None) is None:
|
| 600 |
+
self.rembg_session = rembg.new_session('u2net')
|
| 601 |
+
output = rembg.remove(input, session=self.rembg_session)
|
| 602 |
+
output_np = np.array(output)
|
| 603 |
+
alpha = output_np[:, :, 3]
|
| 604 |
+
bbox = np.argwhere(alpha > 0.8 * 255)
|
| 605 |
+
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
|
| 606 |
+
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
|
| 607 |
+
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
| 608 |
+
size = int(size * 1.2)
|
| 609 |
+
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
|
| 610 |
+
output = output.crop(bbox) # type: ignore
|
| 611 |
+
output = output.resize((518, 518), Image.Resampling.LANCZOS)
|
| 612 |
+
output = np.array(output).astype(np.float32) / 255
|
| 613 |
+
output = output[:, :, :3] * output[:, :, 3:4]
|
| 614 |
+
output = Image.fromarray((output * 255).astype(np.uint8))
|
| 615 |
+
return output
|
| 616 |
+
|
| 617 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 618 |
+
name = os.path.splitext(self.obj_files[idx])[0]
|
| 619 |
+
cache_path = os.path.join(self.cache_dir, f"{name}_precombined.npz")
|
| 620 |
+
|
| 621 |
+
sha256_hash = self.obj_hashes[idx]
|
| 622 |
+
mesh_render_dir = os.path.join(self.renders_dir, sha256_hash) if sha256_hash else ""
|
| 623 |
+
|
| 624 |
+
image_path = ''
|
| 625 |
+
if mesh_render_dir and os.path.isdir(mesh_render_dir):
|
| 626 |
+
try:
|
| 627 |
+
render_files = [f for f in os.listdir(mesh_render_dir) if f.endswith('.png')]
|
| 628 |
+
if render_files:
|
| 629 |
+
image_path = os.path.join(mesh_render_dir, random.choice(render_files))
|
| 630 |
+
except OSError as e:
|
| 631 |
+
print(f"[WARN] Could not access render directory {mesh_render_dir}: {e}")
|
| 632 |
+
|
| 633 |
+
if self.process_img:
|
| 634 |
+
try:
|
| 635 |
+
if image_path and os.path.exists(image_path):
|
| 636 |
+
image_obj = self.image_transform(self.preprocess_image(Image.open(image_path)).convert('RGB'))
|
| 637 |
+
else:
|
| 638 |
+
image_obj = self.image_transform(Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)).convert('RGB'))
|
| 639 |
+
except Exception as e:
|
| 640 |
+
image_obj = self.image_transform(Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)).convert('RGB'))
|
| 641 |
+
print(f'Error processing image {image_path}: {e}')
|
| 642 |
+
|
| 643 |
+
if os.path.exists(cache_path):
|
| 644 |
+
try:
|
| 645 |
+
loaded = np.load(cache_path, allow_pickle=True)
|
| 646 |
+
data = {
|
| 647 |
+
'original_faces': torch.from_numpy(loaded['original_faces']),
|
| 648 |
+
'original_vertices': torch.from_numpy(loaded['original_vertices']),
|
| 649 |
+
}
|
| 650 |
+
for res in self.res_levels:
|
| 651 |
+
# Load standard voxel data
|
| 652 |
+
if f'combined_voxels_{res}' in loaded:
|
| 653 |
+
data[f'combined_voxels_{res}'] = torch.from_numpy(loaded[f'combined_voxels_{res}'])
|
| 654 |
+
data[f'combined_voxel_labels_{res}'] = torch.from_numpy(loaded[f'combined_voxel_labels_{res}'])
|
| 655 |
+
data[f'gt_combined_endpoints_{res}'] = torch.from_numpy(loaded[f'gt_combined_endpoints_{res}'])
|
| 656 |
+
|
| 657 |
+
data[f'gt_vertex_voxels_{res}'] = torch.from_numpy(loaded[f'gt_vertex_voxels_{res}'])
|
| 658 |
+
data[f'gt_edge_voxels_{res}'] = torch.from_numpy(loaded[f'gt_edge_voxels_{res}'])
|
| 659 |
+
data[f'gt_edge_endpoints_{res}'] = torch.from_numpy(loaded[f'gt_edge_endpoints_{res}'])
|
| 660 |
+
data[f'gt_edge_errors_{res}'] = torch.from_numpy(loaded[f'gt_edge_errors_{res}'])
|
| 661 |
+
|
| 662 |
+
# Load Active Voxels and Point Cloud for Local Pooling
|
| 663 |
+
if res == self.active_voxel_res:
|
| 664 |
+
if f'active_voxels_{res}' in loaded:
|
| 665 |
+
data[f'active_voxels_{res}'] = torch.from_numpy(loaded[f'active_voxels_{res}'])
|
| 666 |
+
if f'point_cloud_{res}' in loaded:
|
| 667 |
+
data[f'point_cloud_{res}'] = torch.from_numpy(loaded[f'point_cloud_{res}'])
|
| 668 |
+
|
| 669 |
+
if f'gt_vertex_edge_indices_{res}' in loaded:
|
| 670 |
+
data[f'gt_vertex_edge_indices_{res}'] = torch.from_numpy(loaded[f'gt_vertex_edge_indices_{res}'])
|
| 671 |
+
|
| 672 |
+
if self.process_img:
|
| 673 |
+
data['image'] = image_obj
|
| 674 |
+
data['image_path'] = image_path
|
| 675 |
+
return data
|
| 676 |
+
|
| 677 |
+
except Exception as e:
|
| 678 |
+
print(f"[WARN] Corrupted NPZ cache {cache_path}, regenerating... {e}")
|
| 679 |
+
os.remove(cache_path)
|
| 680 |
+
|
| 681 |
+
try:
|
| 682 |
+
mesh_path = os.path.join(self.root_dir, self.obj_files[idx])
|
| 683 |
+
mesh = normalize_mesh(mesh_path)
|
| 684 |
+
if mesh.is_empty or not hasattr(mesh.vertices, 'shape') or mesh.vertices.shape[0] < 3 or not hasattr(mesh.faces, 'shape') or mesh.faces.shape[0] < 1:
|
| 685 |
+
raise ValueError("Invalid or empty mesh")
|
| 686 |
+
except Exception as e:
|
| 687 |
+
print(f"[ERROR] Failed to load mesh: {self.obj_files[idx]} | {e}")
|
| 688 |
+
return self.__getitem__((idx + 1) % len(self))
|
| 689 |
+
|
| 690 |
+
vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
|
| 691 |
+
faces = torch.tensor(mesh.faces, dtype=torch.long)
|
| 692 |
+
|
| 693 |
+
data = {'original_faces': faces.clone(), 'original_vertices': vertices.clone()}
|
| 694 |
+
|
| 695 |
+
for res in self.res_levels:
|
| 696 |
+
quantized = quantize_vertices(vertices, res)
|
| 697 |
+
tmesh = trimesh.Trimesh(vertices=quantized.numpy(), faces=faces.numpy())
|
| 698 |
+
tmesh.merge_vertices()
|
| 699 |
+
|
| 700 |
+
vertex_voxels_raw = torch.from_numpy(tmesh.vertices.astype(np.int32))
|
| 701 |
+
edges_raw = tmesh.edges_unique
|
| 702 |
+
|
| 703 |
+
edges_indices_raw = torch.from_numpy(tmesh.edges_unique.astype(np.long))
|
| 704 |
+
data[f'gt_vertex_edge_indices_{res}'] = edges_indices_raw
|
| 705 |
+
|
| 706 |
+
vertex_labels_raw = torch.zeros(vertex_voxels_raw.shape[0], dtype=torch.long)
|
| 707 |
+
|
| 708 |
+
all_edge_voxels = []
|
| 709 |
+
edge_endpoints = []
|
| 710 |
+
edge_errors = []
|
| 711 |
+
|
| 712 |
+
for u_idx, v_idx in edges_raw:
|
| 713 |
+
p1_grid, p2_grid = vertex_voxels_raw[u_idx].float(), vertex_voxels_raw[v_idx].float()
|
| 714 |
+
v, ep, err = get_voxel_line(p1_grid, p2_grid, mode='cpu')
|
| 715 |
+
all_edge_voxels.extend(v)
|
| 716 |
+
edge_endpoints.extend(ep)
|
| 717 |
+
edge_errors.extend(err)
|
| 718 |
+
|
| 719 |
+
if all_edge_voxels:
|
| 720 |
+
edge_voxels_np = np.array(all_edge_voxels, dtype=np.int32)
|
| 721 |
+
edge_endpoints_np = np.array([np.stack(pair) for pair in edge_endpoints], dtype=np.float32)
|
| 722 |
+
edge_errors_np = np.array(edge_errors, dtype=np.float32)
|
| 723 |
+
|
| 724 |
+
unique_edge_voxels_np, first_indices = np.unique(edge_voxels_np, axis=0, return_index=True)
|
| 725 |
+
edge_voxels_raw = torch.from_numpy(unique_edge_voxels_np)
|
| 726 |
+
edge_labels_raw = torch.ones(len(edge_voxels_raw), dtype=torch.long)
|
| 727 |
+
edge_endpoints_raw = torch.from_numpy(edge_endpoints_np[first_indices])
|
| 728 |
+
edge_errors_raw = torch.from_numpy(edge_errors_np[first_indices])
|
| 729 |
+
else:
|
| 730 |
+
edge_voxels_raw = torch.empty(0, 3, dtype=torch.int32)
|
| 731 |
+
edge_labels_raw = torch.empty(0, dtype=torch.long)
|
| 732 |
+
edge_endpoints_raw = torch.empty(0, 2, 3, dtype=torch.float32)
|
| 733 |
+
edge_errors_raw = torch.empty(0, 3, dtype=torch.float32)
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
if res == self.active_voxel_res:
|
| 737 |
+
try:
|
| 738 |
+
if self.sample_type == 'uniform':
|
| 739 |
+
# triposf-style, normilize wrong
|
| 740 |
+
ts_voxels, ts_points = load_quantized_mesh_original(
|
| 741 |
+
mesh_path=os.path.join(self.root_dir, self.obj_files[idx]),
|
| 742 |
+
mesh_load=mesh,
|
| 743 |
+
volume_resolution=res,
|
| 744 |
+
use_normals=True,
|
| 745 |
+
pc_sample_number=self.pc_sample_number,
|
| 746 |
+
)
|
| 747 |
+
else:
|
| 748 |
+
ts_voxels, ts_points = load_quantized_mesh_dora(
|
| 749 |
+
mesh_path=os.path.join(self.root_dir, self.obj_files[idx]),
|
| 750 |
+
mesh_load=mesh,
|
| 751 |
+
volume_resolution=res,
|
| 752 |
+
use_normals=True,
|
| 753 |
+
pc_sample_number=self.pc_sample_number,
|
| 754 |
+
edge_sample_ratio=0.5,
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# Convert types
|
| 758 |
+
# Voxels from TripoSF are float Tensor (N, 3), convert to int32
|
| 759 |
+
data[f'active_voxels_{res}'] = ts_voxels.int()
|
| 760 |
+
data[f'point_cloud_{res}'] = ts_points
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
except Exception as e:
|
| 764 |
+
print(f"[ERROR] Failed to compute active voxels/points for {name} at res {res}: {e}")
|
| 765 |
+
data[f'active_voxels_{res}'] = torch.empty(0, 3, dtype=torch.int32)
|
| 766 |
+
data[f'point_cloud_{res}'] = torch.empty(0, 6, dtype=torch.float32)
|
| 767 |
+
|
| 768 |
+
combined_voxels = torch.cat([vertex_voxels_raw, edge_voxels_raw], dim=0)
|
| 769 |
+
combined_labels = torch.cat([vertex_labels_raw, edge_labels_raw], dim=0)
|
| 770 |
+
|
| 771 |
+
if combined_voxels.numel() > 0:
|
| 772 |
+
unique_voxels, inverse_indices = torch.unique(combined_voxels, dim=0, return_inverse=True)
|
| 773 |
+
|
| 774 |
+
zero_mask = (combined_labels == 0)
|
| 775 |
+
if zero_mask.any():
|
| 776 |
+
zero_per_unique = torch.zeros(len(unique_voxels), dtype=torch.bool)
|
| 777 |
+
zero_per_unique.scatter_(0, inverse_indices[zero_mask], True)
|
| 778 |
+
final_combined_labels = torch.where(zero_per_unique, 0, 1).long()
|
| 779 |
+
else:
|
| 780 |
+
final_combined_labels = torch.ones(len(unique_voxels), dtype=torch.long)
|
| 781 |
+
|
| 782 |
+
if edge_voxels_raw.numel() > 0:
|
| 783 |
+
edge_endpoint_map = {tuple(coord): ep for coord, ep in zip(edge_voxels_raw.numpy(), edge_endpoints_raw.numpy())}
|
| 784 |
+
|
| 785 |
+
endpoints_arr = np.empty((len(unique_voxels), 2, 3), dtype=np.float32)
|
| 786 |
+
unique_voxels_np = unique_voxels.numpy()
|
| 787 |
+
|
| 788 |
+
for j, coord in enumerate(unique_voxels_np):
|
| 789 |
+
coord_tuple = tuple(coord)
|
| 790 |
+
if coord_tuple in edge_endpoint_map:
|
| 791 |
+
endpoints_arr[j] = edge_endpoint_map[coord_tuple]
|
| 792 |
+
else:
|
| 793 |
+
endpoints_arr[j, 0, :] = coord
|
| 794 |
+
endpoints_arr[j, 1, :] = coord
|
| 795 |
+
final_combined_endpoints = torch.from_numpy(endpoints_arr)
|
| 796 |
+
else:
|
| 797 |
+
final_combined_endpoints = unique_voxels.float().unsqueeze(1).repeat(1, 2, 1)
|
| 798 |
+
else:
|
| 799 |
+
unique_voxels = torch.empty(0, 3, dtype=torch.int32)
|
| 800 |
+
final_combined_labels = torch.empty(0, dtype=torch.long)
|
| 801 |
+
final_combined_endpoints = torch.empty(0, 2, 3, dtype=torch.float32)
|
| 802 |
+
|
| 803 |
+
data[f'combined_voxels_{res}'] = unique_voxels
|
| 804 |
+
data[f'combined_voxel_labels_{res}'] = final_combined_labels
|
| 805 |
+
data[f'gt_combined_endpoints_{res}'] = final_combined_endpoints.reshape(-1, 6)
|
| 806 |
+
|
| 807 |
+
data[f'gt_vertex_voxels_{res}'] = vertex_voxels_raw
|
| 808 |
+
data[f'gt_edge_voxels_{res}'] = edge_voxels_raw
|
| 809 |
+
data[f'gt_edge_endpoints_{res}'] = edge_endpoints_raw.reshape(-1, 6)
|
| 810 |
+
data[f'gt_edge_errors_{res}'] = edge_errors_raw
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
save_dict = {
|
| 814 |
+
'original_faces': data['original_faces'].numpy(),
|
| 815 |
+
'original_vertices': data['original_vertices'].numpy(),
|
| 816 |
+
}
|
| 817 |
+
for res in self.res_levels:
|
| 818 |
+
for key_suffix in [
|
| 819 |
+
'combined_voxels', 'combined_voxel_labels', 'gt_combined_endpoints',
|
| 820 |
+
'gt_vertex_voxels', 'gt_edge_voxels', 'gt_edge_endpoints', 'gt_edge_errors',
|
| 821 |
+
'gt_vertex_edge_indices',
|
| 822 |
+
]:
|
| 823 |
+
full_key = f'{key_suffix}_{res}'
|
| 824 |
+
if full_key in data:
|
| 825 |
+
save_dict[full_key] = data[full_key].numpy()
|
| 826 |
+
|
| 827 |
+
if f'active_voxels_{res}' in data:
|
| 828 |
+
save_dict[f'active_voxels_{res}'] = data[f'active_voxels_{res}'].numpy()
|
| 829 |
+
|
| 830 |
+
if f'point_cloud_{res}' in data:
|
| 831 |
+
save_dict[f'point_cloud_{res}'] = data[f'point_cloud_{res}'].numpy()
|
| 832 |
+
|
| 833 |
+
# try:
|
| 834 |
+
# np.savez_compressed(cache_path, **save_dict)
|
| 835 |
+
# except Exception as e:
|
| 836 |
+
# print(f"[ERROR] Failed to save cache {cache_path}: {e}")
|
| 837 |
+
# if os.path.exists(cache_path): os.remove(cache_path)
|
| 838 |
+
|
| 839 |
+
if self.process_img:
|
| 840 |
+
data['image'] = image_obj
|
| 841 |
+
data['image_path'] = image_path
|
| 842 |
+
|
| 843 |
+
return data
|
| 844 |
+
|
| 845 |
+
def __len__(self) -> int:
|
| 846 |
+
return len(self.obj_files)
|
| 847 |
+
|
| 848 |
+
def collate_fn_pointnet(
|
| 849 |
+
batch: List[Dict[str, torch.Tensor]],
|
| 850 |
+
) -> Dict[str, torch.Tensor]:
|
| 851 |
+
|
| 852 |
+
if not batch:
|
| 853 |
+
return {}
|
| 854 |
+
|
| 855 |
+
batch = [b for b in batch if b is not None]
|
| 856 |
+
if not batch:
|
| 857 |
+
return {}
|
| 858 |
+
|
| 859 |
+
collated = {
|
| 860 |
+
'original_faces': [b['original_faces'] for b in batch],
|
| 861 |
+
'original_vertices': [b['original_vertices'] for b in batch],
|
| 862 |
+
'image_path': [b['image_path'] for b in batch],
|
| 863 |
+
}
|
| 864 |
+
|
| 865 |
+
if 'image' in batch[0] and batch[0]['image'] is not None:
|
| 866 |
+
collated['image'] = torch.stack([b['image'] for b in batch])
|
| 867 |
+
|
| 868 |
+
res_levels = []
|
| 869 |
+
for k in batch[0].keys():
|
| 870 |
+
if k.startswith('gt_vertex_voxels_'):
|
| 871 |
+
try:
|
| 872 |
+
res_levels.append(int(k.split('_')[-1]))
|
| 873 |
+
except ValueError:
|
| 874 |
+
pass
|
| 875 |
+
res_levels.sort()
|
| 876 |
+
|
| 877 |
+
for res in res_levels:
|
| 878 |
+
all_active_voxels_list = []
|
| 879 |
+
all_point_clouds_list = []
|
| 880 |
+
|
| 881 |
+
all_combined_voxels_list = []
|
| 882 |
+
all_combined_labels_list = []
|
| 883 |
+
all_vertex_voxels_only = []
|
| 884 |
+
all_edge_voxels_only = []
|
| 885 |
+
all_edge_endpoints_only = []
|
| 886 |
+
all_combined_endpoints = []
|
| 887 |
+
all_combined_errors_list = []
|
| 888 |
+
layout = []
|
| 889 |
+
|
| 890 |
+
vtx_offset = 0
|
| 891 |
+
adj_flat_offset = 0
|
| 892 |
+
start_idx = 0
|
| 893 |
+
|
| 894 |
+
# Attempt to find device from first tensor
|
| 895 |
+
device = torch.device('cpu')
|
| 896 |
+
for v in batch[0].values():
|
| 897 |
+
if isinstance(v, torch.Tensor):
|
| 898 |
+
device = v.device
|
| 899 |
+
break
|
| 900 |
+
|
| 901 |
+
all_edge_indices_list = []
|
| 902 |
+
vertex_count_offset = 0
|
| 903 |
+
|
| 904 |
+
for i, sample in enumerate(batch):
|
| 905 |
+
vertex_voxels = sample.get(f'gt_vertex_voxels_{res}', torch.empty(0,3,dtype=torch.int32)).to(device)
|
| 906 |
+
|
| 907 |
+
num_vertices = vertex_voxels.shape[0]
|
| 908 |
+
|
| 909 |
+
vertex_labels = torch.zeros(vertex_voxels.shape[0], dtype=torch.long, device=device)
|
| 910 |
+
edge_voxels = sample.get(f'gt_edge_voxels_{res}', torch.empty(0,3,dtype=torch.int32)).to(device)
|
| 911 |
+
edge_labels = torch.ones(edge_voxels.shape[0], dtype=torch.long, device=device)
|
| 912 |
+
edge_endpoints= sample.get(f'gt_edge_endpoints_{res}', torch.empty(0,6,dtype=torch.float32)).to(device)
|
| 913 |
+
edge_errors = sample.get(f'gt_edge_errors_{res}', torch.empty(0,3,dtype=torch.float32)).to(device)
|
| 914 |
+
|
| 915 |
+
vertex_errors = sample.get(f'gt_vertex_errors_{res}', torch.zeros_like(vertex_voxels, dtype=torch.float32)).to(device)
|
| 916 |
+
|
| 917 |
+
if vertex_voxels.numel() > 0:
|
| 918 |
+
idx = torch.full((vertex_voxels.shape[0],1), i, dtype=torch.int32, device=device)
|
| 919 |
+
all_vertex_voxels_only.append(torch.cat([idx, vertex_voxels], dim=1))
|
| 920 |
+
|
| 921 |
+
edge_indices = sample.get(f'gt_vertex_edge_indices_{res}', torch.empty(0, 2, dtype=torch.long)).to(device)
|
| 922 |
+
if edge_indices.numel() > 0:
|
| 923 |
+
shifted_indices = edge_indices + vertex_count_offset
|
| 924 |
+
all_edge_indices_list.append(shifted_indices)
|
| 925 |
+
vertex_count_offset += num_vertices
|
| 926 |
+
|
| 927 |
+
if edge_voxels.numel() > 0:
|
| 928 |
+
idx = torch.full((edge_voxels.shape[0],1), i, dtype=torch.int32, device=device)
|
| 929 |
+
all_edge_voxels_only.append(torch.cat([idx, edge_voxels], dim=1))
|
| 930 |
+
all_edge_endpoints_only.append(
|
| 931 |
+
torch.cat([idx.to(torch.float32), edge_endpoints], dim=1))
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
if vertex_voxels.numel() + edge_voxels.numel() > 0:
|
| 935 |
+
combined_voxels = torch.cat([vertex_voxels, edge_voxels], dim=0)
|
| 936 |
+
combined_labels = torch.cat([vertex_labels, edge_labels], dim=0)
|
| 937 |
+
|
| 938 |
+
endpoints = torch.zeros(combined_voxels.size(0), 6, dtype=torch.float32, device=device)
|
| 939 |
+
if edge_voxels.numel() > 0:
|
| 940 |
+
endpoints[-edge_voxels.size(0):] = edge_endpoints
|
| 941 |
+
if vertex_voxels.numel() > 0:
|
| 942 |
+
endpoints[:vertex_voxels.size(0)] = vertex_voxels.repeat(1,2).float()
|
| 943 |
+
|
| 944 |
+
combined_errors = torch.cat([vertex_errors, edge_errors], dim=0)
|
| 945 |
+
|
| 946 |
+
batch_idx_int = torch.full((combined_voxels.shape[0],1), i, dtype=torch.int32, device=device)
|
| 947 |
+
all_combined_voxels_list.append(torch.cat([batch_idx_int, combined_voxels], dim=1))
|
| 948 |
+
all_combined_labels_list.append(combined_labels)
|
| 949 |
+
|
| 950 |
+
batch_idx_float = batch_idx_int.to(torch.float32)
|
| 951 |
+
all_combined_endpoints.append(torch.cat([batch_idx_float, endpoints], dim=1))
|
| 952 |
+
all_combined_errors_list.append(torch.cat([batch_idx_float, combined_errors], dim=1))
|
| 953 |
+
|
| 954 |
+
layout.append(slice(start_idx, start_idx + combined_voxels.shape[0]))
|
| 955 |
+
start_idx += combined_voxels.shape[0]
|
| 956 |
+
else:
|
| 957 |
+
layout.append(slice(start_idx, start_idx))
|
| 958 |
+
|
| 959 |
+
# Active Voxels (Sparse Coords)
|
| 960 |
+
active_voxels = sample.get(f'active_voxels_{res}', torch.empty(0, 3, dtype=torch.int32)).to(device)
|
| 961 |
+
if active_voxels.numel() > 0:
|
| 962 |
+
idx = torch.full((active_voxels.shape[0], 1), i, dtype=torch.int32, device=device)
|
| 963 |
+
all_active_voxels_list.append(torch.cat([idx, active_voxels], dim=1))
|
| 964 |
+
|
| 965 |
+
# ==========================================
|
| 966 |
+
# Modified Section: Collect Point Clouds
|
| 967 |
+
# ==========================================
|
| 968 |
+
# pc = sample.get(f'point_cloud_{res}', torch.empty(0, 6, dtype=torch.float32)).to(device)
|
| 969 |
+
pc = sample.get(f'point_cloud_{res}', torch.empty(0, 15, dtype=torch.float32)).to(device)
|
| 970 |
+
# We expect all samples to have point clouds if res == active_voxel_res
|
| 971 |
+
if pc.numel() > 0:
|
| 972 |
+
all_point_clouds_list.append(pc)
|
| 973 |
+
|
| 974 |
+
collated[f'layout_{res}'] = layout
|
| 975 |
+
|
| 976 |
+
def cat_or_empty(lst, shape, dtype):
|
| 977 |
+
return torch.cat(lst, dim=0) if lst else torch.empty(shape, dtype=dtype, device=device)
|
| 978 |
+
|
| 979 |
+
collated[f'combined_voxels_{res}'] = cat_or_empty(all_combined_voxels_list,(0,4),torch.int32)
|
| 980 |
+
collated[f'combined_voxel_labels_{res}'] = cat_or_empty(all_combined_labels_list,(0,),torch.long)
|
| 981 |
+
collated[f'gt_vertex_voxels_{res}'] = cat_or_empty(all_vertex_voxels_only,(0,4),torch.int32)
|
| 982 |
+
collated[f'gt_edge_voxels_{res}'] = cat_or_empty(all_edge_voxels_only,(0,4),torch.int32)
|
| 983 |
+
collated[f'gt_edge_endpoints_{res}'] = cat_or_empty(all_edge_endpoints_only,(0,7),torch.float32)
|
| 984 |
+
collated[f'gt_combined_endpoints_{res}'] = cat_or_empty(all_combined_endpoints,(0,7),torch.float32)
|
| 985 |
+
collated[f'gt_combined_errors_{res}'] = cat_or_empty(all_combined_errors_list,(0,4),torch.float32)
|
| 986 |
+
|
| 987 |
+
collated[f'active_voxels_{res}'] = cat_or_empty(all_active_voxels_list, (0, 4), torch.int32)
|
| 988 |
+
|
| 989 |
+
if all_edge_indices_list:
|
| 990 |
+
collated[f'gt_vertex_edge_indices_{res}'] = torch.cat(all_edge_indices_list, dim=0)
|
| 991 |
+
else:
|
| 992 |
+
collated[f'gt_vertex_edge_indices_{res}'] = torch.empty((0, 2), dtype=torch.long, device=device)
|
| 993 |
+
|
| 994 |
+
if all_point_clouds_list:
|
| 995 |
+
collated[f'point_cloud_{res}'] = torch.stack(all_point_clouds_list, dim=0)
|
| 996 |
+
else:
|
| 997 |
+
# collated[f'point_cloud_{res}'] = torch.empty((0, 6), dtype=torch.float32, device=device)
|
| 998 |
+
collated[f'point_cloud_{res}'] = torch.empty((0, 15), dtype=torch.float32, device=device)
|
| 999 |
+
|
| 1000 |
+
return collated
|
debug_viz/step_0_batch_0.ply
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
debug_viz/step_0_batch_1.ply
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
filter_active_voxels.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 5 |
+
|
| 6 |
+
# === 配置 ===
|
| 7 |
+
cache_dir = "/gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to1024_819200"
|
| 8 |
+
|
| 9 |
+
# 1. Edge Voxels (512分辨率) 的筛选阈值
|
| 10 |
+
target_res_edge = 512
|
| 11 |
+
min_edge_voxels = 2000
|
| 12 |
+
max_edge_voxels = 75000
|
| 13 |
+
|
| 14 |
+
# 2. Active Voxels (64分辨率) 的筛选阈值
|
| 15 |
+
# 请根据你的需求调整这两个数值
|
| 16 |
+
target_res_active = 128
|
| 17 |
+
min_active_voxels = 2000 # 举例:最少要有100个粗糙体素
|
| 18 |
+
max_active_voxels = 326780 # 举例:最多8000个粗糙体素
|
| 19 |
+
|
| 20 |
+
save_txt_path = f"/gemini/user/private/zhaotianhao/Triposf/MERGED_DATASET_filtered_{min_edge_voxels}-{max_edge_voxels}edge_{min_active_voxels}-{max_active_voxels}active.txt"
|
| 21 |
+
|
| 22 |
+
# === 单文件统计函数 ===
|
| 23 |
+
def check_voxel_counts(npz_path):
|
| 24 |
+
try:
|
| 25 |
+
# 打开 npz 文件
|
| 26 |
+
with np.load(npz_path) as data:
|
| 27 |
+
# 键名定义
|
| 28 |
+
key_edge = f"combined_voxels_{target_res_edge}"
|
| 29 |
+
key_active = f"active_voxels_{target_res_active}"
|
| 30 |
+
|
| 31 |
+
# 检查键是否存在
|
| 32 |
+
if key_edge not in data or key_active not in data:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
# 获取数量
|
| 36 |
+
count_edge = len(data[key_edge])
|
| 37 |
+
count_active = len(data[key_active])
|
| 38 |
+
|
| 39 |
+
# === 核心筛选逻辑 (同时满足两个条件) ===
|
| 40 |
+
is_edge_valid = min_edge_voxels <= count_edge <= max_edge_voxels
|
| 41 |
+
is_active_valid = min_active_voxels <= count_active <= max_active_voxels
|
| 42 |
+
|
| 43 |
+
if is_edge_valid and is_active_valid:
|
| 44 |
+
base_name = os.path.basename(npz_path)
|
| 45 |
+
# 处理文件名
|
| 46 |
+
if base_name.endswith("_precombined.npz"):
|
| 47 |
+
original_name = base_name.replace("_precombined.npz", "")
|
| 48 |
+
else:
|
| 49 |
+
original_name = os.path.splitext(base_name)[0]
|
| 50 |
+
|
| 51 |
+
return (original_name, count_edge, count_active)
|
| 52 |
+
|
| 53 |
+
except Exception:
|
| 54 |
+
return None
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
# === 获取所有 NPZ 文件 ===
|
| 58 |
+
if not os.path.exists(cache_dir):
|
| 59 |
+
print(f"错误: 缓存目录不存在 {cache_dir}")
|
| 60 |
+
exit()
|
| 61 |
+
|
| 62 |
+
npz_files = [os.path.join(cache_dir, f) for f in os.listdir(cache_dir) if f.endswith(".npz")]
|
| 63 |
+
print(f"共发现 {len(npz_files)} 个缓存文件。开始并行过滤...")
|
| 64 |
+
print(f"筛选条件:")
|
| 65 |
+
print(f" - Edge (512): {min_edge_voxels} ~ {max_edge_voxels}")
|
| 66 |
+
print(f" - Active (64): {min_active_voxels} ~ {max_active_voxels}")
|
| 67 |
+
|
| 68 |
+
# === 并行过滤 ===
|
| 69 |
+
filtered_files = []
|
| 70 |
+
counts_edge = []
|
| 71 |
+
counts_active = []
|
| 72 |
+
|
| 73 |
+
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
|
| 74 |
+
futures = {executor.submit(check_voxel_counts, path): path for path in npz_files}
|
| 75 |
+
|
| 76 |
+
for future in tqdm(as_completed(futures), total=len(futures), desc="Filtering"):
|
| 77 |
+
result = future.result()
|
| 78 |
+
if result is not None:
|
| 79 |
+
fname, c_edge, c_active = result
|
| 80 |
+
filtered_files.append(fname)
|
| 81 |
+
counts_edge.append(c_edge)
|
| 82 |
+
counts_active.append(c_active)
|
| 83 |
+
|
| 84 |
+
# === 保存结果 ===
|
| 85 |
+
os.makedirs(os.path.dirname(save_txt_path), exist_ok=True)
|
| 86 |
+
with open(save_txt_path, "w") as f:
|
| 87 |
+
for fname in filtered_files:
|
| 88 |
+
f.write(f"{fname}\n")
|
| 89 |
+
|
| 90 |
+
# === 打印统计信息 ===
|
| 91 |
+
print(f"\n✅ 筛选完成:")
|
| 92 |
+
print(f" 符合条件的文件数: {len(filtered_files)} / {len(npz_files)} (保留率: {len(filtered_files)/len(npz_files)*100:.2f}%)")
|
| 93 |
+
|
| 94 |
+
if counts_edge:
|
| 95 |
+
print(f"\n[统计 - Edge Voxels (512)]")
|
| 96 |
+
print(f" 最小值: {min(counts_edge)}")
|
| 97 |
+
print(f" 最大值: {max(counts_edge)}")
|
| 98 |
+
print(f" 平均值: {np.mean(counts_edge):.2f}")
|
| 99 |
+
|
| 100 |
+
if counts_active:
|
| 101 |
+
print(f"\n[统计 - Active Voxels (64)]")
|
| 102 |
+
print(f" 最小值: {min(counts_active)}")
|
| 103 |
+
print(f" 最大值: {max(counts_active)}")
|
| 104 |
+
print(f" 平均值: {np.mean(counts_active):.2f}")
|
| 105 |
+
|
| 106 |
+
print(f"\n 结果已保存到: {save_txt_path}")
|
generate_npz.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import yaml
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
# Assuming your custom modules are in the same directory or in the Python path
|
| 10 |
+
# from dataset import VoxelVertexDataset_edge, collate_fn_edge
|
| 11 |
+
from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 12 |
+
|
| 13 |
+
def inspect_batch(batch, batch_idx, device):
|
| 14 |
+
"""
|
| 15 |
+
A detailed function to inspect and print information about a single batch.
|
| 16 |
+
"""
|
| 17 |
+
print(f"\n{'='*20} Inspecting Batch {batch_idx} {'='*20}")
|
| 18 |
+
|
| 19 |
+
# if batch is None:
|
| 20 |
+
# print("Batch is None. Skipping.")
|
| 21 |
+
# return
|
| 22 |
+
|
| 23 |
+
# print("Batch contains the following keys:")
|
| 24 |
+
# for key in batch.keys():
|
| 25 |
+
# print(f" - {key}")
|
| 26 |
+
|
| 27 |
+
# print(f"{'='*58}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
"""
|
| 32 |
+
Main function to load configuration, set up the dataset,
|
| 33 |
+
and process a few batches for inspection.
|
| 34 |
+
"""
|
| 35 |
+
import argparse
|
| 36 |
+
parser = argparse.ArgumentParser(description="Process and inspect data from the VoxelVertexDataset.")
|
| 37 |
+
# parser.add_argument('config_path', type=str, help='Path to the configuration YAML file.')
|
| 38 |
+
parser.add_argument('--num_batches', type=int, default=3, help='Number of batches to inspect.')
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
# 1. Load Configuration
|
| 42 |
+
# print(f"Loading configuration from: {args.config_path}")
|
| 43 |
+
# with open(args.config_path) as f:
|
| 44 |
+
# cfg = yaml.safe_load(f)
|
| 45 |
+
|
| 46 |
+
# 2. Initialize Device
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
|
| 50 |
+
# 3. Initialize Dataset
|
| 51 |
+
print("Initializing dataset...")
|
| 52 |
+
# dataset = VoxelVertexDataset_edge(
|
| 53 |
+
# root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/final_data_decimate_2',
|
| 54 |
+
# base_resolution=512,
|
| 55 |
+
# min_resolution=64,
|
| 56 |
+
# cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/final_data_decimate_60w_2',
|
| 57 |
+
# renders_dir=None,
|
| 58 |
+
# )
|
| 59 |
+
dataset = VoxelVertexDataset_edge(
|
| 60 |
+
root_dir='/root/mesh_split_200complex/mesh_split_200complex_train',
|
| 61 |
+
base_resolution=512,
|
| 62 |
+
min_resolution=64,
|
| 63 |
+
cache_dir='/root/Trisf/dataset_cache/objaverse_200_2000_filtered_final_8354files_512to512',
|
| 64 |
+
renders_dir=None,
|
| 65 |
+
|
| 66 |
+
filter_active_voxels=False,
|
| 67 |
+
cache_filter_path='',
|
| 68 |
+
|
| 69 |
+
active_voxel_res=512,
|
| 70 |
+
|
| 71 |
+
sample_type='dora',
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# dataset = VoxelVertexDataset_edge(
|
| 75 |
+
# root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/03001627',
|
| 76 |
+
# base_resolution=512,
|
| 77 |
+
# min_resolution=64,
|
| 78 |
+
# cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/03001627',
|
| 79 |
+
# renders_dir=None,
|
| 80 |
+
# )
|
| 81 |
+
# dataset = VoxelVertexDataset_edge(
|
| 82 |
+
# root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/03636649',
|
| 83 |
+
# base_resolution=512,
|
| 84 |
+
# min_resolution=64,
|
| 85 |
+
# cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/03636649',
|
| 86 |
+
# renders_dir=None,
|
| 87 |
+
# )
|
| 88 |
+
# dataset = VoxelVertexDataset_edge(
|
| 89 |
+
# root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/04379243',
|
| 90 |
+
# base_resolution=512,
|
| 91 |
+
# min_resolution=64,
|
| 92 |
+
# cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/04379243',
|
| 93 |
+
# renders_dir=None,
|
| 94 |
+
# )
|
| 95 |
+
print(f"Dataset initialized with {len(dataset)} samples.")
|
| 96 |
+
|
| 97 |
+
# 4. Initialize DataLoader
|
| 98 |
+
# We don't need a DistributedSampler here, just a regular DataLoader.
|
| 99 |
+
print("Initializing DataLoader...")
|
| 100 |
+
dataloader = DataLoader(
|
| 101 |
+
dataset,
|
| 102 |
+
batch_size=1,
|
| 103 |
+
shuffle=False, # Shuffle for a random sample of batches
|
| 104 |
+
collate_fn=partial(collate_fn_pointnet,),
|
| 105 |
+
num_workers=24,
|
| 106 |
+
pin_memory=True,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 5. Data Processing Loop
|
| 110 |
+
print(f"\nStarting data inspection loop for {args.num_batches} batches...")
|
| 111 |
+
for i, batch in enumerate(dataloader):
|
| 112 |
+
inspect_batch(batch, i, device)
|
| 113 |
+
|
| 114 |
+
print("\nData inspection complete.")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == '__main__':
|
| 118 |
+
main()
|
mesh_augment.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import open3d as o3d
|
| 4 |
+
import trimesh
|
| 5 |
+
import random
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
def augment_obj_file(input_path, output_path, n_augmentations=5):
|
| 10 |
+
"""
|
| 11 |
+
Augment an OBJ file with random transformations
|
| 12 |
+
Args:
|
| 13 |
+
input_path: Path to input OBJ file
|
| 14 |
+
output_path: Directory to save augmented files
|
| 15 |
+
n_augmentations: Number of augmented copies to create
|
| 16 |
+
"""
|
| 17 |
+
# Create output directory if it doesn't exist
|
| 18 |
+
os.makedirs(output_path, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
# Load the original mesh
|
| 21 |
+
mesh = trimesh.load(input_path)
|
| 22 |
+
original_name = Path(input_path).stem
|
| 23 |
+
|
| 24 |
+
for i in range(n_augmentations):
|
| 25 |
+
# Create a copy of the original mesh
|
| 26 |
+
augmented_mesh = mesh.copy()
|
| 27 |
+
|
| 28 |
+
# Random rotation (0-360 degrees around each axis)
|
| 29 |
+
angle_x = np.random.uniform(0, 2*np.pi)
|
| 30 |
+
angle_y = np.random.uniform(0, 2*np.pi)
|
| 31 |
+
angle_z = np.random.uniform(0, 2*np.pi)
|
| 32 |
+
rotation_matrix = trimesh.transformations.euler_matrix(angle_x, angle_y, angle_z)
|
| 33 |
+
augmented_mesh.apply_transform(rotation_matrix)
|
| 34 |
+
|
| 35 |
+
# Random scaling (0.8-1.2 range)
|
| 36 |
+
scale_factor = np.random.uniform(0.8, 1.2, size=3)
|
| 37 |
+
scale_matrix = np.eye(4)
|
| 38 |
+
scale_matrix[:3, :3] *= scale_factor
|
| 39 |
+
augmented_mesh.apply_transform(scale_matrix)
|
| 40 |
+
|
| 41 |
+
# Random translation (-0.1 to 0.1 range in each dimension)
|
| 42 |
+
translation = np.random.uniform(-0.1, 0.1, size=3)
|
| 43 |
+
translation_matrix = np.eye(4)
|
| 44 |
+
translation_matrix[:3, 3] = translation
|
| 45 |
+
augmented_mesh.apply_transform(translation_matrix)
|
| 46 |
+
|
| 47 |
+
# Save the augmented mesh
|
| 48 |
+
output_file = os.path.join(output_path, f"{original_name}_aug_{i}.obj")
|
| 49 |
+
augmented_mesh.export(output_file)
|
| 50 |
+
|
| 51 |
+
def augment_all_objs(source_dir, target_dir, n_augmentations=5):
|
| 52 |
+
"""
|
| 53 |
+
Augment all OBJ files in a directory
|
| 54 |
+
Args:
|
| 55 |
+
source_dir: Directory containing original OBJ files
|
| 56 |
+
target_dir: Directory to save augmented files
|
| 57 |
+
n_augmentations: Number of augmented copies per file
|
| 58 |
+
"""
|
| 59 |
+
# Get all OBJ files in source directory
|
| 60 |
+
obj_files = [f for f in os.listdir(source_dir) if f.endswith('.obj')]
|
| 61 |
+
|
| 62 |
+
print(f"Found {len(obj_files)} OBJ files to augment")
|
| 63 |
+
print(f"Will create {n_augmentations} augmented versions per file")
|
| 64 |
+
|
| 65 |
+
# Process each file
|
| 66 |
+
for obj_file in tqdm(obj_files, desc="Augmenting OBJ files"):
|
| 67 |
+
input_path = os.path.join(source_dir, obj_file)
|
| 68 |
+
augment_obj_file(input_path, target_dir, n_augmentations)
|
| 69 |
+
|
| 70 |
+
print(f"Finished! Augmented files saved to: {target_dir}")
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
# Configuration
|
| 74 |
+
N_AUGMENTATIONS = 10 # Number of augmented copies per file
|
| 75 |
+
SOURCE_DIR = "/root/shapenet_data/train_mesh_data_under_25kb_1000/train" # Directory with original OBJ files
|
| 76 |
+
TARGET_DIR = f"/root/shapenet_data/train_mesh_data_under_25kb_1000/train_{N_AUGMENTATIONS}augment" # Where to save augmented files
|
| 77 |
+
|
| 78 |
+
# Run augmentation
|
| 79 |
+
augment_all_objs(SOURCE_DIR, TARGET_DIR, N_AUGMENTATIONS)
|
metric.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import warnings
|
| 6 |
+
import trimesh
|
| 7 |
+
from scipy.stats import entropy
|
| 8 |
+
from sklearn.neighbors import NearestNeighbors
|
| 9 |
+
from numpy.linalg import norm
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
# 用户配置 (User Configuration)
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
# --- 路径设置 ---
|
| 16 |
+
# !!重要提示!!
|
| 17 |
+
# 当您计算真实指标时, 请确保这两个路径指向不同的文件夹
|
| 18 |
+
# 这里为了方便测试, 设置为相同路径。当两个路径相同时:
|
| 19 |
+
# MMD->0, COV->1.0, 1-NNA->0.5, JSD->0
|
| 20 |
+
# 而 CD 和 HD 会是一个很小的值, 代表同一mesh两次不同采样的差异。
|
| 21 |
+
|
| 22 |
+
# GENERATED_MESH_DIR = "/root/mesh_split_200complex/mesh_split_200complex_test" # 存放生成的 .obj 文件的文件夹路径
|
| 23 |
+
# GT_MESH_DIR = "/root/mesh_split_200complex/mesh_split_200complex_test" # 存放真实的 .obj 文件的文件夹路径
|
| 24 |
+
|
| 25 |
+
GENERATED_MESH_DIR = "/root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs" # 存放生成的 .obj 文件的文件夹路径
|
| 26 |
+
GT_MESH_DIR = "/root/Trisf/abalation_post_processing/gt_mesh" # 存放真实的 .obj 文件的文件夹路径
|
| 27 |
+
|
| 28 |
+
# --- 采样和计算参数 ---
|
| 29 |
+
NUM_POINTS_PER_MESH = 2048 # 从每个mesh表面采样的点数
|
| 30 |
+
BATCH_SIZE = 32 # 计算指标时使用的批次大小,根据显存调整
|
| 31 |
+
JSD_RESOLUTION = 28 # JSD计算中体素网格的分辨率
|
| 32 |
+
|
| 33 |
+
# ==============================================================================
|
| 34 |
+
# 核心功能函数: Mesh处理
|
| 35 |
+
# ==============================================================================
|
| 36 |
+
|
| 37 |
+
def process_meshes_in_folder(folder_path, num_points):
|
| 38 |
+
"""
|
| 39 |
+
加载文件夹中所有的 .obj 文件, 将它们采样成点云, 并进行归一化。
|
| 40 |
+
"""
|
| 41 |
+
# 按文件名排序以确保一一对应
|
| 42 |
+
mesh_files = sorted(glob.glob(os.path.join(folder_path, '*.obj')))
|
| 43 |
+
if not mesh_files:
|
| 44 |
+
raise FileNotFoundError(f"在文件夹 '{folder_path}' 中没有找到任何 .obj 文件。")
|
| 45 |
+
|
| 46 |
+
all_point_clouds = []
|
| 47 |
+
print(f"正在从 '{folder_path}' 处理 {len(mesh_files)} 个mesh...")
|
| 48 |
+
|
| 49 |
+
for mesh_path in tqdm(mesh_files, desc=f'处理 {os.path.basename(folder_path)}'):
|
| 50 |
+
try:
|
| 51 |
+
mesh = trimesh.load(mesh_path, process=False)
|
| 52 |
+
|
| 53 |
+
# 归一化: 移动到原点并缩放到单位球体内
|
| 54 |
+
center = mesh.bounds.mean(axis=0)
|
| 55 |
+
mesh.apply_translation(-center)
|
| 56 |
+
max_dist = np.max(np.linalg.norm(mesh.vertices, axis=1))
|
| 57 |
+
if max_dist > 0:
|
| 58 |
+
mesh.apply_scale(1.0 / max_dist)
|
| 59 |
+
|
| 60 |
+
points, _ = trimesh.sample.sample_surface(mesh, num_points)
|
| 61 |
+
|
| 62 |
+
if points.shape[0] != num_points:
|
| 63 |
+
# print(f"警告: {mesh_path} 采样点数 {points.shape[0]} != {num_points}, 进行重采样。")
|
| 64 |
+
indices = np.random.choice(points.shape[0], num_points, replace=True)
|
| 65 |
+
points = points[indices]
|
| 66 |
+
|
| 67 |
+
all_point_clouds.append(points)
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"错误:加载或处理文件 {mesh_path} 失败: {e}")
|
| 71 |
+
|
| 72 |
+
return np.array(all_point_clouds)
|
| 73 |
+
|
| 74 |
+
# ==============================================================================
|
| 75 |
+
# 评估指标代码 (来自 PointFlow 及新增)
|
| 76 |
+
# ==============================================================================
|
| 77 |
+
|
| 78 |
+
_EMD_NOT_IMPL_WARNED = False
|
| 79 |
+
def emd_approx(sample, ref):
|
| 80 |
+
global _EMD_NOT_IMPL_WARNED
|
| 81 |
+
emd = torch.zeros([sample.size(0)]).to(sample)
|
| 82 |
+
if not _EMD_NOT_IMPL_WARNED:
|
| 83 |
+
_EMD_NOT_IMPL_WARNED = True
|
| 84 |
+
print('\n\n[WARNING] EMD is not implemented. Setting to zero.')
|
| 85 |
+
return emd
|
| 86 |
+
|
| 87 |
+
def distChamfer(a, b):
|
| 88 |
+
x, y = a, b
|
| 89 |
+
bs, num_points, points_dim = x.size()
|
| 90 |
+
xx = torch.bmm(x, x.transpose(2, 1))
|
| 91 |
+
yy = torch.bmm(y, y.transpose(2, 1))
|
| 92 |
+
zz = torch.bmm(x, y.transpose(2, 1))
|
| 93 |
+
diag_ind = torch.arange(0, num_points, device=a.device).long()
|
| 94 |
+
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
|
| 95 |
+
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
|
| 96 |
+
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
| 97 |
+
# P is batch_size x n_points x n_points matrix of squared distances
|
| 98 |
+
return P.min(1)[0], P.min(2)[0]
|
| 99 |
+
|
| 100 |
+
def compute_cd_hd(sample_pcs, ref_pcs, batch_size):
|
| 101 |
+
"""
|
| 102 |
+
计算平均成对的Chamfer Distance (CD) 和 Hausdorff Distance (HD)。
|
| 103 |
+
"""
|
| 104 |
+
print("\n--- 开始计算 平均Chamfer和Hausdorff距离 ---")
|
| 105 |
+
N_sample = sample_pcs.shape[0]
|
| 106 |
+
N_ref = ref_pcs.shape[0]
|
| 107 |
+
|
| 108 |
+
assert N_sample == N_ref, f"用于成对度量计算的集合大小必须相等, 但得到 {N_sample} 和 {N_ref}"
|
| 109 |
+
|
| 110 |
+
cd_all = []
|
| 111 |
+
hd_all = []
|
| 112 |
+
|
| 113 |
+
iterator = range(0, N_sample, batch_size)
|
| 114 |
+
for b_start in tqdm(iterator, desc='计算 CD/HD'):
|
| 115 |
+
b_end = min(N_sample, b_start + batch_size)
|
| 116 |
+
sample_batch = sample_pcs[b_start:b_end]
|
| 117 |
+
ref_batch = ref_pcs[b_start:b_end]
|
| 118 |
+
|
| 119 |
+
# distChamfer返回的是平方距离
|
| 120 |
+
dist1_sq, dist2_sq = distChamfer(sample_batch, ref_batch)
|
| 121 |
+
|
| 122 |
+
# 计算 Chamfer Distance
|
| 123 |
+
cd_batch = dist1_sq.mean(dim=1) + dist2_sq.mean(dim=1)
|
| 124 |
+
cd_all.append(cd_batch)
|
| 125 |
+
|
| 126 |
+
# 计算 Hausdorff Distance
|
| 127 |
+
# HD = max(max(min_dist_1), max(min_dist_2))
|
| 128 |
+
# 我们需要对平方距离开方来得到真实距离
|
| 129 |
+
hd_batch = torch.max(dist1_sq.max(dim=1)[0], dist2_sq.max(dim=1)[0]).sqrt()
|
| 130 |
+
hd_all.append(hd_batch)
|
| 131 |
+
|
| 132 |
+
cd_all = torch.cat(cd_all)
|
| 133 |
+
hd_all = torch.cat(hd_all)
|
| 134 |
+
|
| 135 |
+
results = {
|
| 136 |
+
'Chamfer-L2': cd_all.mean(),
|
| 137 |
+
'Hausdorff': hd_all.mean(),
|
| 138 |
+
}
|
| 139 |
+
return results
|
| 140 |
+
|
| 141 |
+
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, verbose=True):
|
| 142 |
+
N_sample = sample_pcs.shape[0]
|
| 143 |
+
N_ref = ref_pcs.shape[0]
|
| 144 |
+
all_cd = []
|
| 145 |
+
iterator = range(N_sample)
|
| 146 |
+
if verbose:
|
| 147 |
+
iterator = tqdm(iterator, desc='计算点云间距离')
|
| 148 |
+
for i in iterator:
|
| 149 |
+
sample_batch = sample_pcs[i]
|
| 150 |
+
cd_lst = []
|
| 151 |
+
sub_iterator = range(0, N_ref, batch_size)
|
| 152 |
+
for b_start in sub_iterator:
|
| 153 |
+
b_end = min(N_ref, b_start + batch_size)
|
| 154 |
+
ref_batch = ref_pcs[b_start:b_end]
|
| 155 |
+
batch_size_ref = ref_batch.size(0)
|
| 156 |
+
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1).contiguous()
|
| 157 |
+
dl, dr = distChamfer(sample_batch_exp, ref_batch)
|
| 158 |
+
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
|
| 159 |
+
cd_lst = torch.cat(cd_lst, dim=1)
|
| 160 |
+
all_cd.append(cd_lst)
|
| 161 |
+
all_cd = torch.cat(all_cd, dim=0)
|
| 162 |
+
# EMD is not implemented, so we return a dummy tensor for it
|
| 163 |
+
all_emd = torch.zeros_like(all_cd)
|
| 164 |
+
return all_cd, all_emd
|
| 165 |
+
|
| 166 |
+
def knn(Mxx, Mxy, Myy, k, sqrt=False):
|
| 167 |
+
n0, n1 = Mxx.size(0), Myy.size(0)
|
| 168 |
+
device = Mxx.device
|
| 169 |
+
|
| 170 |
+
ones_tensor = torch.ones(n0, device=device)
|
| 171 |
+
zeros_tensor = torch.zeros(n1, device=device)
|
| 172 |
+
label = torch.cat((ones_tensor, zeros_tensor))
|
| 173 |
+
|
| 174 |
+
M = torch.cat([torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.t(), Myy), 1)], 0)
|
| 175 |
+
if sqrt: M = M.abs().sqrt()
|
| 176 |
+
|
| 177 |
+
diag_inf = torch.diag(torch.full((n0 + n1,), float('inf'), device=device))
|
| 178 |
+
val, idx = (M + diag_inf).topk(k, 0, False)
|
| 179 |
+
|
| 180 |
+
count = torch.zeros(n0 + n1, device=device)
|
| 181 |
+
for i in range(k):
|
| 182 |
+
count.add_(label.index_select(0, idx[i]))
|
| 183 |
+
|
| 184 |
+
threshold = torch.full((n0 + n1,), float(k) / 2, device=device)
|
| 185 |
+
pred = (count >= threshold).float()
|
| 186 |
+
|
| 187 |
+
return {'acc': (label == pred).float().mean()}
|
| 188 |
+
|
| 189 |
+
def lgan_mmd_cov(all_dist):
|
| 190 |
+
N_sample, N_ref = all_dist.shape
|
| 191 |
+
min_val, min_idx = all_dist.min(dim=1) # For each sample, find closest ref
|
| 192 |
+
mmd_smp = min_val.mean() # MMD-smp
|
| 193 |
+
|
| 194 |
+
min_val_ref, _ = all_dist.min(dim=0) # For each ref, find closest sample
|
| 195 |
+
mmd = min_val_ref.mean() # MMD-ref
|
| 196 |
+
|
| 197 |
+
cov = min_idx.unique().numel() / float(N_ref)
|
| 198 |
+
cov = torch.tensor(cov, device=all_dist.device)
|
| 199 |
+
|
| 200 |
+
return {'lgan_mmd': mmd, 'lgan_cov': cov}
|
| 201 |
+
|
| 202 |
+
def compute_mmd_cov_1nna(sample_pcs, ref_pcs, batch_size):
|
| 203 |
+
results = {}
|
| 204 |
+
print("\n--- 开始计算 MMD-CD, COV-CD, 1-NNA-CD ---")
|
| 205 |
+
|
| 206 |
+
M_rs_cd, _ = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) # ref vs sample
|
| 207 |
+
|
| 208 |
+
res_cd = lgan_mmd_cov(M_rs_cd.t()) # Transpose to get sample vs ref
|
| 209 |
+
results.update({f"{k}-CD": v for k, v in res_cd.items()})
|
| 210 |
+
|
| 211 |
+
M_rr_cd, _ = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
|
| 212 |
+
M_ss_cd, _ = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
|
| 213 |
+
|
| 214 |
+
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1)
|
| 215 |
+
results.update({"1-NNA-CD": one_nn_cd_res['acc']})
|
| 216 |
+
|
| 217 |
+
return results
|
| 218 |
+
|
| 219 |
+
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
| 220 |
+
grid = np.linspace(-0.5, 0.5, resolution)
|
| 221 |
+
x, y, z = np.meshgrid(grid, grid, grid, indexing='ij')
|
| 222 |
+
grid = np.stack([x, y, z], axis=-1).reshape(-1, 3)
|
| 223 |
+
if clip_sphere:
|
| 224 |
+
grid = grid[norm(grid, axis=1) <= 0.5]
|
| 225 |
+
return grid
|
| 226 |
+
|
| 227 |
+
def entropy_of_occupancy_grid(pclouds, grid_resolution):
|
| 228 |
+
grid_coords = unit_cube_grid_point_cloud(grid_resolution, True)
|
| 229 |
+
grid_counters = np.zeros(len(grid_coords))
|
| 230 |
+
nn = NearestNeighbors(n_neighbors=1).fit(grid_coords)
|
| 231 |
+
|
| 232 |
+
for pc in tqdm(pclouds, desc='计算占据网格'):
|
| 233 |
+
_, indices = nn.kneighbors(pc)
|
| 234 |
+
indices = np.unique(indices.squeeze())
|
| 235 |
+
grid_counters[indices] += 1
|
| 236 |
+
return grid_counters
|
| 237 |
+
|
| 238 |
+
def jensen_shannon_divergence(P, Q):
|
| 239 |
+
P_ = P / (P.sum() + 1e-9)
|
| 240 |
+
Q_ = Q / (Q.sum() + 1e-9)
|
| 241 |
+
M = 0.5 * (P_ + Q_)
|
| 242 |
+
return 0.5 * (entropy(P_, M, base=2) + entropy(Q_, M, base=2))
|
| 243 |
+
|
| 244 |
+
def compute_jsd(sample_pcs, ref_pcs, resolution):
|
| 245 |
+
print("\n--- 开始计算 JSD ---")
|
| 246 |
+
sample_grid_dist = entropy_of_occupancy_grid(sample_pcs, resolution)
|
| 247 |
+
ref_grid_dist = entropy_of_occupancy_grid(ref_pcs, resolution)
|
| 248 |
+
jsd = jensen_shannon_divergence(sample_grid_dist, ref_grid_dist)
|
| 249 |
+
return jsd
|
| 250 |
+
|
| 251 |
+
# ==============================================================================
|
| 252 |
+
# 主执行函数 (Main Execution)
|
| 253 |
+
# ==============================================================================
|
| 254 |
+
if __name__ == '__main__':
|
| 255 |
+
# 1. 加载并处理Meshes为点云 (Numpy arrays)
|
| 256 |
+
sample_pcs_np = process_meshes_in_folder(GENERATED_MESH_DIR, NUM_POINTS_PER_MESH)
|
| 257 |
+
ref_pcs_np = process_meshes_in_folder(GT_MESH_DIR, NUM_POINTS_PER_MESH)
|
| 258 |
+
|
| 259 |
+
print(f"\n加载完成: {sample_pcs_np.shape[0]} 个生成点云, {ref_pcs_np.shape[0]} 个真实点云。")
|
| 260 |
+
print(f"每个点云包含 {sample_pcs_np.shape[1]} 个点。")
|
| 261 |
+
|
| 262 |
+
# 2. 设置设备并转换数据为PyTorch Tensors
|
| 263 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 264 |
+
print(f"使用设备: {device}")
|
| 265 |
+
|
| 266 |
+
sample_pcs_torch = torch.from_numpy(sample_pcs_np).float().to(device)
|
| 267 |
+
ref_pcs_torch = torch.from_numpy(ref_pcs_np).float().to(device)
|
| 268 |
+
|
| 269 |
+
# 3. 计算分布度量: MMD, COV, 1-NNA (使用PyTorch)
|
| 270 |
+
metrics_results = compute_mmd_cov_1nna(sample_pcs_torch, ref_pcs_torch, BATCH_SIZE)
|
| 271 |
+
|
| 272 |
+
# 4. 计算成对几何度量: CD, HD (使用PyTorch)
|
| 273 |
+
cd_hd_results = compute_cd_hd(sample_pcs_torch, ref_pcs_torch, BATCH_SIZE)
|
| 274 |
+
metrics_results.update(cd_hd_results) # 合并结果
|
| 275 |
+
|
| 276 |
+
# 5. 计算JSD (使用Numpy)
|
| 277 |
+
jsd_result = compute_jsd(sample_pcs_np, ref_pcs_np, JSD_RESOLUTION)
|
| 278 |
+
|
| 279 |
+
# 6. 打印最终结果
|
| 280 |
+
print("\n==================================================")
|
| 281 |
+
print(" 评估结果")
|
| 282 |
+
print("==================================================")
|
| 283 |
+
|
| 284 |
+
print("\n--- 分布质量与多样性 (Distribution Metrics) ---")
|
| 285 |
+
# MMD: 越低越好 (质量)
|
| 286 |
+
print(f"{'lgan_mmd-CD':<12s}: {metrics_results['lgan_mmd-CD'].item():.6f} (↓ Lower is better)")
|
| 287 |
+
# COV: 越高越好 (多样性)
|
| 288 |
+
print(f"{'lgan_cov-CD':<12s}: {metrics_results['lgan_cov-CD'].item():.6f} (↑ Higher is better)")
|
| 289 |
+
# 1-NNA: 越接近0.5越好 (真实性)
|
| 290 |
+
print(f"{'1-NNA-CD':<12s}: {metrics_results['1-NNA-CD'].item():.6f} (→ Closer to 0.5 is better)")
|
| 291 |
+
# JSD: 越低越好 (分布相似性)
|
| 292 |
+
print(f"{'JSD':<12s}: {jsd_result:.6f} (↓ Lower is better)")
|
| 293 |
+
|
| 294 |
+
print("\n--- 平均几何保真度 (Average Geometric Fidelity) ---")
|
| 295 |
+
# CD: 越低越好
|
| 296 |
+
print(f"{'Chamfer-L2':<12s}: {metrics_results['Chamfer-L2'].item():.6f} (↓ Lower is better)")
|
| 297 |
+
# HD: 越低越好
|
| 298 |
+
print(f"{'Hausdorff':<12s}: {metrics_results['Hausdorff'].item():.6f} (↓ Lower is better)")
|
| 299 |
+
|
| 300 |
+
print("==================================================")
|
metric_cd.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import trimesh
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# =====================================================
|
| 9 |
+
# 🔹 Mesh归一化函数
|
| 10 |
+
# =====================================================
|
| 11 |
+
def normalize_to_unit_sphere(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
| 12 |
+
"""将mesh平移到原点并缩放到单位球内"""
|
| 13 |
+
vertices = mesh.vertices
|
| 14 |
+
centroid = vertices.mean(axis=0)
|
| 15 |
+
vertices = vertices - centroid
|
| 16 |
+
scale = np.max(np.linalg.norm(vertices, axis=1))
|
| 17 |
+
vertices = vertices / scale
|
| 18 |
+
mesh.vertices = vertices
|
| 19 |
+
return mesh
|
| 20 |
+
|
| 21 |
+
def normalize_to_unit_cube(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
| 22 |
+
"""将mesh平移并缩放到[-1,1]^3单位立方体内"""
|
| 23 |
+
bbox_min, bbox_max = mesh.bounds
|
| 24 |
+
center = (bbox_min + bbox_max) / 2
|
| 25 |
+
scale = (bbox_max - bbox_min).max() / 2
|
| 26 |
+
mesh.vertices = (mesh.vertices - center) / scale
|
| 27 |
+
return mesh
|
| 28 |
+
|
| 29 |
+
# =====================================================
|
| 30 |
+
# 🔹 点云采样函数 + 返回面片数
|
| 31 |
+
# =====================================================
|
| 32 |
+
def sample_points_from_mesh(mesh_path: str, num_points: int, normalize: str = "none"):
|
| 33 |
+
"""
|
| 34 |
+
从mesh文件采样点云,并可选归一化。
|
| 35 |
+
返回: (points: Tensor, face_count: int)
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
mesh = trimesh.load(mesh_path, force='mesh', process=False)
|
| 39 |
+
if normalize == "sphere":
|
| 40 |
+
mesh = normalize_to_unit_sphere(mesh)
|
| 41 |
+
elif normalize == "cube":
|
| 42 |
+
mesh = normalize_to_unit_cube(mesh)
|
| 43 |
+
points, _ = trimesh.sample.sample_surface(mesh, num_points)
|
| 44 |
+
face_count = len(mesh.faces)
|
| 45 |
+
return torch.from_numpy(points).float(), face_count
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"[-] 警告:加载或采样文件失败 {mesh_path}。错误: {e}")
|
| 48 |
+
return None, 0
|
| 49 |
+
|
| 50 |
+
# =====================================================
|
| 51 |
+
# 🔹 Chamfer Distance 计算函数
|
| 52 |
+
# =====================================================
|
| 53 |
+
def find_minimum_cd_batched(gen_pc: torch.Tensor, gt_pcs_batch: torch.Tensor):
|
| 54 |
+
"""计算生成点云到一批GT点云的最小CD及对应索引"""
|
| 55 |
+
gen_pc_batch = gen_pc.unsqueeze(0).expand(gt_pcs_batch.size(0), -1, -1)
|
| 56 |
+
dist_matrix = torch.cdist(gen_pc_batch, gt_pcs_batch)
|
| 57 |
+
min_dist_gen_to_gt = dist_matrix.min(2).values.mean(1)
|
| 58 |
+
min_dist_gt_to_gen = dist_matrix.min(1).values.mean(1)
|
| 59 |
+
cd_scores_for_one_gen = min_dist_gen_to_gt + min_dist_gt_to_gen
|
| 60 |
+
min_cd, min_idx = cd_scores_for_one_gen.min(0)
|
| 61 |
+
return min_cd.item(), min_idx.item()
|
| 62 |
+
|
| 63 |
+
# =====================================================
|
| 64 |
+
# 🔹 主流程
|
| 65 |
+
# =====================================================
|
| 66 |
+
def main(args):
|
| 67 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 68 |
+
print(f"[*] 使用设备: {device}")
|
| 69 |
+
print(f"[*] 归一化模式: {args.normalize}")
|
| 70 |
+
|
| 71 |
+
# --- Step 1: 加载GT网格并采样 ---
|
| 72 |
+
print("[*] 正在预加载并采样所有GT网格...")
|
| 73 |
+
gt_files = sorted([f for f in os.listdir(args.gt_dir) if f.endswith(('.obj', '.ply', '.off'))])
|
| 74 |
+
if not gt_files:
|
| 75 |
+
print(f"[-] 错误: GT目录中未找到mesh文件: {args.gt_dir}")
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
gt_point_clouds, gt_faces_counts = [], []
|
| 79 |
+
for gt_filename in tqdm(gt_files, desc="预处理GT网格"):
|
| 80 |
+
gt_filepath = os.path.join(args.gt_dir, gt_filename)
|
| 81 |
+
pc, fnum = sample_points_from_mesh(gt_filepath, args.num_points, args.normalize)
|
| 82 |
+
if pc is not None:
|
| 83 |
+
gt_point_clouds.append(pc.to(device))
|
| 84 |
+
gt_faces_counts.append(fnum)
|
| 85 |
+
|
| 86 |
+
if not gt_point_clouds:
|
| 87 |
+
print("[-] 错误: 无法从任何GT文件采样点云。")
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
print(f"[*] 成功加载 {len(gt_point_clouds)} 个GT点云。")
|
| 91 |
+
|
| 92 |
+
# --- Step 2: 遍历生成的网格 ---
|
| 93 |
+
gen_files = sorted([f for f in os.listdir(args.generated_dir) if f.endswith(('.obj', '.ply', '.off'))])
|
| 94 |
+
if not gen_files:
|
| 95 |
+
print(f"[-] 错误: 生成目录中未找到mesh文件: {args.generated_dir}")
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
all_min_cd_scores = []
|
| 99 |
+
face_ratios = []
|
| 100 |
+
pred_faces_all = []
|
| 101 |
+
gt_faces_matched = []
|
| 102 |
+
|
| 103 |
+
for gen_filename in tqdm(gen_files, desc="评估生成的网格"):
|
| 104 |
+
gen_filepath = os.path.join(args.generated_dir, gen_filename)
|
| 105 |
+
gen_pc, gen_face_count = sample_points_from_mesh(gen_filepath, args.num_points, args.normalize)
|
| 106 |
+
if gen_pc is None:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
gen_pc = gen_pc.to(device)
|
| 110 |
+
batch_size = args.batch_size
|
| 111 |
+
min_cd_for_this_gen = float('inf')
|
| 112 |
+
matched_gt_idx = -1
|
| 113 |
+
|
| 114 |
+
for i in range(0, len(gt_point_clouds), batch_size):
|
| 115 |
+
gt_pcs_batch = torch.stack(gt_point_clouds[i:i + batch_size])
|
| 116 |
+
min_cd_in_batch, idx_in_batch = find_minimum_cd_batched(gen_pc, gt_pcs_batch)
|
| 117 |
+
if min_cd_in_batch < min_cd_for_this_gen:
|
| 118 |
+
min_cd_for_this_gen = min_cd_in_batch
|
| 119 |
+
matched_gt_idx = i + idx_in_batch
|
| 120 |
+
|
| 121 |
+
all_min_cd_scores.append(min_cd_for_this_gen)
|
| 122 |
+
if matched_gt_idx >= 0:
|
| 123 |
+
gt_face_count = gt_faces_counts[matched_gt_idx]
|
| 124 |
+
face_ratio = gen_face_count / gt_face_count if gt_face_count > 0 else 0
|
| 125 |
+
face_ratios.append(face_ratio)
|
| 126 |
+
pred_faces_all.append(gen_face_count)
|
| 127 |
+
gt_faces_matched.append(gt_face_count)
|
| 128 |
+
if not args.quiet:
|
| 129 |
+
print(f" -> {gen_filename}: 最小CD={min_cd_for_this_gen:.6f}, Pred面数={gen_face_count}, GT面数={gt_face_count}, 比值={face_ratio:.3f}")
|
| 130 |
+
|
| 131 |
+
# --- Step 3: 汇总 ---
|
| 132 |
+
if not all_min_cd_scores:
|
| 133 |
+
print("\n[-] 评估结束,但没有成功处理任何网格。")
|
| 134 |
+
else:
|
| 135 |
+
mean_min_cd = np.mean(all_min_cd_scores)
|
| 136 |
+
mean_face_ratio = np.mean(face_ratios) if face_ratios else 0
|
| 137 |
+
mean_pred_faces = np.mean(pred_faces_all) if pred_faces_all else 0
|
| 138 |
+
mean_gt_faces = np.mean(gt_faces_matched) if gt_faces_matched else 0
|
| 139 |
+
|
| 140 |
+
print("\n" + "="*70)
|
| 141 |
+
print(f"[*] 评估完成 (基于最小CD匹配)")
|
| 142 |
+
print(f"[*] 共评估 {len(all_min_cd_scores)} 个生成网格")
|
| 143 |
+
print(f"[*] 平均最小倒角距离 (Mean Min CD): {mean_min_cd:.6f}")
|
| 144 |
+
print(f"[*] 平均Pred面片数: {mean_pred_faces:.1f}")
|
| 145 |
+
print(f"[*] 平均GT面片数: {mean_gt_faces:.1f}")
|
| 146 |
+
print(f"[*] 平均面片比 (Pred/GT): {mean_face_ratio:.3f}")
|
| 147 |
+
print("="*70)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# =====================================================
|
| 151 |
+
# 🔹 命令行接口
|
| 152 |
+
# =====================================================
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
parser = argparse.ArgumentParser(description="评估生成mesh与GT集合的最小Chamfer Distance及面片数比")
|
| 155 |
+
|
| 156 |
+
parser.add_argument("--generated_dir", type=str, required=True, help="生成的mesh文件夹路径")
|
| 157 |
+
parser.add_argument("--gt_dir", type=str, required=True, help="GT网格文件夹路径")
|
| 158 |
+
parser.add_argument("--num_points", type=int, default=10000, help="每个mesh采样点数")
|
| 159 |
+
parser.add_argument("--batch_size", type=int, default=16, help="与多少个GT点云进行批处理比较")
|
| 160 |
+
parser.add_argument("--normalize", type=str, default="none", choices=["none", "sphere", "cube"], help="归一化模式: none | sphere | cube")
|
| 161 |
+
parser.add_argument("--quiet", action="store_true", help="静默模式,只输出最终平均CD")
|
| 162 |
+
|
| 163 |
+
args = parser.parse_args()
|
| 164 |
+
main(args)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
'''
|
| 170 |
+
# 不归一化
|
| 171 |
+
python metric_cd.py \
|
| 172 |
+
--generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs \
|
| 173 |
+
--gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
|
| 174 |
+
--num_points 4096 \
|
| 175 |
+
--normalize none
|
| 176 |
+
|
| 177 |
+
# 归一化到单位球
|
| 178 |
+
python metric_cd.py \
|
| 179 |
+
--generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs/0.8_1.5 \
|
| 180 |
+
--gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
|
| 181 |
+
--num_points 4096 \
|
| 182 |
+
--normalize sphere
|
| 183 |
+
|
| 184 |
+
# 归一化到单位立方体
|
| 185 |
+
python metric_cd.py \
|
| 186 |
+
--generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs \
|
| 187 |
+
--gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
|
| 188 |
+
--num_points 4096 \
|
| 189 |
+
--normalize cube
|
| 190 |
+
'''
|
query_point.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch import einsum
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from functools import partial
|
| 6 |
+
from timm.models.layers import DropPath
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
|
| 9 |
+
# ---- PE: NeRF-style Position Encoding ----
|
| 10 |
+
class Embedder:
|
| 11 |
+
def __init__(self, **kwargs):
|
| 12 |
+
self.kwargs = kwargs
|
| 13 |
+
self.create_embedding_fn()
|
| 14 |
+
|
| 15 |
+
def create_embedding_fn(self):
|
| 16 |
+
embed_fns = []
|
| 17 |
+
d = self.kwargs['input_dims']
|
| 18 |
+
out_dim = 0
|
| 19 |
+
if self.kwargs['include_input']:
|
| 20 |
+
embed_fns.append(self.identity_fn)
|
| 21 |
+
out_dim += d
|
| 22 |
+
|
| 23 |
+
max_freq = self.kwargs['max_freq_log2']
|
| 24 |
+
N_freqs = self.kwargs['num_freqs']
|
| 25 |
+
|
| 26 |
+
if self.kwargs['log_sampling']:
|
| 27 |
+
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
|
| 28 |
+
else:
|
| 29 |
+
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
|
| 30 |
+
|
| 31 |
+
for freq in freq_bands:
|
| 32 |
+
for p_fn in self.kwargs['periodic_fns']:
|
| 33 |
+
embed_fns.append(partial(self.periodic_fn, p_fn=p_fn, freq=freq))
|
| 34 |
+
out_dim += d
|
| 35 |
+
|
| 36 |
+
self.embed_fns = embed_fns
|
| 37 |
+
self.out_dim = out_dim
|
| 38 |
+
|
| 39 |
+
def identity_fn(self, x):
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
def periodic_fn(self, x, p_fn, freq):
|
| 43 |
+
return p_fn(x * freq)
|
| 44 |
+
|
| 45 |
+
def embed(self, inputs):
|
| 46 |
+
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
| 47 |
+
|
| 48 |
+
def get_embedder(multires, i=0):
|
| 49 |
+
if i == -1:
|
| 50 |
+
return nn.Identity(), 1
|
| 51 |
+
|
| 52 |
+
embed_kwargs = {
|
| 53 |
+
'include_input': True,
|
| 54 |
+
'input_dims': 1,
|
| 55 |
+
'max_freq_log2': multires-1,
|
| 56 |
+
'num_freqs': multires,
|
| 57 |
+
'log_sampling': True,
|
| 58 |
+
'periodic_fns': [torch.sin, torch.cos],
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
embedder_obj = Embedder(**embed_kwargs)
|
| 62 |
+
embed = embedder_obj.embed
|
| 63 |
+
return embed, embedder_obj.out_dim
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class PE_NeRF(nn.Module):
|
| 67 |
+
def __init__(self, out_channels=512, multires=10):
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
self.multires = multires
|
| 71 |
+
self.embed_fn, embed_dim_per_dim = get_embedder(multires) # per-dim embed
|
| 72 |
+
self.embed_dim = embed_dim_per_dim * 3 # since 3D: x, y, z
|
| 73 |
+
|
| 74 |
+
self.coor_embed = nn.Sequential(
|
| 75 |
+
nn.Linear(self.embed_dim, 256),
|
| 76 |
+
nn.GELU(),
|
| 77 |
+
nn.Linear(256, out_channels)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, vertices: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Args:
|
| 83 |
+
vertices: [B, 3] or [N, 3], coordinates in [-0.5, 0.5]
|
| 84 |
+
Returns:
|
| 85 |
+
encoded: [B, out_channels * 3]
|
| 86 |
+
"""
|
| 87 |
+
x_embed = self.embed_fn(vertices[..., 0:1]) # [N, D]
|
| 88 |
+
y_embed = self.embed_fn(vertices[..., 1:2])
|
| 89 |
+
z_embed = self.embed_fn(vertices[..., 2:3])
|
| 90 |
+
|
| 91 |
+
pos_enc = torch.cat([x_embed, y_embed, z_embed], dim=-1) # [N, D * 3]
|
| 92 |
+
|
| 93 |
+
return self.coor_embed(pos_enc)
|
| 94 |
+
|
| 95 |
+
def exists(val):
|
| 96 |
+
return val is not None
|
| 97 |
+
|
| 98 |
+
def default(val, d):
|
| 99 |
+
return val if exists(val) else d
|
| 100 |
+
|
| 101 |
+
# ---- Attention & FF blocks ----
|
| 102 |
+
class GEGLU(nn.Module):
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
x, gate = x.chunk(2, dim=-1)
|
| 105 |
+
return x * F.gelu(gate)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class FeedForward(nn.Module):
|
| 109 |
+
def __init__(self, dim, mult=4):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.net = nn.Sequential(
|
| 112 |
+
nn.Linear(dim, dim * mult * 2),
|
| 113 |
+
GEGLU(),
|
| 114 |
+
nn.Linear(dim * mult, dim)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
return self.net(x)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class PreNorm(nn.Module):
|
| 122 |
+
def __init__(self, dim, fn, context_dim = None):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.fn = fn
|
| 125 |
+
self.norm = nn.LayerNorm(dim)
|
| 126 |
+
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
|
| 127 |
+
|
| 128 |
+
def forward(self, x, **kwargs):
|
| 129 |
+
x = self.norm(x)
|
| 130 |
+
|
| 131 |
+
if exists(self.norm_context):
|
| 132 |
+
context = kwargs['context']
|
| 133 |
+
normed_context = self.norm_context(context)
|
| 134 |
+
kwargs.update(context = normed_context)
|
| 135 |
+
|
| 136 |
+
return self.fn(x, **kwargs)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Attention(nn.Module):
|
| 140 |
+
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
|
| 141 |
+
super().__init__()
|
| 142 |
+
inner_dim = dim_head * heads
|
| 143 |
+
context_dim = default(context_dim, query_dim)
|
| 144 |
+
self.scale = dim_head ** -0.5
|
| 145 |
+
self.heads = heads
|
| 146 |
+
|
| 147 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
|
| 148 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
| 149 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
| 150 |
+
|
| 151 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 152 |
+
|
| 153 |
+
def forward(self, x, context = None, mask = None):
|
| 154 |
+
h = self.heads
|
| 155 |
+
|
| 156 |
+
q = self.to_q(x)
|
| 157 |
+
context = default(context, x)
|
| 158 |
+
k, v = self.to_kv(context).chunk(2, dim = -1)
|
| 159 |
+
|
| 160 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
|
| 161 |
+
|
| 162 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 163 |
+
|
| 164 |
+
if exists(mask):
|
| 165 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
| 166 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 167 |
+
mask = repeat(mask, 'b j -> (b h) () j', h = h)
|
| 168 |
+
sim.masked_fill_(~mask, max_neg_value)
|
| 169 |
+
|
| 170 |
+
# attention, what we cannot get enough of
|
| 171 |
+
attn = sim.softmax(dim = -1)
|
| 172 |
+
|
| 173 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
| 174 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
| 175 |
+
return self.drop_path(self.to_out(out))
|
| 176 |
+
|
| 177 |
+
class QueryPointDecoder(nn.Module):
|
| 178 |
+
def __init__(self, query_dim=1536, context_dim=512, output_dim=1, depth=8,
|
| 179 |
+
using_nerf=True, quantize_bits=10, dim=512, heads=8, multires=10):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.using_nerf = using_nerf
|
| 182 |
+
self.depth = depth
|
| 183 |
+
|
| 184 |
+
if using_nerf:
|
| 185 |
+
self.pe = PE_NeRF(out_channels=query_dim, multires=multires)
|
| 186 |
+
else:
|
| 187 |
+
self.embedding_x = nn.Embedding(2**quantize_bits, query_dim // 3)
|
| 188 |
+
self.embedding_y = nn.Embedding(2**quantize_bits, query_dim // 3)
|
| 189 |
+
self.embedding_z = nn.Embedding(2**quantize_bits, query_dim // 3)
|
| 190 |
+
self.coord_proj = nn.Sequential(
|
| 191 |
+
nn.Linear(query_dim, query_dim * 4),
|
| 192 |
+
nn.GELU(),
|
| 193 |
+
nn.Linear(query_dim * 4, query_dim)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# self.context_proj = nn.Linear(context_dim, query_dim)
|
| 197 |
+
self.context_proj = nn.Linear(context_dim, dim)
|
| 198 |
+
|
| 199 |
+
self.pe_ctx = PE_NeRF(out_channels=dim, multires=multires)
|
| 200 |
+
|
| 201 |
+
self.context_self_attn_layers = nn.ModuleList([
|
| 202 |
+
nn.ModuleList([
|
| 203 |
+
PreNorm(dim, Attention(dim, dim_head=64, heads=heads)),
|
| 204 |
+
PreNorm(dim, FeedForward(dim))
|
| 205 |
+
]) for _ in range(depth)
|
| 206 |
+
])
|
| 207 |
+
|
| 208 |
+
self.cross_attn = PreNorm(dim,
|
| 209 |
+
Attention(dim, dim,
|
| 210 |
+
dim_head=dim, heads=1))
|
| 211 |
+
self.cross_ff = PreNorm(dim, FeedForward(dim))
|
| 212 |
+
|
| 213 |
+
self.to_outputs = nn.Linear(dim, output_dim)
|
| 214 |
+
|
| 215 |
+
def forward(self, query_points, context_feats, context_mask=None, voxels_coords=None,):
|
| 216 |
+
B, N, _ = query_points.shape
|
| 217 |
+
|
| 218 |
+
if self.using_nerf:
|
| 219 |
+
# print('query_points.min()', query_points.min())
|
| 220 |
+
# print('query_points.max()', query_points.max())
|
| 221 |
+
x = self.pe(query_points.view(-1, 3)).view(B, N, -1)
|
| 222 |
+
else:
|
| 223 |
+
embeddings = torch.cat([
|
| 224 |
+
self.embedding_x(query_points[..., 0]),
|
| 225 |
+
self.embedding_y(query_points[..., 1]),
|
| 226 |
+
self.embedding_z(query_points[..., 2]),
|
| 227 |
+
], dim=-1)
|
| 228 |
+
x = self.coord_proj(embeddings)
|
| 229 |
+
|
| 230 |
+
context = self.context_proj(context_feats)
|
| 231 |
+
|
| 232 |
+
if voxels_coords is not None:
|
| 233 |
+
M = voxels_coords.shape[1]
|
| 234 |
+
normalized_coords = 2.0 * (voxels_coords.float() / 1024.) - 1.0
|
| 235 |
+
context += self.pe_ctx(normalized_coords.view(-1, 3)).view(B, M, -1)
|
| 236 |
+
|
| 237 |
+
attn_mask = context_mask[:, None, None, :] if context_mask is not None else None
|
| 238 |
+
|
| 239 |
+
for self_attn, ff in self.context_self_attn_layers:
|
| 240 |
+
context = self_attn(context, mask=attn_mask) + context
|
| 241 |
+
context = ff(context) + context
|
| 242 |
+
|
| 243 |
+
latents = self.cross_attn(x, context=context, mask=attn_mask)
|
| 244 |
+
latents = self.cross_ff(x) + latents
|
| 245 |
+
|
| 246 |
+
return self.to_outputs(latents).squeeze(-1)
|
| 247 |
+
|
| 248 |
+
if __name__ == '__main__':
|
| 249 |
+
torch.manual_seed(42)
|
| 250 |
+
model = QueryPointDecoder().cuda()
|
| 251 |
+
model.eval()
|
| 252 |
+
|
| 253 |
+
B, N, M = 2, 64, 20
|
| 254 |
+
query_pts = torch.rand(B, N, 3).cuda() - 0.5 # [-0.5, 0.5]
|
| 255 |
+
context_feats = torch.randn(B, M, 512).cuda()
|
| 256 |
+
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
logits = model(query_pts, context_feats)
|
| 259 |
+
print("Logits shape:", logits.shape) # [B, N, 1]
|
test_slat_flow_128to1024_pointnet.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import yaml
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from functools import partial
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.amp import GradScaler, autocast
|
| 12 |
+
from typing import *
|
| 13 |
+
from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig, Dinov2Model, AutoImageProcessor, Dinov2Config
|
| 14 |
+
import torch
|
| 15 |
+
import re
|
| 16 |
+
from utils import load_pretrained_woself
|
| 17 |
+
|
| 18 |
+
from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 19 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 20 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet
|
| 21 |
+
|
| 22 |
+
from trellis.models.structured_latent_flow import SLatFlowModel
|
| 23 |
+
from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
|
| 24 |
+
|
| 25 |
+
from trellis.pipelines.samplers import FlowEulerSampler
|
| 26 |
+
from safetensors.torch import load_file
|
| 27 |
+
import open3d as o3d
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 31 |
+
from trellis.modules.sparse.basic import SparseTensor as SparseTensor_trellis
|
| 32 |
+
|
| 33 |
+
from triposf.modules.utils import DiagonalGaussianDistribution
|
| 34 |
+
|
| 35 |
+
from sklearn.decomposition import PCA
|
| 36 |
+
import trimesh
|
| 37 |
+
import torchvision.transforms as transforms
|
| 38 |
+
|
| 39 |
+
# --- Helper Functions ---
|
| 40 |
+
def save_colored_ply(points, colors, filename):
|
| 41 |
+
if len(points) == 0:
|
| 42 |
+
print(f"[Warning] No points to save for {filename}")
|
| 43 |
+
return
|
| 44 |
+
# Ensure colors are uint8
|
| 45 |
+
if colors.max() <= 1.0:
|
| 46 |
+
colors = (colors * 255).astype(np.uint8)
|
| 47 |
+
colors = colors.astype(np.uint8)
|
| 48 |
+
|
| 49 |
+
# Add Alpha if missing
|
| 50 |
+
if colors.shape[1] == 3:
|
| 51 |
+
colors = np.hstack([colors, np.full((len(colors), 1), 255, dtype=np.uint8)])
|
| 52 |
+
|
| 53 |
+
cloud = trimesh.PointCloud(points, colors=colors)
|
| 54 |
+
cloud.export(filename)
|
| 55 |
+
print(f"Saved colored point cloud to {filename}")
|
| 56 |
+
|
| 57 |
+
def normalize_to_rgb(features_3d):
|
| 58 |
+
min_vals = features_3d.min(axis=0)
|
| 59 |
+
max_vals = features_3d.max(axis=0)
|
| 60 |
+
range_vals = max_vals - min_vals
|
| 61 |
+
range_vals[range_vals == 0] = 1
|
| 62 |
+
normalized = (features_3d - min_vals) / range_vals
|
| 63 |
+
return (normalized * 255).astype(np.uint8)
|
| 64 |
+
|
| 65 |
+
class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
|
| 66 |
+
def __init__(self, *args, **kwargs):
|
| 67 |
+
super().__init__(*args, **kwargs)
|
| 68 |
+
self.cfg = kwargs.pop('cfg', None)
|
| 69 |
+
if self.cfg is None:
|
| 70 |
+
raise ValueError("Configuration dictionary 'cfg' must be provided.")
|
| 71 |
+
|
| 72 |
+
self.sampler = FlowEulerSampler(sigma_min=1.e-5)
|
| 73 |
+
self.device = torch.device("cuda")
|
| 74 |
+
|
| 75 |
+
# Based on PointNet Encoder setting
|
| 76 |
+
self.resolution = 128
|
| 77 |
+
|
| 78 |
+
self.condition_type = 'image'
|
| 79 |
+
self.is_cond = False
|
| 80 |
+
|
| 81 |
+
self.img_res = 518
|
| 82 |
+
self.feature_dim = self.cfg['model']['latent_dim']
|
| 83 |
+
|
| 84 |
+
self._init_components(
|
| 85 |
+
clip_model_path=self.cfg['training'].get('clip_model_path', None),
|
| 86 |
+
dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
|
| 87 |
+
vae_path=self.cfg['training']['vae_path'],
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Classifier head removed as it is not part of the Active Voxel pipeline
|
| 91 |
+
|
| 92 |
+
def _load_denoiser(self, denoiser_checkpoint_path):
|
| 93 |
+
path = denoiser_checkpoint_path
|
| 94 |
+
if not path or not os.path.isfile(path):
|
| 95 |
+
print("No valid checkpoint path provided for fine-tuning. Starting from scratch.")
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
print(f"Loading checkpoint from: {path}")
|
| 99 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
denoiser_state_dict = checkpoint['denoiser']
|
| 103 |
+
# Handle DDP prefix
|
| 104 |
+
if next(iter(denoiser_state_dict)).startswith('module.'):
|
| 105 |
+
denoiser_state_dict = {k[7:]: v for k, v in denoiser_state_dict.items()}
|
| 106 |
+
|
| 107 |
+
self.denoiser.load_state_dict(denoiser_state_dict)
|
| 108 |
+
print("Denoiser weights loaded successfully.")
|
| 109 |
+
except KeyError:
|
| 110 |
+
print("[WARN] 'denoiser' key not found in checkpoint. Skipping.")
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"[ERROR] Failed to load denoiser state_dict: {e}")
|
| 113 |
+
|
| 114 |
+
def _init_components(self,
|
| 115 |
+
clip_model_path=None,
|
| 116 |
+
dinov2_model_path=None,
|
| 117 |
+
vae_path=None,
|
| 118 |
+
):
|
| 119 |
+
|
| 120 |
+
# 1. Initialize PointNet Voxel Encoder (Matches Training)
|
| 121 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 122 |
+
in_channels=15,
|
| 123 |
+
hidden_dim=256,
|
| 124 |
+
out_channels=1024,
|
| 125 |
+
scatter_type='mean',
|
| 126 |
+
n_blocks=5,
|
| 127 |
+
resolution=128,
|
| 128 |
+
add_label=False,
|
| 129 |
+
).to(self.device)
|
| 130 |
+
|
| 131 |
+
# 2. Initialize VAE
|
| 132 |
+
self.vae = VoxelVAE(
|
| 133 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 134 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 135 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 136 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 137 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 138 |
+
num_heads=8,
|
| 139 |
+
num_head_channels=64,
|
| 140 |
+
mlp_ratio=4.0,
|
| 141 |
+
attn_mode="swin",
|
| 142 |
+
window_size=8,
|
| 143 |
+
pe_mode="ape",
|
| 144 |
+
use_fp16=False,
|
| 145 |
+
use_checkpoint=False,
|
| 146 |
+
qk_rms_norm=False,
|
| 147 |
+
using_subdivide=True,
|
| 148 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 149 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 150 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 151 |
+
).to(self.device)
|
| 152 |
+
|
| 153 |
+
# 3. Initialize Dataset with collate_fn_pointnet
|
| 154 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 155 |
+
root_dir=self.cfg['dataset']['path'],
|
| 156 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 157 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 158 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 159 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 160 |
+
|
| 161 |
+
process_img=False,
|
| 162 |
+
|
| 163 |
+
active_voxel_res=128,
|
| 164 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 165 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 166 |
+
sample_type=self.cfg['dataset'].get('sample_type', 'dora'),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.dataloader = DataLoader(
|
| 170 |
+
self.dataset,
|
| 171 |
+
batch_size=1,
|
| 172 |
+
shuffle=True,
|
| 173 |
+
collate_fn=partial(collate_fn_pointnet,), # Critical Change
|
| 174 |
+
num_workers=4,
|
| 175 |
+
pin_memory=True,
|
| 176 |
+
persistent_workers=True,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# 4. Load Pretrained Weights
|
| 180 |
+
# Assuming vae_path contains 'voxel_encoder' and 'vae'
|
| 181 |
+
print(f"Loading VAE/Encoder from {vae_path}")
|
| 182 |
+
ckpt = torch.load(vae_path, map_location='cpu')
|
| 183 |
+
|
| 184 |
+
# Load VAE
|
| 185 |
+
if 'vae' in ckpt:
|
| 186 |
+
self.vae.load_state_dict(ckpt['vae'], strict=False)
|
| 187 |
+
else:
|
| 188 |
+
self.vae.load_state_dict(ckpt) # Fallback
|
| 189 |
+
|
| 190 |
+
# Load Encoder
|
| 191 |
+
if 'voxel_encoder' in ckpt:
|
| 192 |
+
self.voxel_encoder.load_state_dict(ckpt['voxel_encoder'])
|
| 193 |
+
else:
|
| 194 |
+
print("[WARN] 'voxel_encoder' not found in checkpoint, random init (BAD for inference).")
|
| 195 |
+
|
| 196 |
+
self.voxel_encoder.eval()
|
| 197 |
+
self.vae.eval()
|
| 198 |
+
|
| 199 |
+
# 5. Initialize Conditioning Model
|
| 200 |
+
if self.condition_type == 'text':
|
| 201 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
|
| 202 |
+
self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
|
| 203 |
+
elif self.condition_type == 'image':
|
| 204 |
+
model_name = 'dinov2_vitl14_reg'
|
| 205 |
+
local_repo_path = "/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main"
|
| 206 |
+
weights_path = "/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
|
| 207 |
+
|
| 208 |
+
dinov2_model = torch.hub.load(
|
| 209 |
+
repo_or_dir=local_repo_path,
|
| 210 |
+
model=model_name,
|
| 211 |
+
source='local',
|
| 212 |
+
pretrained=False
|
| 213 |
+
)
|
| 214 |
+
self.condition_model = dinov2_model
|
| 215 |
+
self.condition_model.load_state_dict(torch.load(weights_path))
|
| 216 |
+
|
| 217 |
+
self.image_cond_model_transform = transforms.Compose([
|
| 218 |
+
transforms.ToTensor(),
|
| 219 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 220 |
+
])
|
| 221 |
+
else:
|
| 222 |
+
raise ValueError(f"Unsupported condition type: {self.condition_type}")
|
| 223 |
+
|
| 224 |
+
self.condition_model.to(self.device).eval()
|
| 225 |
+
|
| 226 |
+
@torch.no_grad()
|
| 227 |
+
def encode_image(self, images) -> torch.Tensor:
|
| 228 |
+
if isinstance(images, torch.Tensor):
|
| 229 |
+
batch_tensor = images.to(self.device)
|
| 230 |
+
elif isinstance(images, list):
|
| 231 |
+
assert all(isinstance(i, Image.Image) for i in images)
|
| 232 |
+
image = [i.resize((518, 518), Image.LANCZOS) for i in images]
|
| 233 |
+
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
| 234 |
+
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
| 235 |
+
batch_tensor = torch.stack(image).to(self.device)
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Unsupported type of image: {type(images)}")
|
| 238 |
+
|
| 239 |
+
if batch_tensor.shape[-2:] != (518, 518):
|
| 240 |
+
batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
|
| 241 |
+
|
| 242 |
+
features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
|
| 243 |
+
patchtokens = F.layer_norm(features, features.shape[-1:])
|
| 244 |
+
return patchtokens
|
| 245 |
+
|
| 246 |
+
def process_batch(self, batch):
|
| 247 |
+
preprocessed_images = batch['image']
|
| 248 |
+
cond_ = self.encode_image(preprocessed_images)
|
| 249 |
+
return cond_
|
| 250 |
+
|
| 251 |
+
def eval(self):
|
| 252 |
+
# Unconditional Setup
|
| 253 |
+
if self.is_cond == False:
|
| 254 |
+
if self.condition_type == 'text':
|
| 255 |
+
txt = ['']
|
| 256 |
+
encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
| 257 |
+
tokens = encoding['input_ids'].to(self.device)
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
cond_ = self.condition_model(input_ids=tokens).last_hidden_state
|
| 260 |
+
else:
|
| 261 |
+
blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
dummy_cond = self.encode_image([blank_img])
|
| 264 |
+
cond_ = torch.zeros_like(dummy_cond)
|
| 265 |
+
print(f"Generated unconditional image prompt (zero tensor) with shape: {cond_.shape}")
|
| 266 |
+
|
| 267 |
+
self.denoiser.eval()
|
| 268 |
+
|
| 269 |
+
# Load Denoiser Checkpoint
|
| 270 |
+
# Update this path to your ACTIVE VOXEL trained checkpoint
|
| 271 |
+
checkpoint_path = '/gemini/user/private/zhaotianhao/output_slat_flow_matching_active/ckpts/8780_complex_128to1024_rope/checkpoint_step90000_loss0_694290.pt'
|
| 272 |
+
|
| 273 |
+
self._load_denoiser(checkpoint_path)
|
| 274 |
+
|
| 275 |
+
filename = os.path.basename(checkpoint_path)
|
| 276 |
+
match = re.search(r'step(\d+)', filename)
|
| 277 |
+
step_str = match.group(1) if match else "eval"
|
| 278 |
+
save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_1000complex")
|
| 279 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 280 |
+
print(f"Results will be saved to: {save_dir}")
|
| 281 |
+
|
| 282 |
+
for i, batch in enumerate(self.dataloader):
|
| 283 |
+
if i > 50: exit() # Visualize first 10
|
| 284 |
+
|
| 285 |
+
if self.is_cond and self.condition_type == 'image':
|
| 286 |
+
cond_ = self.process_batch(batch)
|
| 287 |
+
|
| 288 |
+
if cond_.shape[0] != 1:
|
| 289 |
+
cond_ = cond_.expand(batch['active_voxels_128'].shape[0], -1, -1).contiguous().to(self.device)
|
| 290 |
+
else:
|
| 291 |
+
cond_ = cond_.to(self.device)
|
| 292 |
+
|
| 293 |
+
# --- Data Retrieval (Matches collate_fn_pointnet) ---
|
| 294 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 295 |
+
active_coords = batch['active_voxels_128'].to(self.device) # [N, 4]
|
| 296 |
+
|
| 297 |
+
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 298 |
+
with torch.no_grad():
|
| 299 |
+
# 1. Encode Ground Truth Latents
|
| 300 |
+
active_voxel_feats = self.voxel_encoder(
|
| 301 |
+
p=point_cloud,
|
| 302 |
+
sparse_coords=active_coords,
|
| 303 |
+
res=128,
|
| 304 |
+
bbox_size=(-0.5, 0.5),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
sparse_input = SparseTensor(
|
| 308 |
+
feats=active_voxel_feats,
|
| 309 |
+
coords=active_coords.int()
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Encode to get GT distribution
|
| 313 |
+
gt_latents, posterior = self.vae.encode(sparse_input)
|
| 314 |
+
|
| 315 |
+
print(f"Batch {i}: Active voxels: {active_coords.shape[0]}")
|
| 316 |
+
|
| 317 |
+
# 2. Generation / Sampling
|
| 318 |
+
# Generate noise on the SAME active coordinates
|
| 319 |
+
noise = SparseTensor_trellis(
|
| 320 |
+
coords=active_coords.int(),
|
| 321 |
+
feats=torch.randn(
|
| 322 |
+
active_coords.shape[0],
|
| 323 |
+
self.feature_dim,
|
| 324 |
+
device=self.device,
|
| 325 |
+
),
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
sample_results = self.sampler.sample(
|
| 329 |
+
model=self.denoiser.float(),
|
| 330 |
+
noise=noise.to(self.device).float(),
|
| 331 |
+
cond=cond_.to(self.device).float(),
|
| 332 |
+
steps=50,
|
| 333 |
+
rescale_t=1.0,
|
| 334 |
+
verbose=True,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
generated_sparse_tensor = sample_results.samples
|
| 338 |
+
generated_coords = generated_sparse_tensor.coords
|
| 339 |
+
generated_features = generated_sparse_tensor.feats
|
| 340 |
+
|
| 341 |
+
print('Gen features mean:', generated_features.mean().item(), 'std:', generated_features.std().item())
|
| 342 |
+
print('GT features mean:', gt_latents.feats.mean().item(), 'std:', gt_latents.feats.std().item())
|
| 343 |
+
print('MSE:', F.mse_loss(generated_features, gt_latents.feats).item())
|
| 344 |
+
|
| 345 |
+
# --- Visualization (PCA) ---
|
| 346 |
+
gt_feats_np = gt_latents.feats.detach().cpu().numpy()
|
| 347 |
+
gen_feats_np = generated_features.detach().cpu().numpy()
|
| 348 |
+
coords_np = active_coords[:, 1:4].detach().cpu().numpy() # x, y, z
|
| 349 |
+
|
| 350 |
+
print("Visualizing features using PCA...")
|
| 351 |
+
pca = PCA(n_components=3)
|
| 352 |
+
|
| 353 |
+
# Fit PCA on GT, transform both
|
| 354 |
+
pca.fit(gt_feats_np)
|
| 355 |
+
gt_feats_3d = pca.transform(gt_feats_np)
|
| 356 |
+
gen_feats_3d = pca.transform(gen_feats_np)
|
| 357 |
+
|
| 358 |
+
gt_colors = normalize_to_rgb(gt_feats_3d)
|
| 359 |
+
gen_colors = normalize_to_rgb(gen_feats_3d)
|
| 360 |
+
|
| 361 |
+
# Save PLYs
|
| 362 |
+
save_colored_ply(coords_np, gt_colors, os.path.join(save_dir, f"batch_{i}_gt_pca.ply"))
|
| 363 |
+
save_colored_ply(coords_np, gen_colors, os.path.join(save_dir, f"batch_{i}_gen_pca.ply"))
|
| 364 |
+
|
| 365 |
+
# Save Tensors for further analysis
|
| 366 |
+
torch.save(gt_latents, os.path.join(save_dir, f"gt_latent_{i}.pt"))
|
| 367 |
+
|
| 368 |
+
torch.save(batch, os.path.join(save_dir, f"gt_data_batch_{i}.pt"))
|
| 369 |
+
torch.save(sample_results.samples, os.path.join(save_dir, f"sample_latent_{i}.pt"))
|
| 370 |
+
|
| 371 |
+
if __name__ == '__main__':
|
| 372 |
+
torch.manual_seed(42)
|
| 373 |
+
config_path = "/gemini/user/private/zhaotianhao/Triposf/config_slat_flow_128to1024_pointnet_test.yaml"
|
| 374 |
+
with open(config_path) as f:
|
| 375 |
+
cfg = yaml.safe_load(f)
|
| 376 |
+
|
| 377 |
+
# Initialize Model on CPU first
|
| 378 |
+
diffusion_model = SLatFlowModel(
|
| 379 |
+
resolution=cfg['flow']['resolution'],
|
| 380 |
+
in_channels=cfg['flow']['in_channels'],
|
| 381 |
+
out_channels=cfg['flow']['out_channels'],
|
| 382 |
+
model_channels=cfg['flow']['model_channels'],
|
| 383 |
+
cond_channels=cfg['flow']['cond_channels'],
|
| 384 |
+
num_blocks=cfg['flow']['num_blocks'],
|
| 385 |
+
num_heads=cfg['flow']['num_heads'],
|
| 386 |
+
mlp_ratio=cfg['flow']['mlp_ratio'],
|
| 387 |
+
patch_size=cfg['flow']['patch_size'],
|
| 388 |
+
num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
|
| 389 |
+
io_block_channels=cfg['flow']['io_block_channels'],
|
| 390 |
+
pe_mode=cfg['flow']['pe_mode'],
|
| 391 |
+
qk_rms_norm=cfg['flow']['qk_rms_norm'],
|
| 392 |
+
qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
|
| 393 |
+
use_fp16=cfg['flow'].get('use_fp16', False),
|
| 394 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 395 |
+
|
| 396 |
+
trainer = SLatFlowMatchingTrainer(
|
| 397 |
+
denoiser=diffusion_model,
|
| 398 |
+
t_schedule=cfg['t_schedule'],
|
| 399 |
+
sigma_min=cfg['sigma_min'],
|
| 400 |
+
cfg=cfg,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
trainer.eval()
|
test_slat_flow_128to256_pointnet.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import yaml
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from functools import partial
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.amp import GradScaler, autocast
|
| 12 |
+
from typing import *
|
| 13 |
+
from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig, Dinov2Model, AutoImageProcessor, Dinov2Config
|
| 14 |
+
import torch
|
| 15 |
+
import re
|
| 16 |
+
from utils import load_pretrained_woself
|
| 17 |
+
|
| 18 |
+
from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 19 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 20 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet
|
| 21 |
+
|
| 22 |
+
from trellis.models.structured_latent_flow import SLatFlowModel
|
| 23 |
+
from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
|
| 24 |
+
|
| 25 |
+
from trellis.pipelines.samplers import FlowEulerSampler
|
| 26 |
+
from safetensors.torch import load_file
|
| 27 |
+
import open3d as o3d
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 31 |
+
from trellis.modules.sparse.basic import SparseTensor as SparseTensor_trellis
|
| 32 |
+
|
| 33 |
+
from triposf.modules.utils import DiagonalGaussianDistribution
|
| 34 |
+
|
| 35 |
+
from sklearn.decomposition import PCA
|
| 36 |
+
import trimesh
|
| 37 |
+
import torchvision.transforms as transforms
|
| 38 |
+
|
| 39 |
+
# --- Helper Functions ---
|
| 40 |
+
def save_colored_ply(points, colors, filename):
|
| 41 |
+
if len(points) == 0:
|
| 42 |
+
print(f"[Warning] No points to save for {filename}")
|
| 43 |
+
return
|
| 44 |
+
# Ensure colors are uint8
|
| 45 |
+
if colors.max() <= 1.0:
|
| 46 |
+
colors = (colors * 255).astype(np.uint8)
|
| 47 |
+
colors = colors.astype(np.uint8)
|
| 48 |
+
|
| 49 |
+
# Add Alpha if missing
|
| 50 |
+
if colors.shape[1] == 3:
|
| 51 |
+
colors = np.hstack([colors, np.full((len(colors), 1), 255, dtype=np.uint8)])
|
| 52 |
+
|
| 53 |
+
cloud = trimesh.PointCloud(points, colors=colors)
|
| 54 |
+
cloud.export(filename)
|
| 55 |
+
print(f"Saved colored point cloud to {filename}")
|
| 56 |
+
|
| 57 |
+
def normalize_to_rgb(features_3d):
|
| 58 |
+
min_vals = features_3d.min(axis=0)
|
| 59 |
+
max_vals = features_3d.max(axis=0)
|
| 60 |
+
range_vals = max_vals - min_vals
|
| 61 |
+
range_vals[range_vals == 0] = 1
|
| 62 |
+
normalized = (features_3d - min_vals) / range_vals
|
| 63 |
+
return (normalized * 255).astype(np.uint8)
|
| 64 |
+
|
| 65 |
+
class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
|
| 66 |
+
def __init__(self, *args, **kwargs):
|
| 67 |
+
super().__init__(*args, **kwargs)
|
| 68 |
+
self.cfg = kwargs.pop('cfg', None)
|
| 69 |
+
if self.cfg is None:
|
| 70 |
+
raise ValueError("Configuration dictionary 'cfg' must be provided.")
|
| 71 |
+
|
| 72 |
+
self.sampler = FlowEulerSampler(sigma_min=1.e-5)
|
| 73 |
+
self.device = torch.device("cuda")
|
| 74 |
+
|
| 75 |
+
# Based on PointNet Encoder setting
|
| 76 |
+
self.resolution = 128
|
| 77 |
+
|
| 78 |
+
self.condition_type = 'image'
|
| 79 |
+
self.is_cond = False
|
| 80 |
+
|
| 81 |
+
self.img_res = 518
|
| 82 |
+
self.feature_dim = self.cfg['model']['latent_dim']
|
| 83 |
+
|
| 84 |
+
self._init_components(
|
| 85 |
+
clip_model_path=self.cfg['training'].get('clip_model_path', None),
|
| 86 |
+
dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
|
| 87 |
+
vae_path=self.cfg['training']['vae_path'],
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Classifier head removed as it is not part of the Active Voxel pipeline
|
| 91 |
+
|
| 92 |
+
def _load_denoiser(self, denoiser_checkpoint_path):
|
| 93 |
+
path = denoiser_checkpoint_path
|
| 94 |
+
if not path or not os.path.isfile(path):
|
| 95 |
+
print("No valid checkpoint path provided for fine-tuning. Starting from scratch.")
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
print(f"Loading checkpoint from: {path}")
|
| 99 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
denoiser_state_dict = checkpoint['denoiser']
|
| 103 |
+
# Handle DDP prefix
|
| 104 |
+
if next(iter(denoiser_state_dict)).startswith('module.'):
|
| 105 |
+
denoiser_state_dict = {k[7:]: v for k, v in denoiser_state_dict.items()}
|
| 106 |
+
|
| 107 |
+
self.denoiser.load_state_dict(denoiser_state_dict)
|
| 108 |
+
print("Denoiser weights loaded successfully.")
|
| 109 |
+
except KeyError:
|
| 110 |
+
print("[WARN] 'denoiser' key not found in checkpoint. Skipping.")
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"[ERROR] Failed to load denoiser state_dict: {e}")
|
| 113 |
+
|
| 114 |
+
def _init_components(self,
|
| 115 |
+
clip_model_path=None,
|
| 116 |
+
dinov2_model_path=None,
|
| 117 |
+
vae_path=None,
|
| 118 |
+
):
|
| 119 |
+
|
| 120 |
+
# 1. Initialize PointNet Voxel Encoder (Matches Training)
|
| 121 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 122 |
+
in_channels=15,
|
| 123 |
+
hidden_dim=256,
|
| 124 |
+
out_channels=1024,
|
| 125 |
+
scatter_type='mean',
|
| 126 |
+
n_blocks=5,
|
| 127 |
+
resolution=128,
|
| 128 |
+
add_label=False,
|
| 129 |
+
).to(self.device)
|
| 130 |
+
|
| 131 |
+
# 2. Initialize VAE
|
| 132 |
+
self.vae = VoxelVAE(
|
| 133 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 134 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 135 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 136 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 137 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 138 |
+
num_heads=8,
|
| 139 |
+
num_head_channels=64,
|
| 140 |
+
mlp_ratio=4.0,
|
| 141 |
+
attn_mode="swin",
|
| 142 |
+
window_size=8,
|
| 143 |
+
pe_mode="ape",
|
| 144 |
+
use_fp16=False,
|
| 145 |
+
use_checkpoint=False,
|
| 146 |
+
qk_rms_norm=False,
|
| 147 |
+
using_subdivide=True,
|
| 148 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 149 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 150 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 151 |
+
).to(self.device)
|
| 152 |
+
|
| 153 |
+
# 3. Initialize Dataset with collate_fn_pointnet
|
| 154 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 155 |
+
root_dir=self.cfg['dataset']['path'],
|
| 156 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 157 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 158 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 159 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 160 |
+
|
| 161 |
+
process_img=False,
|
| 162 |
+
|
| 163 |
+
active_voxel_res=128,
|
| 164 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 165 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 166 |
+
sample_type=self.cfg['dataset'].get('sample_type', 'dora'),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.dataloader = DataLoader(
|
| 170 |
+
self.dataset,
|
| 171 |
+
batch_size=1,
|
| 172 |
+
shuffle=True,
|
| 173 |
+
collate_fn=partial(collate_fn_pointnet,), # Critical Change
|
| 174 |
+
num_workers=4,
|
| 175 |
+
pin_memory=True,
|
| 176 |
+
persistent_workers=True,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# 4. Load Pretrained Weights
|
| 180 |
+
# Assuming vae_path contains 'voxel_encoder' and 'vae'
|
| 181 |
+
print(f"Loading VAE/Encoder from {vae_path}")
|
| 182 |
+
ckpt = torch.load(vae_path, map_location='cpu')
|
| 183 |
+
|
| 184 |
+
# Load VAE
|
| 185 |
+
if 'vae' in ckpt:
|
| 186 |
+
self.vae.load_state_dict(ckpt['vae'], strict=False)
|
| 187 |
+
else:
|
| 188 |
+
self.vae.load_state_dict(ckpt) # Fallback
|
| 189 |
+
|
| 190 |
+
# Load Encoder
|
| 191 |
+
if 'voxel_encoder' in ckpt:
|
| 192 |
+
self.voxel_encoder.load_state_dict(ckpt['voxel_encoder'])
|
| 193 |
+
else:
|
| 194 |
+
print("[WARN] 'voxel_encoder' not found in checkpoint, random init (BAD for inference).")
|
| 195 |
+
|
| 196 |
+
self.voxel_encoder.eval()
|
| 197 |
+
self.vae.eval()
|
| 198 |
+
|
| 199 |
+
# 5. Initialize Conditioning Model
|
| 200 |
+
if self.condition_type == 'text':
|
| 201 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
|
| 202 |
+
self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
|
| 203 |
+
elif self.condition_type == 'image':
|
| 204 |
+
model_name = 'dinov2_vitl14_reg'
|
| 205 |
+
local_repo_path = "/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main"
|
| 206 |
+
weights_path = "/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
|
| 207 |
+
|
| 208 |
+
dinov2_model = torch.hub.load(
|
| 209 |
+
repo_or_dir=local_repo_path,
|
| 210 |
+
model=model_name,
|
| 211 |
+
source='local',
|
| 212 |
+
pretrained=False
|
| 213 |
+
)
|
| 214 |
+
self.condition_model = dinov2_model
|
| 215 |
+
self.condition_model.load_state_dict(torch.load(weights_path))
|
| 216 |
+
|
| 217 |
+
self.image_cond_model_transform = transforms.Compose([
|
| 218 |
+
transforms.ToTensor(),
|
| 219 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 220 |
+
])
|
| 221 |
+
else:
|
| 222 |
+
raise ValueError(f"Unsupported condition type: {self.condition_type}")
|
| 223 |
+
|
| 224 |
+
self.condition_model.to(self.device).eval()
|
| 225 |
+
|
| 226 |
+
@torch.no_grad()
|
| 227 |
+
def encode_image(self, images) -> torch.Tensor:
|
| 228 |
+
if isinstance(images, torch.Tensor):
|
| 229 |
+
batch_tensor = images.to(self.device)
|
| 230 |
+
elif isinstance(images, list):
|
| 231 |
+
assert all(isinstance(i, Image.Image) for i in images)
|
| 232 |
+
image = [i.resize((518, 518), Image.LANCZOS) for i in images]
|
| 233 |
+
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
| 234 |
+
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
| 235 |
+
batch_tensor = torch.stack(image).to(self.device)
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Unsupported type of image: {type(images)}")
|
| 238 |
+
|
| 239 |
+
if batch_tensor.shape[-2:] != (518, 518):
|
| 240 |
+
batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
|
| 241 |
+
|
| 242 |
+
features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
|
| 243 |
+
patchtokens = F.layer_norm(features, features.shape[-1:])
|
| 244 |
+
return patchtokens
|
| 245 |
+
|
| 246 |
+
def process_batch(self, batch):
|
| 247 |
+
preprocessed_images = batch['image']
|
| 248 |
+
cond_ = self.encode_image(preprocessed_images)
|
| 249 |
+
return cond_
|
| 250 |
+
|
| 251 |
+
def eval(self):
|
| 252 |
+
# Unconditional Setup
|
| 253 |
+
if self.is_cond == False:
|
| 254 |
+
if self.condition_type == 'text':
|
| 255 |
+
txt = ['']
|
| 256 |
+
encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
| 257 |
+
tokens = encoding['input_ids'].to(self.device)
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
cond_ = self.condition_model(input_ids=tokens).last_hidden_state
|
| 260 |
+
else:
|
| 261 |
+
blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
dummy_cond = self.encode_image([blank_img])
|
| 264 |
+
cond_ = torch.zeros_like(dummy_cond)
|
| 265 |
+
print(f"Generated unconditional image prompt (zero tensor) with shape: {cond_.shape}")
|
| 266 |
+
|
| 267 |
+
self.denoiser.eval()
|
| 268 |
+
|
| 269 |
+
# Load Denoiser Checkpoint
|
| 270 |
+
# Update this path to your ACTIVE VOXEL trained checkpoint
|
| 271 |
+
checkpoint_path = '/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/checkpoint_step215000_loss0_332666.pt'
|
| 272 |
+
|
| 273 |
+
self._load_denoiser(checkpoint_path)
|
| 274 |
+
|
| 275 |
+
filename = os.path.basename(checkpoint_path)
|
| 276 |
+
match = re.search(r'step(\d+)', filename)
|
| 277 |
+
step_str = match.group(1) if match else "eval"
|
| 278 |
+
save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_1000complex")
|
| 279 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 280 |
+
print(f"Results will be saved to: {save_dir}")
|
| 281 |
+
|
| 282 |
+
for i, batch in enumerate(self.dataloader):
|
| 283 |
+
if i > 50: exit() # Visualize first 10
|
| 284 |
+
|
| 285 |
+
if self.is_cond and self.condition_type == 'image':
|
| 286 |
+
cond_ = self.process_batch(batch)
|
| 287 |
+
|
| 288 |
+
if cond_.shape[0] != 1:
|
| 289 |
+
cond_ = cond_.expand(batch['active_voxels_128'].shape[0], -1, -1).contiguous().to(self.device)
|
| 290 |
+
else:
|
| 291 |
+
cond_ = cond_.to(self.device)
|
| 292 |
+
|
| 293 |
+
# --- Data Retrieval (Matches collate_fn_pointnet) ---
|
| 294 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 295 |
+
active_coords = batch['active_voxels_128'].to(self.device) # [N, 4]
|
| 296 |
+
|
| 297 |
+
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 298 |
+
with torch.no_grad():
|
| 299 |
+
# 1. Encode Ground Truth Latents
|
| 300 |
+
active_voxel_feats = self.voxel_encoder(
|
| 301 |
+
p=point_cloud,
|
| 302 |
+
sparse_coords=active_coords,
|
| 303 |
+
res=128,
|
| 304 |
+
bbox_size=(-0.5, 0.5),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
sparse_input = SparseTensor(
|
| 308 |
+
feats=active_voxel_feats,
|
| 309 |
+
coords=active_coords.int()
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Encode to get GT distribution
|
| 313 |
+
gt_latents, posterior = self.vae.encode(sparse_input)
|
| 314 |
+
|
| 315 |
+
print(f"Batch {i}: Active voxels: {active_coords.shape[0]}")
|
| 316 |
+
|
| 317 |
+
# 2. Generation / Sampling
|
| 318 |
+
# Generate noise on the SAME active coordinates
|
| 319 |
+
noise = SparseTensor_trellis(
|
| 320 |
+
coords=active_coords.int(),
|
| 321 |
+
feats=torch.randn(
|
| 322 |
+
active_coords.shape[0],
|
| 323 |
+
self.feature_dim,
|
| 324 |
+
device=self.device,
|
| 325 |
+
),
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
sample_results = self.sampler.sample(
|
| 329 |
+
model=self.denoiser.float(),
|
| 330 |
+
noise=noise.to(self.device).float(),
|
| 331 |
+
cond=cond_.to(self.device).float(),
|
| 332 |
+
steps=50,
|
| 333 |
+
rescale_t=1.0,
|
| 334 |
+
verbose=True,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
generated_sparse_tensor = sample_results.samples
|
| 338 |
+
generated_coords = generated_sparse_tensor.coords
|
| 339 |
+
generated_features = generated_sparse_tensor.feats
|
| 340 |
+
|
| 341 |
+
print('Gen features mean:', generated_features.mean().item(), 'std:', generated_features.std().item())
|
| 342 |
+
print('GT features mean:', gt_latents.feats.mean().item(), 'std:', gt_latents.feats.std().item())
|
| 343 |
+
print('MSE:', F.mse_loss(generated_features, gt_latents.feats).item())
|
| 344 |
+
|
| 345 |
+
# --- Visualization (PCA) ---
|
| 346 |
+
gt_feats_np = gt_latents.feats.detach().cpu().numpy()
|
| 347 |
+
gen_feats_np = generated_features.detach().cpu().numpy()
|
| 348 |
+
coords_np = active_coords[:, 1:4].detach().cpu().numpy() # x, y, z
|
| 349 |
+
|
| 350 |
+
print("Visualizing features using PCA...")
|
| 351 |
+
pca = PCA(n_components=3)
|
| 352 |
+
|
| 353 |
+
# Fit PCA on GT, transform both
|
| 354 |
+
pca.fit(gt_feats_np)
|
| 355 |
+
gt_feats_3d = pca.transform(gt_feats_np)
|
| 356 |
+
gen_feats_3d = pca.transform(gen_feats_np)
|
| 357 |
+
|
| 358 |
+
gt_colors = normalize_to_rgb(gt_feats_3d)
|
| 359 |
+
gen_colors = normalize_to_rgb(gen_feats_3d)
|
| 360 |
+
|
| 361 |
+
# Save PLYs
|
| 362 |
+
save_colored_ply(coords_np, gt_colors, os.path.join(save_dir, f"batch_{i}_gt_pca.ply"))
|
| 363 |
+
save_colored_ply(coords_np, gen_colors, os.path.join(save_dir, f"batch_{i}_gen_pca.ply"))
|
| 364 |
+
|
| 365 |
+
# Save Tensors for further analysis
|
| 366 |
+
torch.save(gt_latents, os.path.join(save_dir, f"gt_latent_{i}.pt"))
|
| 367 |
+
|
| 368 |
+
torch.save(batch, os.path.join(save_dir, f"gt_data_batch_{i}.pt"))
|
| 369 |
+
torch.save(sample_results.samples, os.path.join(save_dir, f"sample_latent_{i}.pt"))
|
| 370 |
+
|
| 371 |
+
if __name__ == '__main__':
|
| 372 |
+
torch.manual_seed(42)
|
| 373 |
+
config_path = "/gemini/user/private/zhaotianhao/Triposf/config_slat_flow_128to256_pointnet_test.yaml"
|
| 374 |
+
with open(config_path) as f:
|
| 375 |
+
cfg = yaml.safe_load(f)
|
| 376 |
+
|
| 377 |
+
# Initialize Model on CPU first
|
| 378 |
+
diffusion_model = SLatFlowModel(
|
| 379 |
+
resolution=cfg['flow']['resolution'],
|
| 380 |
+
in_channels=cfg['flow']['in_channels'],
|
| 381 |
+
out_channels=cfg['flow']['out_channels'],
|
| 382 |
+
model_channels=cfg['flow']['model_channels'],
|
| 383 |
+
cond_channels=cfg['flow']['cond_channels'],
|
| 384 |
+
num_blocks=cfg['flow']['num_blocks'],
|
| 385 |
+
num_heads=cfg['flow']['num_heads'],
|
| 386 |
+
mlp_ratio=cfg['flow']['mlp_ratio'],
|
| 387 |
+
patch_size=cfg['flow']['patch_size'],
|
| 388 |
+
num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
|
| 389 |
+
io_block_channels=cfg['flow']['io_block_channels'],
|
| 390 |
+
pe_mode=cfg['flow']['pe_mode'],
|
| 391 |
+
qk_rms_norm=cfg['flow']['qk_rms_norm'],
|
| 392 |
+
qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
|
| 393 |
+
use_fp16=cfg['flow'].get('use_fp16', False),
|
| 394 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 395 |
+
|
| 396 |
+
trainer = SLatFlowMatchingTrainer(
|
| 397 |
+
denoiser=diffusion_model,
|
| 398 |
+
t_schedule=cfg['t_schedule'],
|
| 399 |
+
sigma_min=cfg['sigma_min'],
|
| 400 |
+
cfg=cfg,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
trainer.eval()
|
test_slat_vae_128to1024_pointnet.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_slat_vae_128to1024_pointnet_vae.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_slat_vae_128to1024_pointnet_vae_addhead.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_slat_vae_128to1024_pointnet_vae_head.py
ADDED
|
@@ -0,0 +1,1339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
from plotly.subplots import make_subplots
|
| 10 |
+
from torch.utils.data import DataLoader, Subset
|
| 11 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 12 |
+
|
| 13 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 14 |
+
|
| 15 |
+
from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 16 |
+
from utils import load_pretrained_woself
|
| 17 |
+
|
| 18 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from functools import partial
|
| 22 |
+
import itertools
|
| 23 |
+
from typing import List, Tuple, Set
|
| 24 |
+
from collections import OrderedDict
|
| 25 |
+
from scipy.spatial import cKDTree
|
| 26 |
+
from sklearn.neighbors import KDTree
|
| 27 |
+
|
| 28 |
+
import trimesh
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
import time
|
| 33 |
+
|
| 34 |
+
from sklearn.decomposition import PCA
|
| 35 |
+
import matplotlib.pyplot as plt
|
| 36 |
+
|
| 37 |
+
import networkx as nx
|
| 38 |
+
|
| 39 |
+
def predict_mesh_connectivity(
|
| 40 |
+
connection_head,
|
| 41 |
+
vtx_feats,
|
| 42 |
+
vtx_coords,
|
| 43 |
+
batch_size=10000,
|
| 44 |
+
threshold=0.5,
|
| 45 |
+
k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测
|
| 46 |
+
device='cuda'
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
connection_head: 训练好的 MLP 模型
|
| 51 |
+
vtx_feats: [N, C] 顶点特征
|
| 52 |
+
vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边)
|
| 53 |
+
batch_size: MLP 推理的 batch size
|
| 54 |
+
threshold: 判定连接的概率阈值
|
| 55 |
+
k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。
|
| 56 |
+
"""
|
| 57 |
+
num_verts = vtx_feats.shape[0]
|
| 58 |
+
if num_verts < 3:
|
| 59 |
+
return [], [] # 无法构成三角形
|
| 60 |
+
|
| 61 |
+
connection_head.eval()
|
| 62 |
+
|
| 63 |
+
# --- 1. 生成候选边 (Candidate Edges) ---
|
| 64 |
+
if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts:
|
| 65 |
+
# 策略 A: 局部 KNN (推荐)
|
| 66 |
+
# 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版
|
| 67 |
+
# 或者直接暴力 cdist 如果 N < 10000
|
| 68 |
+
|
| 69 |
+
# 为了简单且高效,这里演示简单的 cdist (注意显存)
|
| 70 |
+
# 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree
|
| 71 |
+
dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N]
|
| 72 |
+
|
| 73 |
+
# 取 topk (smallest distance),排除自己
|
| 74 |
+
# values: [N, K], indices: [N, K]
|
| 75 |
+
_, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False)
|
| 76 |
+
neighbor_indices = indices[:, 1:] # 去掉第一列(自己)
|
| 77 |
+
|
| 78 |
+
# 构建 source, target 索引
|
| 79 |
+
src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten()
|
| 80 |
+
dst = neighbor_indices.flatten()
|
| 81 |
+
|
| 82 |
+
# 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重
|
| 83 |
+
# 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可
|
| 84 |
+
# 这里为了简单,我们生成 u < v 的 mask
|
| 85 |
+
mask = src < dst
|
| 86 |
+
u_indices = src[mask]
|
| 87 |
+
v_indices = dst[mask]
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
# 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用
|
| 91 |
+
u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device)
|
| 92 |
+
|
| 93 |
+
# --- 2. 批量推理 ---
|
| 94 |
+
all_probs = []
|
| 95 |
+
num_candidates = u_indices.shape[0]
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
for i in range(0, num_candidates, batch_size):
|
| 99 |
+
end = min(i + batch_size, num_candidates)
|
| 100 |
+
batch_u = u_indices[i:end]
|
| 101 |
+
batch_v = v_indices[i:end]
|
| 102 |
+
|
| 103 |
+
feat_u = vtx_feats[batch_u]
|
| 104 |
+
feat_v = vtx_feats[batch_v]
|
| 105 |
+
|
| 106 |
+
# Symmetric Forward (和你训练时保持一致)
|
| 107 |
+
# A -> B
|
| 108 |
+
input_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 109 |
+
logits_uv = connection_head(input_uv)
|
| 110 |
+
|
| 111 |
+
# B -> A
|
| 112 |
+
input_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 113 |
+
logits_vu = connection_head(input_vu)
|
| 114 |
+
|
| 115 |
+
# Sum logits
|
| 116 |
+
logits = (logits_uv + logits_vu) / 2.
|
| 117 |
+
probs = torch.sigmoid(logits)
|
| 118 |
+
all_probs.append(probs)
|
| 119 |
+
|
| 120 |
+
all_probs = torch.cat(all_probs).squeeze() # [M]
|
| 121 |
+
|
| 122 |
+
# --- 3. 筛选连接边 ---
|
| 123 |
+
connected_mask = all_probs > threshold
|
| 124 |
+
final_u = u_indices[connected_mask].cpu().numpy()
|
| 125 |
+
final_v = v_indices[connected_mask].cpu().numpy()
|
| 126 |
+
|
| 127 |
+
edges = np.stack([final_u, final_v], axis=1) # [E, 2]
|
| 128 |
+
|
| 129 |
+
return edges
|
| 130 |
+
|
| 131 |
+
def build_triangles_from_edges(edges, num_verts):
|
| 132 |
+
"""
|
| 133 |
+
从边列表构建三角形。
|
| 134 |
+
寻找图中所有的 3-Cliques (三元环)。
|
| 135 |
+
这在图论中是一个经典问题,可以使用 networkx 库。
|
| 136 |
+
"""
|
| 137 |
+
if len(edges) == 0:
|
| 138 |
+
return np.empty((0, 3), dtype=int)
|
| 139 |
+
|
| 140 |
+
G = nx.Graph()
|
| 141 |
+
G.add_nodes_from(range(num_verts))
|
| 142 |
+
G.add_edges_from(edges)
|
| 143 |
+
|
| 144 |
+
# 寻找所有的 3-cliques (三角形)
|
| 145 |
+
# enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的
|
| 146 |
+
# 或者使用 nx.triangles ? 不,那个只返回数量
|
| 147 |
+
# 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以
|
| 148 |
+
|
| 149 |
+
# 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w
|
| 150 |
+
triangles = []
|
| 151 |
+
adj = [set(G.neighbors(n)) for n in range(num_verts)]
|
| 152 |
+
|
| 153 |
+
# 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w
|
| 154 |
+
# 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u]
|
| 155 |
+
|
| 156 |
+
# 优化算法:
|
| 157 |
+
for u, v in edges:
|
| 158 |
+
if u > v: u, v = v, u # 确保有序
|
| 159 |
+
|
| 160 |
+
# 找公共邻居
|
| 161 |
+
common = adj[u].intersection(adj[v])
|
| 162 |
+
for w in common:
|
| 163 |
+
if w > v: # 强制顺序 u < v < w 防止重复
|
| 164 |
+
triangles.append([u, v, w])
|
| 165 |
+
|
| 166 |
+
return np.array(triangles)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def downsample_voxels(
|
| 170 |
+
voxels: torch.Tensor,
|
| 171 |
+
input_resolution: int,
|
| 172 |
+
output_resolution: int
|
| 173 |
+
) -> torch.Tensor:
|
| 174 |
+
if input_resolution % output_resolution != 0:
|
| 175 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 176 |
+
f"by output_resolution ({output_resolution}).")
|
| 177 |
+
|
| 178 |
+
factor = input_resolution // output_resolution
|
| 179 |
+
|
| 180 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 181 |
+
|
| 182 |
+
downsampled_voxels[:, 1:] //= factor
|
| 183 |
+
|
| 184 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 185 |
+
return unique_downsampled_voxels
|
| 186 |
+
|
| 187 |
+
def visualize_colored_points_ply(coords, vectors, filename):
|
| 188 |
+
"""
|
| 189 |
+
可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。
|
| 193 |
+
vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。
|
| 194 |
+
filename (str): 保存输出文件的名称,必须是 .ply 格式。
|
| 195 |
+
"""
|
| 196 |
+
# 确保输入是 numpy 数组
|
| 197 |
+
if isinstance(coords, torch.Tensor):
|
| 198 |
+
coords = coords.detach().cpu().numpy()
|
| 199 |
+
if isinstance(vectors, torch.Tensor):
|
| 200 |
+
vectors = vectors.detach().cpu().to(torch.float32).numpy()
|
| 201 |
+
|
| 202 |
+
# 检查输入数据是否为空,防止崩溃
|
| 203 |
+
if coords.size == 0 or vectors.size == 0:
|
| 204 |
+
print(f"警告:输入数据为空,未生成 {filename} 文件。")
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
# 将向量分量从 [-1, 1] 映射到 [0, 255]
|
| 208 |
+
# np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出
|
| 209 |
+
# (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2]
|
| 210 |
+
# * 127.5 将范围从 [0, 2] 缩放到 [0, 255]
|
| 211 |
+
colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8)
|
| 212 |
+
|
| 213 |
+
# 创建一个点云对象,并传入颜色信息
|
| 214 |
+
# trimesh.PointCloud 能够自动处理带颜色的点
|
| 215 |
+
points = trimesh.points.PointCloud(coords, colors=colors)
|
| 216 |
+
# 导出为 PLY 文件
|
| 217 |
+
points.export(filename, file_type='ply')
|
| 218 |
+
print(f"可视化文件已成功保存为: {filename}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0):
|
| 222 |
+
"""
|
| 223 |
+
使用 KDTree 最近邻 + 贪心匹配算法计算顶点匹配 (欧式距离)
|
| 224 |
+
|
| 225 |
+
参数:
|
| 226 |
+
pred_coords: 预测坐标 (Tensor)
|
| 227 |
+
gt_coords: 真实坐标 (Tensor)
|
| 228 |
+
threshold: 匹配误差阈值 (默认1.0)
|
| 229 |
+
|
| 230 |
+
返回:
|
| 231 |
+
matches: 匹配成功的顶点数量
|
| 232 |
+
match_rate: 匹配率 (基于真实顶点数)
|
| 233 |
+
pred_total: 预测顶点总数
|
| 234 |
+
gt_total: 真实顶点总数
|
| 235 |
+
"""
|
| 236 |
+
# 转换为整数坐标并去重
|
| 237 |
+
print('len(pred_coords)', len(pred_coords))
|
| 238 |
+
|
| 239 |
+
pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0)
|
| 240 |
+
gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0)
|
| 241 |
+
print('len(pred_array)', len(pred_array))
|
| 242 |
+
pred_total = len(pred_array)
|
| 243 |
+
gt_total = len(gt_array)
|
| 244 |
+
|
| 245 |
+
# 如果没有点,直接返回
|
| 246 |
+
if pred_total == 0 or gt_total == 0:
|
| 247 |
+
return 0, 0.0, pred_total, gt_total
|
| 248 |
+
|
| 249 |
+
# 建立 KDTree(以 gt 为基准)
|
| 250 |
+
tree = KDTree(gt_array)
|
| 251 |
+
|
| 252 |
+
# 查找预测点到最近的 gt 点
|
| 253 |
+
dist, indices = tree.query(pred_array, k=1)
|
| 254 |
+
dist = dist.squeeze()
|
| 255 |
+
indices = indices.squeeze()
|
| 256 |
+
|
| 257 |
+
# 贪心去重:确保 1 对 1
|
| 258 |
+
matches = 0
|
| 259 |
+
used_gt = set()
|
| 260 |
+
for d, idx in zip(dist, indices):
|
| 261 |
+
if d <= threshold and idx not in used_gt:
|
| 262 |
+
matches += 1
|
| 263 |
+
used_gt.add(idx)
|
| 264 |
+
|
| 265 |
+
match_rate = matches / max(gt_total, 1)
|
| 266 |
+
|
| 267 |
+
return matches, match_rate, pred_total, gt_total
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def flatten_coords_3d(coords_3d: torch.Tensor):
|
| 271 |
+
coords_3d_long = coords_3d #.long()
|
| 272 |
+
|
| 273 |
+
base_x = 1024
|
| 274 |
+
base_y = 1024 * 1024
|
| 275 |
+
base_z = 1024 * 1024 * 1024
|
| 276 |
+
|
| 277 |
+
flat_coords = coords_3d_long[:, 0] * base_z + \
|
| 278 |
+
coords_3d_long[:, 1] * base_y + \
|
| 279 |
+
coords_3d_long[:, 2] * base_x
|
| 280 |
+
return flat_coords
|
| 281 |
+
|
| 282 |
+
class Tester:
|
| 283 |
+
def __init__(self, ckpt_path, config_path=None, dataset_path=None):
|
| 284 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 285 |
+
self.ckpt_path = ckpt_path
|
| 286 |
+
|
| 287 |
+
self.config = self._load_config(config_path)
|
| 288 |
+
self.dataset_path = dataset_path # or self.config['dataset']['path']
|
| 289 |
+
checkpoint = torch.load(self.ckpt_path, map_location='cpu')
|
| 290 |
+
self.epoch = checkpoint.get('epoch', 0)
|
| 291 |
+
|
| 292 |
+
self._init_models()
|
| 293 |
+
self._init_dataset()
|
| 294 |
+
|
| 295 |
+
self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results")
|
| 296 |
+
os.makedirs(self.result_dir, exist_ok=True)
|
| 297 |
+
|
| 298 |
+
dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '')
|
| 299 |
+
self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path),
|
| 300 |
+
f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs")
|
| 301 |
+
os.makedirs(self.output_voxel_dir, exist_ok=True)
|
| 302 |
+
|
| 303 |
+
self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path),
|
| 304 |
+
f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs")
|
| 305 |
+
os.makedirs(self.output_obj_dir, exist_ok=True)
|
| 306 |
+
|
| 307 |
+
def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str):
|
| 308 |
+
if coords.numel() == 0:
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
coords_np = coords.cpu().to(torch.float32).numpy()
|
| 312 |
+
labels_np = labels.cpu().to(torch.float32).numpy()
|
| 313 |
+
|
| 314 |
+
colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8)
|
| 315 |
+
colors[labels_np == 0] = [255, 0, 0]
|
| 316 |
+
colors[labels_np == 1] = [0, 0, 255]
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
import trimesh
|
| 320 |
+
point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors)
|
| 321 |
+
ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
|
| 322 |
+
point_cloud.export(ply_path)
|
| 323 |
+
except ImportError:
|
| 324 |
+
ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
|
| 325 |
+
with open(ply_path, 'w') as f:
|
| 326 |
+
f.write("ply\n")
|
| 327 |
+
f.write("format ascii 1.0\n")
|
| 328 |
+
f.write(f"element vertex {coords_np.shape[0]}\n")
|
| 329 |
+
f.write("property float x\n")
|
| 330 |
+
f.write("property float y\n")
|
| 331 |
+
f.write("property float z\n")
|
| 332 |
+
f.write("property uchar red\n")
|
| 333 |
+
f.write("property uchar green\n")
|
| 334 |
+
f.write("property uchar blue\n")
|
| 335 |
+
f.write("end_header\n")
|
| 336 |
+
for i in range(coords_np.shape[0]):
|
| 337 |
+
f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
|
| 338 |
+
|
| 339 |
+
def _load_config(self, config_path=None):
|
| 340 |
+
if config_path and os.path.exists(config_path):
|
| 341 |
+
with open(config_path) as f:
|
| 342 |
+
return yaml.safe_load(f)
|
| 343 |
+
|
| 344 |
+
ckpt_dir = os.path.dirname(self.ckpt_path)
|
| 345 |
+
possible_configs = [
|
| 346 |
+
os.path.join(ckpt_dir, "config.yaml"),
|
| 347 |
+
os.path.join(os.path.dirname(ckpt_dir), "config.yaml")
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
for config_file in possible_configs:
|
| 351 |
+
if os.path.exists(config_file):
|
| 352 |
+
with open(config_file) as f:
|
| 353 |
+
print(f"Loaded config from: {config_file}")
|
| 354 |
+
return yaml.safe_load(f)
|
| 355 |
+
|
| 356 |
+
checkpoint = torch.load(self.ckpt_path, map_location='cpu')
|
| 357 |
+
if 'config' in checkpoint:
|
| 358 |
+
print("Loaded config from checkpoint")
|
| 359 |
+
return checkpoint['config']
|
| 360 |
+
|
| 361 |
+
raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.")
|
| 362 |
+
|
| 363 |
+
def _init_models(self):
|
| 364 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 365 |
+
in_channels=15,
|
| 366 |
+
hidden_dim=256,
|
| 367 |
+
out_channels=1024,
|
| 368 |
+
scatter_type='mean',
|
| 369 |
+
n_blocks=5,
|
| 370 |
+
resolution=128,
|
| 371 |
+
|
| 372 |
+
).to(self.device)
|
| 373 |
+
|
| 374 |
+
self.connection_head = ConnectionHead(
|
| 375 |
+
channels=32 * 2,
|
| 376 |
+
out_channels=1,
|
| 377 |
+
mlp_ratio=16,
|
| 378 |
+
).to(self.device)
|
| 379 |
+
|
| 380 |
+
self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation
|
| 381 |
+
in_channels=self.config['model']['in_channels'],
|
| 382 |
+
latent_dim=self.config['model']['latent_dim'],
|
| 383 |
+
encoder_blocks=self.config['model']['encoder_blocks'],
|
| 384 |
+
# decoder_blocks=self.config['model']['decoder_blocks'],
|
| 385 |
+
decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'],
|
| 386 |
+
decoder_blocks_edge=self.config['model']['decoder_blocks_edge'],
|
| 387 |
+
num_heads=8,
|
| 388 |
+
num_head_channels=64,
|
| 389 |
+
mlp_ratio=4.0,
|
| 390 |
+
attn_mode="swin",
|
| 391 |
+
window_size=8,
|
| 392 |
+
pe_mode="ape",
|
| 393 |
+
use_fp16=False,
|
| 394 |
+
use_checkpoint=False,
|
| 395 |
+
qk_rms_norm=False,
|
| 396 |
+
using_subdivide=True,
|
| 397 |
+
using_attn=self.config['model']['using_attn'],
|
| 398 |
+
attn_first=self.config['model'].get('attn_first', True),
|
| 399 |
+
pred_direction=self.config['model'].get('pred_direction', False),
|
| 400 |
+
).to(self.device)
|
| 401 |
+
|
| 402 |
+
load_pretrained_woself(
|
| 403 |
+
checkpoint_path=self.ckpt_path,
|
| 404 |
+
voxel_encoder=self.voxel_encoder,
|
| 405 |
+
connection_head=self.connection_head,
|
| 406 |
+
vae=self.vae,
|
| 407 |
+
)
|
| 408 |
+
# --- 【新增】在这里添加权重检查逻辑 ---
|
| 409 |
+
print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---")
|
| 410 |
+
has_nan_inf = False
|
| 411 |
+
if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"):
|
| 412 |
+
has_nan_inf = True
|
| 413 |
+
|
| 414 |
+
if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"):
|
| 415 |
+
has_nan_inf = True
|
| 416 |
+
|
| 417 |
+
if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"):
|
| 418 |
+
has_nan_inf = True
|
| 419 |
+
|
| 420 |
+
if not has_nan_inf:
|
| 421 |
+
print("--- 权重检查通过。未发现 NaN/Inf 值。 ---")
|
| 422 |
+
else:
|
| 423 |
+
# 如果发现坏值,直接抛出异常,因为评估无法继续
|
| 424 |
+
raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。")
|
| 425 |
+
# --- 检查逻辑结束 ---
|
| 426 |
+
|
| 427 |
+
self.vae.eval()
|
| 428 |
+
self.voxel_encoder.eval()
|
| 429 |
+
self.connection_head.eval()
|
| 430 |
+
|
| 431 |
+
def _init_dataset(self):
|
| 432 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 433 |
+
root_dir=self.dataset_path,
|
| 434 |
+
base_resolution=self.config['dataset']['base_resolution'],
|
| 435 |
+
min_resolution=self.config['dataset']['min_resolution'],
|
| 436 |
+
cache_dir='/gemini/user/private/zhaotianhao/dataset_cache/test_15c_dora_128to1024',
|
| 437 |
+
# cache_dir=self.config['dataset']['cache_dir'],
|
| 438 |
+
renders_dir=self.config['dataset']['renders_dir'],
|
| 439 |
+
|
| 440 |
+
# filter_active_voxels=self.config['dataset']['filter_active_voxels'],
|
| 441 |
+
filter_active_voxels=False,
|
| 442 |
+
cache_filter_path=self.config['dataset']['cache_filter_path'],
|
| 443 |
+
|
| 444 |
+
sample_type=self.config['dataset']['sample_type'],
|
| 445 |
+
active_voxel_res=128,
|
| 446 |
+
pc_sample_number=819200,
|
| 447 |
+
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.dataloader = DataLoader(
|
| 451 |
+
self.dataset,
|
| 452 |
+
batch_size=1,
|
| 453 |
+
shuffle=False,
|
| 454 |
+
collate_fn=partial(collate_fn_pointnet),
|
| 455 |
+
num_workers=0,
|
| 456 |
+
pin_memory=True,
|
| 457 |
+
# prefetch_factor=4,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool:
|
| 461 |
+
"""
|
| 462 |
+
检查模型的所有参数中是否存在 NaN 或 Inf 值。
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
model (torch.nn.Module): 要检查的模型。
|
| 466 |
+
model_name (str): 模型的名称,用于打印日志。
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。
|
| 470 |
+
"""
|
| 471 |
+
found_issue = False
|
| 472 |
+
for name, param in model.named_parameters():
|
| 473 |
+
if torch.isnan(param.data).any():
|
| 474 |
+
print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!")
|
| 475 |
+
found_issue = True
|
| 476 |
+
if torch.isinf(param.data).any():
|
| 477 |
+
print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!")
|
| 478 |
+
found_issue = True
|
| 479 |
+
return found_issue
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
|
| 483 |
+
"""
|
| 484 |
+
修改后的函数,确保一对一匹配,并优先匹配最近的点对。
|
| 485 |
+
"""
|
| 486 |
+
pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
|
| 487 |
+
gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
|
| 488 |
+
|
| 489 |
+
pred_total = len(pred_array)
|
| 490 |
+
gt_total = len(gt_array)
|
| 491 |
+
|
| 492 |
+
if pred_total == 0 or gt_total == 0:
|
| 493 |
+
return {
|
| 494 |
+
'recall': 0.0,
|
| 495 |
+
'precision': 0.0,
|
| 496 |
+
'f1': 0.0,
|
| 497 |
+
'matches': 0,
|
| 498 |
+
'pred_count': pred_total,
|
| 499 |
+
'gt_count': gt_total
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
# 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点
|
| 503 |
+
tree = cKDTree(pred_array)
|
| 504 |
+
dists, pred_idxs = tree.query(gt_array, k=1)
|
| 505 |
+
|
| 506 |
+
# --- 核心修改部分 ---
|
| 507 |
+
|
| 508 |
+
# 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引)
|
| 509 |
+
# 这样我们就可以按距离对所有可能的匹配进行排序
|
| 510 |
+
possible_matches = []
|
| 511 |
+
for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)):
|
| 512 |
+
if dist <= threshold:
|
| 513 |
+
possible_matches.append((dist, gt_idx, pred_idx))
|
| 514 |
+
|
| 515 |
+
# 2. 按距离从小到大排序(贪心策略)
|
| 516 |
+
possible_matches.sort(key=lambda x: x[0])
|
| 517 |
+
|
| 518 |
+
matches = 0
|
| 519 |
+
# 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配
|
| 520 |
+
used_pred_indices = set()
|
| 521 |
+
used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨
|
| 522 |
+
|
| 523 |
+
# 3. 遍历排序后的可能匹配,进行一对一分配
|
| 524 |
+
for dist, gt_idx, pred_idx in possible_matches:
|
| 525 |
+
# 如果这个预测点和这个真实点都还没有被使用过
|
| 526 |
+
if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices:
|
| 527 |
+
matches += 1
|
| 528 |
+
used_pred_indices.add(pred_idx)
|
| 529 |
+
used_gt_indices.add(gt_idx)
|
| 530 |
+
|
| 531 |
+
# --- 修改结束 ---
|
| 532 |
+
|
| 533 |
+
# matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total
|
| 534 |
+
recall = matches / gt_total if gt_total > 0 else 0.0
|
| 535 |
+
precision = matches / pred_total if pred_total > 0 else 0.0
|
| 536 |
+
|
| 537 |
+
# 计算F1时,使用标准的 Precision 和 Recall 定义
|
| 538 |
+
if (precision + recall) == 0:
|
| 539 |
+
f1 = 0.0
|
| 540 |
+
else:
|
| 541 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 542 |
+
|
| 543 |
+
return {
|
| 544 |
+
'recall': recall,
|
| 545 |
+
'precision': precision,
|
| 546 |
+
'f1': f1,
|
| 547 |
+
'matches': matches,
|
| 548 |
+
'pred_count': pred_total,
|
| 549 |
+
'gt_count': gt_total
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
|
| 553 |
+
"""
|
| 554 |
+
一个折衷的顶点指标计算方案。
|
| 555 |
+
它沿用“为每个真实点寻找最近预测点”的逻辑,
|
| 556 |
+
但通过修正计算方式,确保Precision和F1值不会超过1.0。
|
| 557 |
+
"""
|
| 558 |
+
# 假设 pred_coords 和 gt_coords 是 PyTorch 张量
|
| 559 |
+
pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
|
| 560 |
+
gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
|
| 561 |
+
|
| 562 |
+
pred_total = len(pred_array)
|
| 563 |
+
gt_total = len(gt_array)
|
| 564 |
+
|
| 565 |
+
if pred_total == 0 or gt_total == 0:
|
| 566 |
+
return {
|
| 567 |
+
'recall': 0.0,
|
| 568 |
+
'precision': 0.0,
|
| 569 |
+
'f1': 0.0,
|
| 570 |
+
'matches': 0,
|
| 571 |
+
'pred_count': pred_total,
|
| 572 |
+
'gt_count': gt_total
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
# 在预测点上构建KD-Tree,为每个真实点查找最近的预测点
|
| 576 |
+
tree = cKDTree(pred_array)
|
| 577 |
+
dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引
|
| 578 |
+
|
| 579 |
+
# 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall)
|
| 580 |
+
# 这和您的第一个函数完全一样。
|
| 581 |
+
# 这个值代表了“有多少个真实点被成功找到了”。
|
| 582 |
+
matches_from_gt = np.sum(dists <= threshold)
|
| 583 |
+
|
| 584 |
+
# 2. 计算 Recall (召回率)
|
| 585 |
+
# 召回率的分母是真实点的总数,所以这里的计算是合理的。
|
| 586 |
+
recall = matches_from_gt / gt_total if gt_total > 0 else 0.0
|
| 587 |
+
|
| 588 |
+
# 3. 计算 Precision (精确率) - ✅ 这是核心修正点
|
| 589 |
+
# 精确率的分母是预测点的总数。
|
| 590 |
+
# 分子(True Positives)不能超过预测点的总数。
|
| 591 |
+
# 因此,我们取 matches_from_gt 和 pred_total 中的较小值。
|
| 592 |
+
# 这解决了 Precision > 1 的问题。
|
| 593 |
+
tp_for_precision = min(matches_from_gt, pred_total)
|
| 594 |
+
precision = tp_for_precision / pred_total if pred_total > 0 else 0.0
|
| 595 |
+
|
| 596 |
+
# 4. 使用标准的F1分数公式
|
| 597 |
+
# 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score,
|
| 598 |
+
# 更常用的是基于 Precision 和 Recall 的调和平均数。
|
| 599 |
+
if (precision + recall) == 0:
|
| 600 |
+
f1 = 0.0
|
| 601 |
+
else:
|
| 602 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 603 |
+
|
| 604 |
+
return {
|
| 605 |
+
'recall': recall,
|
| 606 |
+
'precision': precision,
|
| 607 |
+
'f1': f1,
|
| 608 |
+
'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察
|
| 609 |
+
'pred_count': pred_total,
|
| 610 |
+
'gt_count': gt_total
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False):
|
| 614 |
+
if len(p1) == 0 or len(p2) == 0:
|
| 615 |
+
return float('nan')
|
| 616 |
+
|
| 617 |
+
dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean()
|
| 618 |
+
|
| 619 |
+
if one_sided:
|
| 620 |
+
return dist_p1_p2.item()
|
| 621 |
+
else:
|
| 622 |
+
dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean()
|
| 623 |
+
return (dist_p1_p2 + dist_p2_p1).item() / 2
|
| 624 |
+
|
| 625 |
+
def visualize_latent_space_pca(self, sample_idx: int):
|
| 626 |
+
"""
|
| 627 |
+
Encodes a sample, performs PCA on its latent features, and saves a
|
| 628 |
+
colored PLY file for visualization.
|
| 629 |
+
|
| 630 |
+
The position of each point in the PLY file corresponds to the spatial
|
| 631 |
+
location in the latent grid.
|
| 632 |
+
|
| 633 |
+
The color of each point represents the first three principal components
|
| 634 |
+
of its feature vector.
|
| 635 |
+
"""
|
| 636 |
+
print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---")
|
| 637 |
+
self.vae.eval()
|
| 638 |
+
|
| 639 |
+
try:
|
| 640 |
+
# 1. Get the latent representation for the sample
|
| 641 |
+
latent = self._get_latent_for_sample(sample_idx)
|
| 642 |
+
except ValueError as e:
|
| 643 |
+
print(f"Error: {e}")
|
| 644 |
+
return
|
| 645 |
+
|
| 646 |
+
latent_coords = latent.coords.detach().cpu().numpy()
|
| 647 |
+
latent_feats = latent.feats.detach().cpu().numpy()
|
| 648 |
+
|
| 649 |
+
if latent_feats.shape[0] < 3:
|
| 650 |
+
print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.")
|
| 651 |
+
return
|
| 652 |
+
|
| 653 |
+
print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...")
|
| 654 |
+
|
| 655 |
+
# 2. Perform PCA to reduce feature dimensions to 3
|
| 656 |
+
pca = PCA(n_components=3)
|
| 657 |
+
pca_features = pca.fit_transform(latent_feats)
|
| 658 |
+
|
| 659 |
+
print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}")
|
| 660 |
+
print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}")
|
| 661 |
+
|
| 662 |
+
# 3. Normalize the PCA components to be used as RGB colors [0, 255]
|
| 663 |
+
# We normalize each component independently to maximize color contrast
|
| 664 |
+
normalized_colors = np.zeros_like(pca_features)
|
| 665 |
+
for i in range(3):
|
| 666 |
+
min_val = pca_features[:, i].min()
|
| 667 |
+
max_val = pca_features[:, i].max()
|
| 668 |
+
if max_val - min_val > 1e-6:
|
| 669 |
+
normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val)
|
| 670 |
+
else:
|
| 671 |
+
normalized_colors[:, i] = 0.5 # Handle case of constant value
|
| 672 |
+
|
| 673 |
+
colors_uint8 = (normalized_colors * 255).astype(np.uint8)
|
| 674 |
+
|
| 675 |
+
# 4. Prepare spatial coordinates for the point cloud
|
| 676 |
+
# latent_coords is (batch_idx, x, y, z), we want the xyz part
|
| 677 |
+
spatial_coords = latent_coords[:, 1:]
|
| 678 |
+
|
| 679 |
+
# 5. Create and save the colored PLY file
|
| 680 |
+
try:
|
| 681 |
+
# Create a Trimesh PointCloud object
|
| 682 |
+
point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8)
|
| 683 |
+
|
| 684 |
+
# Define the output filename
|
| 685 |
+
filename = f"sample_{sample_idx}_latent_pca.ply"
|
| 686 |
+
ply_path = os.path.join(self.output_voxel_dir, filename)
|
| 687 |
+
|
| 688 |
+
# Export the file
|
| 689 |
+
point_cloud.export(ply_path)
|
| 690 |
+
print(f"--> Successfully saved PCA visualization to: {ply_path}")
|
| 691 |
+
|
| 692 |
+
except Exception as e:
|
| 693 |
+
print(f"Error during Trimesh export: {e}")
|
| 694 |
+
print("Please ensure 'trimesh' is installed correctly.")
|
| 695 |
+
|
| 696 |
+
def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor:
|
| 697 |
+
"""
|
| 698 |
+
Encodes a single sample and returns its latent representation.
|
| 699 |
+
"""
|
| 700 |
+
print(f"--> Encoding sample {sample_idx} to get its latent vector...")
|
| 701 |
+
# Get data for the specified sample
|
| 702 |
+
batch_data = self.dataset[sample_idx]
|
| 703 |
+
if batch_data is None:
|
| 704 |
+
raise ValueError(f"Sample at index {sample_idx} could not be loaded.")
|
| 705 |
+
|
| 706 |
+
# Use the collate function to form a batch
|
| 707 |
+
batch_data = collate_fn_pointnet([batch_data])
|
| 708 |
+
|
| 709 |
+
with torch.no_grad():
|
| 710 |
+
# 1. Get input data and move to device
|
| 711 |
+
gt_vertex_voxels_1024 = batch_data['gt_vertex_voxels_1024'].to(self.device)
|
| 712 |
+
combined_voxels_1024 = batch_data['combined_voxels_1024'].to(self.device)
|
| 713 |
+
active_coords = batch_data['active_voxels_128'].to(self.device)
|
| 714 |
+
point_cloud = batch_data['point_cloud_128'].to(self.device)
|
| 715 |
+
|
| 716 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 717 |
+
edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
active_voxel_feats = self.voxel_encoder(
|
| 721 |
+
p=point_cloud,
|
| 722 |
+
sparse_coords=active_coords,
|
| 723 |
+
res=128,
|
| 724 |
+
bbox_size=(-0.5, 0.5),
|
| 725 |
+
|
| 726 |
+
# voxel_label=active_labels,
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
sparse_input = SparseTensor(
|
| 730 |
+
feats=active_voxel_feats,
|
| 731 |
+
coords=active_coords.int()
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
# 2. Encode to get the latent representation
|
| 735 |
+
latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,)
|
| 736 |
+
print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}")
|
| 737 |
+
return latent_128
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.):
|
| 741 |
+
total_samples = len(self.dataset)
|
| 742 |
+
eval_samples = min(num_samples or total_samples, total_samples)
|
| 743 |
+
# sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples)
|
| 744 |
+
sample_indices = range(eval_samples)
|
| 745 |
+
|
| 746 |
+
eval_dataset = Subset(self.dataset, sample_indices)
|
| 747 |
+
eval_loader = DataLoader(
|
| 748 |
+
eval_dataset,
|
| 749 |
+
batch_size=1,
|
| 750 |
+
shuffle=False,
|
| 751 |
+
collate_fn=partial(collate_fn_pointnet),
|
| 752 |
+
num_workers=self.config['training']['num_workers'],
|
| 753 |
+
pin_memory=True,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
per_sample_metrics = {
|
| 757 |
+
'vertex': {res: [] for res in [128, 256, 512, 1024]},
|
| 758 |
+
'edge': {res: [] for res in [128, 256, 512, 1024]},
|
| 759 |
+
'sample_names': []
|
| 760 |
+
}
|
| 761 |
+
avg_metrics = {
|
| 762 |
+
'vertex': {res: defaultdict(list) for res in [128, 256, 512, 1024]},
|
| 763 |
+
'edge': {res: defaultdict(list) for res in [128, 256, 512, 1024]},
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
self.vae.eval()
|
| 767 |
+
|
| 768 |
+
for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")):
|
| 769 |
+
if batch_data is None:
|
| 770 |
+
continue
|
| 771 |
+
sample_idx = sample_indices[batch_idx]
|
| 772 |
+
sample_name = f'sample_{sample_idx}'
|
| 773 |
+
per_sample_metrics['sample_names'].append(sample_name)
|
| 774 |
+
|
| 775 |
+
# batch_save_path = f"/root/Trisf/output_slat_flow_matching_active/ckpts/8w/195000_sample_active_vis_42seed_trellis_generate/gt_data_batch_{batch_idx}.pt"
|
| 776 |
+
# if not os.path.exists(batch_save_path):
|
| 777 |
+
# print(f"Warning: Saved batch file not found: {batch_save_path}")
|
| 778 |
+
# continue
|
| 779 |
+
# batch_data = torch.load(batch_save_path, map_location=self.device)
|
| 780 |
+
|
| 781 |
+
with torch.no_grad():
|
| 782 |
+
# 1. Get input data
|
| 783 |
+
combined_voxels_1024 = batch_data['combined_voxels_1024'].to(self.device)
|
| 784 |
+
combined_voxel_labels_1024 = batch_data['combined_voxel_labels_1024'].to(self.device)
|
| 785 |
+
gt_combined_endpoints_1024 = batch_data['gt_combined_endpoints_1024'].to(self.device)
|
| 786 |
+
gt_combined_errors_1024 = batch_data['gt_combined_errors_1024'].to(self.device)
|
| 787 |
+
|
| 788 |
+
edge_mask = (combined_voxel_labels_1024 == 1)
|
| 789 |
+
|
| 790 |
+
gt_edge_endpoints_1024 = gt_combined_endpoints_1024[edge_mask].to(self.device)
|
| 791 |
+
gt_edge_errors_1024 = gt_combined_errors_1024[edge_mask].to(self.device)
|
| 792 |
+
gt_edge_voxels_1024 = combined_voxels_1024[edge_mask].to(self.device)
|
| 793 |
+
|
| 794 |
+
p1 = gt_edge_endpoints_1024[:, 1:4].float()
|
| 795 |
+
p2 = gt_edge_endpoints_1024[:, 4:7].float()
|
| 796 |
+
|
| 797 |
+
mask = ( (p1[:,0] < p2[:,0]) |
|
| 798 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
|
| 799 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
|
| 800 |
+
|
| 801 |
+
pA = torch.where(mask[:, None], p1, p2) # smaller one
|
| 802 |
+
pB = torch.where(mask[:, None], p2, p1) # larger one
|
| 803 |
+
|
| 804 |
+
d = pB - pA
|
| 805 |
+
dir_gt = F.normalize(d, dim=-1, eps=1e-6)
|
| 806 |
+
|
| 807 |
+
gt_vertex_voxels_1024 = batch_data['gt_vertex_voxels_1024'].to(self.device).int()
|
| 808 |
+
|
| 809 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 810 |
+
|
| 811 |
+
edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 812 |
+
edge_512 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=512)
|
| 813 |
+
edge_256 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=256)
|
| 814 |
+
edge_1024 = combined_voxels_1024
|
| 815 |
+
|
| 816 |
+
print('vtx_128.shape', vtx_128.shape)
|
| 817 |
+
print('edge_128.shape', edge_128.shape)
|
| 818 |
+
|
| 819 |
+
gt_edge_voxels_list = [
|
| 820 |
+
edge_128,
|
| 821 |
+
edge_256,
|
| 822 |
+
edge_512,
|
| 823 |
+
edge_1024,
|
| 824 |
+
]
|
| 825 |
+
|
| 826 |
+
active_coords = batch_data['active_voxels_128'].to(self.device)
|
| 827 |
+
point_cloud = batch_data['point_cloud_128'].to(self.device)
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
active_voxel_feats = self.voxel_encoder(
|
| 832 |
+
p=point_cloud,
|
| 833 |
+
sparse_coords=active_coords,
|
| 834 |
+
res=128,
|
| 835 |
+
bbox_size=(-0.5, 0.5),
|
| 836 |
+
|
| 837 |
+
# voxel_label=active_labels,
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
sparse_input = SparseTensor(
|
| 841 |
+
feats=active_voxel_feats,
|
| 842 |
+
coords=active_coords.int()
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
latent_128, posterior = self.vae.encode(sparse_input)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
# load_path = f'/root/Trisf/output_slat_flow_matching_active/ckpts/8w/195000_sample_active_vis_42seed_trellis_generate/sample_latent_{batch_idx}.pt'
|
| 849 |
+
# latent_128 = torch.load(load_path, map_location=self.device)
|
| 850 |
+
|
| 851 |
+
print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std())
|
| 852 |
+
print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean())
|
| 853 |
+
print('latent_128.coords.shape', latent_128.coords.shape)
|
| 854 |
+
|
| 855 |
+
# latent_128 = torch.load(f"/root/Trisf/output_slat_flow_matching/ckpts/1100_chair_sample/110000step_sample/sample_results_samples_{batch_idx}.pt", map_location=self.device)
|
| 856 |
+
latent_128 = SparseTensor(
|
| 857 |
+
coords=latent_128.coords,
|
| 858 |
+
feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats),
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
# self.output_voxel_dir = os.path.dirname(load_path)
|
| 862 |
+
# self.output_obj_dir = os.path.dirname(load_path)
|
| 863 |
+
|
| 864 |
+
# 7. Decoding with separate vertex and edge processing
|
| 865 |
+
decoded_results = self.vae.decode(
|
| 866 |
+
latent_128,
|
| 867 |
+
gt_vertex_voxels_list=[],
|
| 868 |
+
gt_edge_voxels_list=[],
|
| 869 |
+
training=False,
|
| 870 |
+
|
| 871 |
+
inference_threshold=0.5,
|
| 872 |
+
vis_last_layer=False,
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
error = 0 #decoded_results[-1]['edge']['predicted_offset_feats']
|
| 876 |
+
|
| 877 |
+
if self.config['model'].get('pred_direction', False):
|
| 878 |
+
pred_dir = decoded_results[-1]['edge']['predicted_direction_feats']
|
| 879 |
+
zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0
|
| 880 |
+
num_zeros = zero_mask.sum().item()
|
| 881 |
+
print("Number of zero vectors:", num_zeros)
|
| 882 |
+
|
| 883 |
+
pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
|
| 884 |
+
print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
|
| 885 |
+
print('pred_dir.shape', pred_dir.shape)
|
| 886 |
+
if pred_edge_coords_3d.shape[-1] == 4:
|
| 887 |
+
pred_edge_coords_3d = pred_edge_coords_3d[:, 1:]
|
| 888 |
+
# visualize_directions(pred_edge_coords_3d, pred_dir, sample_ratio=0.02)
|
| 889 |
+
save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply")
|
| 890 |
+
# visualize_colored_points_ply(pred_edge_coords_3d - error / 2. + 0.5, pred_dir, save_pth)
|
| 891 |
+
visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth)
|
| 892 |
+
|
| 893 |
+
save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply")
|
| 894 |
+
# visualize_colored_points_ply((gt_edge_voxels_1024[:, 1:] - gt_edge_errors_1024[:, 1:] + 0.5), dir_gt, save_pth)
|
| 895 |
+
visualize_colored_points_ply((gt_edge_voxels_1024[:, 1:]), dir_gt, save_pth)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords']
|
| 899 |
+
pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
|
| 900 |
+
|
| 901 |
+
pred_edge_coords_np = np.round(pred_edge_coords_3d.cpu().numpy()).astype(int)
|
| 902 |
+
|
| 903 |
+
gt_vertex_voxels_1024 = batch_data['gt_vertex_voxels_1024'][:, 1:].to(self.device)
|
| 904 |
+
gt_edge_voxels_1024 = batch_data['gt_edge_voxels_1024'][:, 1:].to(self.device)
|
| 905 |
+
gt_edge_coords_np = np.round(gt_edge_voxels_1024.cpu().numpy()).astype(int)
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
# Calculate metrics and save results
|
| 909 |
+
matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_1024, threshold=threshold,)
|
| 910 |
+
print(f"\n----- Resolution {1024} vtx -----")
|
| 911 |
+
print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
|
| 912 |
+
print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
|
| 913 |
+
|
| 914 |
+
self._save_voxel_ply(pred_vtx_coords_3d / 1024., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx")
|
| 915 |
+
self._save_voxel_ply((pred_edge_coords_3d) / 1024, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge")
|
| 916 |
+
|
| 917 |
+
self._save_voxel_ply(gt_vertex_voxels_1024 / 1024, torch.zeros(len(gt_vertex_voxels_1024)), f"{sample_name}_gt_vertex")
|
| 918 |
+
self._save_voxel_ply((combined_voxels_1024[:, 1:]) / 1024., torch.zeros(len(gt_combined_errors_1024)), f"{sample_name}_gt_edge")
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
# Calculate vertex-specific metrics
|
| 922 |
+
matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_1024[:, 1:], threshold=threshold,)
|
| 923 |
+
print(f"\n----- Resolution {1024} edge -----")
|
| 924 |
+
print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
|
| 925 |
+
print('gt_edge_voxels_1024.shape', gt_edge_voxels_1024.shape)
|
| 926 |
+
|
| 927 |
+
print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
|
| 928 |
+
print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
|
| 929 |
+
|
| 930 |
+
pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int)
|
| 931 |
+
pred_edges = []
|
| 932 |
+
gt_vertex_coords_np = np.round(gt_vertex_voxels_1024.cpu().numpy()).astype(int)
|
| 933 |
+
if visualize:
|
| 934 |
+
if pred_vtx_coords_3d.shape[-1] == 4:
|
| 935 |
+
pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float()
|
| 936 |
+
else:
|
| 937 |
+
pred_vtx_coords_float = pred_vtx_coords_3d.float()
|
| 938 |
+
|
| 939 |
+
pred_vtx_feats = decoded_results[-1]['vertex']['feats']
|
| 940 |
+
|
| 941 |
+
# ==========================================
|
| 942 |
+
# Link Prediction & Mesh Generation
|
| 943 |
+
# ==========================================
|
| 944 |
+
print("Predicting connectivity...")
|
| 945 |
+
|
| 946 |
+
# 1. 预测边
|
| 947 |
+
# 注意:K_neighbors 的设置。如果是物体,64 足够了。
|
| 948 |
+
# 如果点非常稀疏,可能需要更大。
|
| 949 |
+
pred_edges = predict_mesh_connectivity(
|
| 950 |
+
connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的
|
| 951 |
+
vtx_feats=pred_vtx_feats,
|
| 952 |
+
vtx_coords=pred_vtx_coords_float,
|
| 953 |
+
batch_size=4096,
|
| 954 |
+
threshold=0.5,
|
| 955 |
+
k_neighbors=None,
|
| 956 |
+
device=self.device
|
| 957 |
+
)
|
| 958 |
+
print(f"Predicted {len(pred_edges)} edges.")
|
| 959 |
+
|
| 960 |
+
# 2. 构建三角形
|
| 961 |
+
num_verts = pred_vtx_coords_float.shape[0]
|
| 962 |
+
pred_faces = build_triangles_from_edges(pred_edges, num_verts)
|
| 963 |
+
print(f"Constructed {len(pred_faces)} triangles.")
|
| 964 |
+
|
| 965 |
+
# 3. 保存 OBJ
|
| 966 |
+
import trimesh
|
| 967 |
+
|
| 968 |
+
# 坐标归一化/还原 (根据你的需求,这里假设你是 0-1024 的体素坐标)
|
| 969 |
+
# 如果想保存为归一化坐标:
|
| 970 |
+
mesh_verts = pred_vtx_coords_float.cpu().numpy() / 1024.0
|
| 971 |
+
|
| 972 |
+
# 如果有 error offset,记得加上!
|
| 973 |
+
# 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
|
| 974 |
+
# 如果 vertex 也有 offset (如 dual contouring),在这里加上
|
| 975 |
+
|
| 976 |
+
# 移动到中心 (可选)
|
| 977 |
+
mesh_verts = mesh_verts - 0.5
|
| 978 |
+
|
| 979 |
+
mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
|
| 980 |
+
|
| 981 |
+
# 过滤孤立点 (可选)
|
| 982 |
+
# mesh.remove_unreferenced_vertices()
|
| 983 |
+
|
| 984 |
+
output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj")
|
| 985 |
+
mesh.export(output_obj_path)
|
| 986 |
+
print(f"Saved mesh to {output_obj_path}")
|
| 987 |
+
|
| 988 |
+
# 保存边线 (用于 Debug)
|
| 989 |
+
# 有时候三角形很难形成,只看边也很有用
|
| 990 |
+
edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply")
|
| 991 |
+
# self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison")
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
# Process results at different resolutions
|
| 995 |
+
for i, res in enumerate([128, 256, 512, 1024]):
|
| 996 |
+
if i >= len(decoded_results):
|
| 997 |
+
continue
|
| 998 |
+
|
| 999 |
+
gt_key = f'gt_vertex_voxels_{res}'
|
| 1000 |
+
if gt_key not in batch_data:
|
| 1001 |
+
continue
|
| 1002 |
+
if i == 0:
|
| 1003 |
+
pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float()
|
| 1004 |
+
gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
|
| 1005 |
+
else:
|
| 1006 |
+
pred_coords_res = decoded_results[i]['vertex']['coords'].float()
|
| 1007 |
+
gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold)
|
| 1011 |
+
|
| 1012 |
+
per_sample_metrics['vertex'][res].append({
|
| 1013 |
+
'recall': v_metrics['recall'],
|
| 1014 |
+
'precision': v_metrics['precision'],
|
| 1015 |
+
'f1': v_metrics['f1'],
|
| 1016 |
+
'num_pred': len(pred_coords_res),
|
| 1017 |
+
'num_gt': len(gt_coords_res)
|
| 1018 |
+
})
|
| 1019 |
+
|
| 1020 |
+
avg_metrics['vertex'][res]['recall'].append(v_metrics['recall'])
|
| 1021 |
+
avg_metrics['vertex'][res]['precision'].append(v_metrics['precision'])
|
| 1022 |
+
avg_metrics['vertex'][res]['f1'].append(v_metrics['f1'])
|
| 1023 |
+
|
| 1024 |
+
gt_edge_key = f'gt_edge_voxels_{res}'
|
| 1025 |
+
if gt_edge_key not in batch_data:
|
| 1026 |
+
continue
|
| 1027 |
+
|
| 1028 |
+
if i == 0:
|
| 1029 |
+
pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float()
|
| 1030 |
+
# gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1031 |
+
idx = i
|
| 1032 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1033 |
+
elif i == 3:
|
| 1034 |
+
idx = i
|
| 1035 |
+
#################################
|
| 1036 |
+
# pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5
|
| 1037 |
+
# # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1038 |
+
# gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_1024[:, 1:].to(self.device) + 0.5
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
|
| 1042 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1043 |
+
|
| 1044 |
+
# self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset")
|
| 1045 |
+
# self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset")
|
| 1046 |
+
|
| 1047 |
+
else:
|
| 1048 |
+
idx = i
|
| 1049 |
+
pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
|
| 1050 |
+
# gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1051 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1052 |
+
|
| 1053 |
+
# self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res")
|
| 1054 |
+
# self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res")
|
| 1055 |
+
|
| 1056 |
+
e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold)
|
| 1057 |
+
|
| 1058 |
+
per_sample_metrics['edge'][res].append({
|
| 1059 |
+
'recall': e_metrics['recall'],
|
| 1060 |
+
'precision': e_metrics['precision'],
|
| 1061 |
+
'f1': e_metrics['f1'],
|
| 1062 |
+
'num_pred': len(pred_edge_coords_res),
|
| 1063 |
+
'num_gt': len(gt_edge_coords_res)
|
| 1064 |
+
})
|
| 1065 |
+
|
| 1066 |
+
avg_metrics['edge'][res]['recall'].append(e_metrics['recall'])
|
| 1067 |
+
avg_metrics['edge'][res]['precision'].append(e_metrics['precision'])
|
| 1068 |
+
avg_metrics['edge'][res]['f1'].append(e_metrics['f1'])
|
| 1069 |
+
|
| 1070 |
+
avg_metrics_processed = {}
|
| 1071 |
+
for category, res_dict in avg_metrics.items():
|
| 1072 |
+
avg_metrics_processed[category] = {}
|
| 1073 |
+
for resolution, metric_dict in res_dict.items():
|
| 1074 |
+
avg_metrics_processed[category][resolution] = {
|
| 1075 |
+
metric_name: np.mean(values) if values else float('nan')
|
| 1076 |
+
for metric_name, values in metric_dict.items()
|
| 1077 |
+
}
|
| 1078 |
+
|
| 1079 |
+
result_data = {
|
| 1080 |
+
'config': self.config,
|
| 1081 |
+
'checkpoint': self.ckpt_path,
|
| 1082 |
+
'dataset': self.dataset_path,
|
| 1083 |
+
'num_samples': eval_samples,
|
| 1084 |
+
'per_sample_metrics': per_sample_metrics,
|
| 1085 |
+
'avg_metrics': avg_metrics_processed
|
| 1086 |
+
}
|
| 1087 |
+
|
| 1088 |
+
results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml")
|
| 1089 |
+
with open(results_file_path, 'w') as f:
|
| 1090 |
+
yaml.dump(result_data, f, default_flow_style=False)
|
| 1091 |
+
|
| 1092 |
+
return result_data
|
| 1093 |
+
|
| 1094 |
+
def _generate_line_voxels(
|
| 1095 |
+
self,
|
| 1096 |
+
p1: torch.Tensor,
|
| 1097 |
+
p2: torch.Tensor
|
| 1098 |
+
) -> Tuple[
|
| 1099 |
+
List[Tuple[int, int, int]],
|
| 1100 |
+
List[Tuple[torch.Tensor, torch.Tensor]],
|
| 1101 |
+
List[np.ndarray]
|
| 1102 |
+
]:
|
| 1103 |
+
"""
|
| 1104 |
+
Improved version using better sampling strategy
|
| 1105 |
+
"""
|
| 1106 |
+
p1_np = p1 #.cpu().numpy()
|
| 1107 |
+
p2_np = p2 #.cpu().numpy()
|
| 1108 |
+
voxel_dict = OrderedDict()
|
| 1109 |
+
|
| 1110 |
+
# Use proper 3D line voxelization algorithm
|
| 1111 |
+
def bresenham_3d(p1, p2):
|
| 1112 |
+
"""3D Bresenham's line algorithm"""
|
| 1113 |
+
x1, y1, z1 = np.round(p1).astype(int)
|
| 1114 |
+
x2, y2, z2 = np.round(p2).astype(int)
|
| 1115 |
+
|
| 1116 |
+
points = []
|
| 1117 |
+
dx = abs(x2 - x1)
|
| 1118 |
+
dy = abs(y2 - y1)
|
| 1119 |
+
dz = abs(z2 - z1)
|
| 1120 |
+
|
| 1121 |
+
xs = 1 if x2 > x1 else -1
|
| 1122 |
+
ys = 1 if y2 > y1 else -1
|
| 1123 |
+
zs = 1 if z2 > z1 else -1
|
| 1124 |
+
|
| 1125 |
+
# Driving axis is X
|
| 1126 |
+
if dx >= dy and dx >= dz:
|
| 1127 |
+
err_1 = 2 * dy - dx
|
| 1128 |
+
err_2 = 2 * dz - dx
|
| 1129 |
+
for i in range(dx + 1):
|
| 1130 |
+
points.append((x1, y1, z1))
|
| 1131 |
+
if err_1 > 0:
|
| 1132 |
+
y1 += ys
|
| 1133 |
+
err_1 -= 2 * dx
|
| 1134 |
+
if err_2 > 0:
|
| 1135 |
+
z1 += zs
|
| 1136 |
+
err_2 -= 2 * dx
|
| 1137 |
+
err_1 += 2 * dy
|
| 1138 |
+
err_2 += 2 * dz
|
| 1139 |
+
x1 += xs
|
| 1140 |
+
|
| 1141 |
+
# Driving axis is Y
|
| 1142 |
+
elif dy >= dx and dy >= dz:
|
| 1143 |
+
err_1 = 2 * dx - dy
|
| 1144 |
+
err_2 = 2 * dz - dy
|
| 1145 |
+
for i in range(dy + 1):
|
| 1146 |
+
points.append((x1, y1, z1))
|
| 1147 |
+
if err_1 > 0:
|
| 1148 |
+
x1 += xs
|
| 1149 |
+
err_1 -= 2 * dy
|
| 1150 |
+
if err_2 > 0:
|
| 1151 |
+
z1 += zs
|
| 1152 |
+
err_2 -= 2 * dy
|
| 1153 |
+
err_1 += 2 * dx
|
| 1154 |
+
err_2 += 2 * dz
|
| 1155 |
+
y1 += ys
|
| 1156 |
+
|
| 1157 |
+
# Driving axis is Z
|
| 1158 |
+
else:
|
| 1159 |
+
err_1 = 2 * dx - dz
|
| 1160 |
+
err_2 = 2 * dy - dz
|
| 1161 |
+
for i in range(dz + 1):
|
| 1162 |
+
points.append((x1, y1, z1))
|
| 1163 |
+
if err_1 > 0:
|
| 1164 |
+
x1 += xs
|
| 1165 |
+
err_1 -= 2 * dz
|
| 1166 |
+
if err_2 > 0:
|
| 1167 |
+
y1 += ys
|
| 1168 |
+
err_2 -= 2 * dz
|
| 1169 |
+
err_1 += 2 * dx
|
| 1170 |
+
err_2 += 2 * dy
|
| 1171 |
+
z1 += zs
|
| 1172 |
+
|
| 1173 |
+
return points
|
| 1174 |
+
|
| 1175 |
+
# Get all voxels using Bresenham algorithm
|
| 1176 |
+
voxel_coords = bresenham_3d(p1_np, p2_np)
|
| 1177 |
+
|
| 1178 |
+
# Add all voxels to dictionary
|
| 1179 |
+
for coord in voxel_coords:
|
| 1180 |
+
voxel_dict[tuple(coord)] = (p1, p2)
|
| 1181 |
+
|
| 1182 |
+
voxel_coords = list(voxel_dict.keys())
|
| 1183 |
+
endpoint_pairs = list(voxel_dict.values())
|
| 1184 |
+
|
| 1185 |
+
# --- compute error vectors ---
|
| 1186 |
+
error_vectors = []
|
| 1187 |
+
diff = p2_np - p1_np
|
| 1188 |
+
d_norm_sq = np.dot(diff, diff)
|
| 1189 |
+
|
| 1190 |
+
for v in voxel_coords:
|
| 1191 |
+
v_center = np.array(v, dtype=float) + 0.5
|
| 1192 |
+
if d_norm_sq == 0: # degenerate line
|
| 1193 |
+
closest = p1_np
|
| 1194 |
+
else:
|
| 1195 |
+
t = np.dot(v_center - p1_np, diff) / d_norm_sq
|
| 1196 |
+
t = np.clip(t, 0.0, 1.0)
|
| 1197 |
+
closest = p1_np + t * diff
|
| 1198 |
+
error_vectors.append(v_center - closest)
|
| 1199 |
+
|
| 1200 |
+
return voxel_coords, endpoint_pairs, error_vectors
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
# 使用示例
|
| 1204 |
+
def set_seed(seed: int):
|
| 1205 |
+
random.seed(seed)
|
| 1206 |
+
np.random.seed(seed)
|
| 1207 |
+
torch.manual_seed(seed)
|
| 1208 |
+
if torch.cuda.is_available():
|
| 1209 |
+
torch.cuda.manual_seed(seed)
|
| 1210 |
+
torch.cuda.manual_seed_all(seed)
|
| 1211 |
+
torch.backends.cudnn.deterministic = True
|
| 1212 |
+
torch.backends.cudnn.benchmark = False
|
| 1213 |
+
|
| 1214 |
+
def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
|
| 1215 |
+
set_seed(42)
|
| 1216 |
+
tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
|
| 1217 |
+
result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD)
|
| 1218 |
+
|
| 1219 |
+
# 生成文件名
|
| 1220 |
+
epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0]
|
| 1221 |
+
dataset_name = os.path.basename(os.path.normpath(dataset_path))
|
| 1222 |
+
|
| 1223 |
+
# 保存简版报告(TXT)
|
| 1224 |
+
summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt")
|
| 1225 |
+
with open(summary_path, 'w') as f:
|
| 1226 |
+
# 头部信息
|
| 1227 |
+
f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n")
|
| 1228 |
+
f.write(f"Dataset: {dataset_name}\n")
|
| 1229 |
+
f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n")
|
| 1230 |
+
|
| 1231 |
+
# 平均指标
|
| 1232 |
+
f.write("=== Average Metrics ===\n")
|
| 1233 |
+
for category, data in result_data['avg_metrics'].items():
|
| 1234 |
+
if isinstance(data, dict): # 处理多分辨率情况
|
| 1235 |
+
f.write(f"\n{category.upper()}:\n")
|
| 1236 |
+
for res, metrics in data.items():
|
| 1237 |
+
f.write(f" Resolution {res}:\n")
|
| 1238 |
+
for k, v in metrics.items():
|
| 1239 |
+
# 确保值是数字类型后再格式化
|
| 1240 |
+
if isinstance(v, (int, float)):
|
| 1241 |
+
f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
|
| 1242 |
+
else:
|
| 1243 |
+
f.write(f" {str(k).ljust(15)}: {str(v)}\n")
|
| 1244 |
+
else: # 处理非多分辨率情况
|
| 1245 |
+
f.write(f"\n{category.upper()}:\n")
|
| 1246 |
+
for k, v in data.items():
|
| 1247 |
+
if isinstance(v, (int, float)):
|
| 1248 |
+
f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
|
| 1249 |
+
else:
|
| 1250 |
+
f.write(f" {str(k).ljust(15)}: {str(v)}\n")
|
| 1251 |
+
|
| 1252 |
+
# 样本级详细统计
|
| 1253 |
+
f.write("\n\n=== Detailed Per-Sample Metrics ===\n")
|
| 1254 |
+
for name, vertex_metrics, edge_metrics in zip(
|
| 1255 |
+
result_data['per_sample_metrics']['sample_names'],
|
| 1256 |
+
zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256, 512, 1024]]),
|
| 1257 |
+
zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256, 512, 1024]])
|
| 1258 |
+
):
|
| 1259 |
+
# 样本标题
|
| 1260 |
+
f.write(f"\n◆ Sample: {name}\n")
|
| 1261 |
+
|
| 1262 |
+
# 顶点指标
|
| 1263 |
+
f.write(f"[Vertex Prediction]\n")
|
| 1264 |
+
f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
|
| 1265 |
+
for res, metrics in zip([128, 256, 512, 1024], vertex_metrics):
|
| 1266 |
+
f.write(f" {str(res).ljust(10)} "
|
| 1267 |
+
f"{metrics['recall']:.4f} "
|
| 1268 |
+
f"{metrics['precision']:.4f} "
|
| 1269 |
+
f"{metrics['f1']:.4f} "
|
| 1270 |
+
f"{metrics['num_pred']}/{metrics['num_gt']}\n")
|
| 1271 |
+
|
| 1272 |
+
# Edge指标
|
| 1273 |
+
f.write(f"[Edge Prediction]\n")
|
| 1274 |
+
f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
|
| 1275 |
+
for res, metrics in zip([128, 256, 512, 1024], edge_metrics):
|
| 1276 |
+
f.write(f" {str(res).ljust(10)} "
|
| 1277 |
+
f"{metrics['recall']:.4f} "
|
| 1278 |
+
f"{metrics['precision']:.4f} "
|
| 1279 |
+
f"{metrics['f1']:.4f} "
|
| 1280 |
+
f"{metrics['num_pred']}/{metrics['num_gt']}\n")
|
| 1281 |
+
|
| 1282 |
+
f.write("-"*60 + "\n")
|
| 1283 |
+
|
| 1284 |
+
print(f"Saved summary to: {summary_path}")
|
| 1285 |
+
return result_data
|
| 1286 |
+
|
| 1287 |
+
|
| 1288 |
+
if __name__ == '__main__':
|
| 1289 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 1290 |
+
evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
|
| 1291 |
+
EPOCH_START = 14
|
| 1292 |
+
EPOCH_END = 460
|
| 1293 |
+
CHAMFER_EDGE_THRESHOLD=0.5
|
| 1294 |
+
NUM_SAMPLES=50
|
| 1295 |
+
VISUALIZE=True
|
| 1296 |
+
THRESHOLD=1.5
|
| 1297 |
+
VISUAL_FIELD=False
|
| 1298 |
+
|
| 1299 |
+
ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to1024_dir_sorted_dora_head_small_right/checkpoint_epoch14_batch5216_loss0.2745.pt'
|
| 1300 |
+
dataset_path = '/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/test'
|
| 1301 |
+
dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized'
|
| 1302 |
+
|
| 1303 |
+
if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
|
| 1304 |
+
RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
|
| 1305 |
+
else:
|
| 1306 |
+
RENDERS_DIR = ''
|
| 1307 |
+
|
| 1308 |
+
|
| 1309 |
+
ckpt_dir = os.path.dirname(ckpt_path)
|
| 1310 |
+
eval_dir = os.path.join(ckpt_dir, "evaluate")
|
| 1311 |
+
os.makedirs(eval_dir, exist_ok=True)
|
| 1312 |
+
|
| 1313 |
+
if False:
|
| 1314 |
+
for i in range(NUM_SAMPLES):
|
| 1315 |
+
print("--- Starting Latent Space PCA Visualization ---")
|
| 1316 |
+
tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
|
| 1317 |
+
tester.visualize_latent_space_pca(sample_idx=i)
|
| 1318 |
+
print("--- PCA Visualization Finished ---")
|
| 1319 |
+
|
| 1320 |
+
if not evaluate_all_checkpoints:
|
| 1321 |
+
evaluate_checkpoint(ckpt_path, dataset_path, eval_dir)
|
| 1322 |
+
else:
|
| 1323 |
+
pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
|
| 1324 |
+
|
| 1325 |
+
filtered_pt_files = []
|
| 1326 |
+
for f in pt_files:
|
| 1327 |
+
try:
|
| 1328 |
+
parts = f.split('_')
|
| 1329 |
+
epoch_str = parts[1].replace('epoch', '')
|
| 1330 |
+
epoch = int(epoch_str)
|
| 1331 |
+
if EPOCH_START <= epoch <= EPOCH_END:
|
| 1332 |
+
filtered_pt_files.append(f)
|
| 1333 |
+
except Exception as e:
|
| 1334 |
+
print(f"Warning: Could not parse epoch from {f}: {e}")
|
| 1335 |
+
continue
|
| 1336 |
+
|
| 1337 |
+
for pt_file in filtered_pt_files:
|
| 1338 |
+
full_ckpt_path = os.path.join(ckpt_dir, pt_file)
|
| 1339 |
+
evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir)
|
test_slat_vae_128to1024_pointnet_vae_head_woca.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_slat_vae_128to256_pointnet_vae_head.py
ADDED
|
@@ -0,0 +1,1349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
from plotly.subplots import make_subplots
|
| 10 |
+
from torch.utils.data import DataLoader, Subset
|
| 11 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 12 |
+
|
| 13 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 14 |
+
|
| 15 |
+
from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 16 |
+
from utils import load_pretrained_woself
|
| 17 |
+
|
| 18 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from functools import partial
|
| 22 |
+
import itertools
|
| 23 |
+
from typing import List, Tuple, Set
|
| 24 |
+
from collections import OrderedDict
|
| 25 |
+
from scipy.spatial import cKDTree
|
| 26 |
+
from sklearn.neighbors import KDTree
|
| 27 |
+
|
| 28 |
+
import trimesh
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
import time
|
| 33 |
+
|
| 34 |
+
from sklearn.decomposition import PCA
|
| 35 |
+
import matplotlib.pyplot as plt
|
| 36 |
+
|
| 37 |
+
import networkx as nx
|
| 38 |
+
|
| 39 |
+
def predict_mesh_connectivity(
|
| 40 |
+
connection_head,
|
| 41 |
+
vtx_feats,
|
| 42 |
+
vtx_coords,
|
| 43 |
+
batch_size=10000,
|
| 44 |
+
threshold=0.5,
|
| 45 |
+
k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测
|
| 46 |
+
device='cuda'
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
connection_head: 训练好的 MLP 模型
|
| 51 |
+
vtx_feats: [N, C] 顶点特征
|
| 52 |
+
vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边)
|
| 53 |
+
batch_size: MLP 推理的 batch size
|
| 54 |
+
threshold: 判定连接的概率阈值
|
| 55 |
+
k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。
|
| 56 |
+
"""
|
| 57 |
+
num_verts = vtx_feats.shape[0]
|
| 58 |
+
if num_verts < 3:
|
| 59 |
+
return [], [] # 无法构成三角形
|
| 60 |
+
|
| 61 |
+
connection_head.eval()
|
| 62 |
+
|
| 63 |
+
# --- 1. 生成候选边 (Candidate Edges) ---
|
| 64 |
+
if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts:
|
| 65 |
+
# 策略 A: 局部 KNN (推荐)
|
| 66 |
+
# 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版
|
| 67 |
+
# 或者直接暴力 cdist 如果 N < 10000
|
| 68 |
+
|
| 69 |
+
# 为了简单且高效,这里演示简单的 cdist (注意显存)
|
| 70 |
+
# 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree
|
| 71 |
+
dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N]
|
| 72 |
+
|
| 73 |
+
# 取 topk (smallest distance),排除自己
|
| 74 |
+
# values: [N, K], indices: [N, K]
|
| 75 |
+
_, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False)
|
| 76 |
+
neighbor_indices = indices[:, 1:] # 去掉第一列(自己)
|
| 77 |
+
|
| 78 |
+
# 构建 source, target 索引
|
| 79 |
+
src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten()
|
| 80 |
+
dst = neighbor_indices.flatten()
|
| 81 |
+
|
| 82 |
+
# 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重
|
| 83 |
+
# 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可
|
| 84 |
+
# 这里为了简单,我们生成 u < v 的 mask
|
| 85 |
+
mask = src < dst
|
| 86 |
+
u_indices = src[mask]
|
| 87 |
+
v_indices = dst[mask]
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
# 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用
|
| 91 |
+
u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device)
|
| 92 |
+
|
| 93 |
+
# --- 2. 批量推理 ---
|
| 94 |
+
all_probs = []
|
| 95 |
+
num_candidates = u_indices.shape[0]
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
for i in range(0, num_candidates, batch_size):
|
| 99 |
+
end = min(i + batch_size, num_candidates)
|
| 100 |
+
batch_u = u_indices[i:end]
|
| 101 |
+
batch_v = v_indices[i:end]
|
| 102 |
+
|
| 103 |
+
feat_u = vtx_feats[batch_u]
|
| 104 |
+
feat_v = vtx_feats[batch_v]
|
| 105 |
+
|
| 106 |
+
# Symmetric Forward (和你训练时保持一致)
|
| 107 |
+
# A -> B
|
| 108 |
+
input_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 109 |
+
logits_uv = connection_head(input_uv)
|
| 110 |
+
|
| 111 |
+
# B -> A
|
| 112 |
+
input_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 113 |
+
logits_vu = connection_head(input_vu)
|
| 114 |
+
|
| 115 |
+
# Sum logits
|
| 116 |
+
logits = logits_uv + logits_vu
|
| 117 |
+
probs = torch.sigmoid(logits)
|
| 118 |
+
all_probs.append(probs)
|
| 119 |
+
|
| 120 |
+
all_probs = torch.cat(all_probs).squeeze() # [M]
|
| 121 |
+
|
| 122 |
+
# --- 3. 筛选连接边 ---
|
| 123 |
+
connected_mask = all_probs > threshold
|
| 124 |
+
final_u = u_indices[connected_mask].cpu().numpy()
|
| 125 |
+
final_v = v_indices[connected_mask].cpu().numpy()
|
| 126 |
+
|
| 127 |
+
edges = np.stack([final_u, final_v], axis=1) # [E, 2]
|
| 128 |
+
|
| 129 |
+
return edges
|
| 130 |
+
|
| 131 |
+
def build_triangles_from_edges(edges, num_verts):
|
| 132 |
+
"""
|
| 133 |
+
从边列表构建三角形。
|
| 134 |
+
寻找图中所有的 3-Cliques (三元环)。
|
| 135 |
+
这在图论中是一个经典问题,可以使用 networkx 库。
|
| 136 |
+
"""
|
| 137 |
+
if len(edges) == 0:
|
| 138 |
+
return np.empty((0, 3), dtype=int)
|
| 139 |
+
|
| 140 |
+
G = nx.Graph()
|
| 141 |
+
G.add_nodes_from(range(num_verts))
|
| 142 |
+
G.add_edges_from(edges)
|
| 143 |
+
|
| 144 |
+
# 寻找所有的 3-cliques (三角形)
|
| 145 |
+
# enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的
|
| 146 |
+
# 或者使用 nx.triangles ? 不,那个只返回数量
|
| 147 |
+
# 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以
|
| 148 |
+
|
| 149 |
+
# 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w
|
| 150 |
+
triangles = []
|
| 151 |
+
adj = [set(G.neighbors(n)) for n in range(num_verts)]
|
| 152 |
+
|
| 153 |
+
# 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w
|
| 154 |
+
# 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u]
|
| 155 |
+
|
| 156 |
+
# 优化算法:
|
| 157 |
+
for u, v in edges:
|
| 158 |
+
if u > v: u, v = v, u # 确保有序
|
| 159 |
+
|
| 160 |
+
# 找公共邻居
|
| 161 |
+
common = adj[u].intersection(adj[v])
|
| 162 |
+
for w in common:
|
| 163 |
+
if w > v: # 强制顺序 u < v < w 防止重复
|
| 164 |
+
triangles.append([u, v, w])
|
| 165 |
+
|
| 166 |
+
return np.array(triangles)
|
| 167 |
+
|
| 168 |
+
def downsample_voxels(
|
| 169 |
+
voxels: torch.Tensor,
|
| 170 |
+
input_resolution: int,
|
| 171 |
+
output_resolution: int
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
if input_resolution % output_resolution != 0:
|
| 174 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 175 |
+
f"by output_resolution ({output_resolution}).")
|
| 176 |
+
|
| 177 |
+
factor = input_resolution // output_resolution
|
| 178 |
+
|
| 179 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 180 |
+
|
| 181 |
+
downsampled_voxels[:, 1:] //= factor
|
| 182 |
+
|
| 183 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 184 |
+
return unique_downsampled_voxels
|
| 185 |
+
|
| 186 |
+
def visualize_colored_points_ply(coords, vectors, filename):
|
| 187 |
+
"""
|
| 188 |
+
可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。
|
| 192 |
+
vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。
|
| 193 |
+
filename (str): 保存输出文件的名称,必须是 .ply 格式。
|
| 194 |
+
"""
|
| 195 |
+
# 确保输入是 numpy 数组
|
| 196 |
+
if isinstance(coords, torch.Tensor):
|
| 197 |
+
coords = coords.detach().cpu().numpy()
|
| 198 |
+
if isinstance(vectors, torch.Tensor):
|
| 199 |
+
vectors = vectors.detach().cpu().to(torch.float32).numpy()
|
| 200 |
+
|
| 201 |
+
# 检查输入数据是否为空,防止崩溃
|
| 202 |
+
if coords.size == 0 or vectors.size == 0:
|
| 203 |
+
print(f"警告:输入数据为空,未生成 {filename} 文件。")
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
# 将向量分量从 [-1, 1] 映射到 [0, 255]
|
| 207 |
+
# np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出
|
| 208 |
+
# (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2]
|
| 209 |
+
# * 127.5 将范围从 [0, 2] 缩放到 [0, 255]
|
| 210 |
+
colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8)
|
| 211 |
+
|
| 212 |
+
# 创建一个点云对象,并传入颜色信息
|
| 213 |
+
# trimesh.PointCloud 能够自动处理带颜色的点
|
| 214 |
+
points = trimesh.points.PointCloud(coords, colors=colors)
|
| 215 |
+
# 导出为 PLY 文件
|
| 216 |
+
points.export(filename, file_type='ply')
|
| 217 |
+
print(f"可视化文件已成功保存为: {filename}")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0):
|
| 221 |
+
"""
|
| 222 |
+
使用 KDTree 最近邻 + 贪心匹配算法计算顶点匹配 (欧式距离)
|
| 223 |
+
|
| 224 |
+
参数:
|
| 225 |
+
pred_coords: 预测坐标 (Tensor)
|
| 226 |
+
gt_coords: 真实坐标 (Tensor)
|
| 227 |
+
threshold: 匹配误差阈值 (默认1.0)
|
| 228 |
+
|
| 229 |
+
返回:
|
| 230 |
+
matches: 匹配成功的顶点数量
|
| 231 |
+
match_rate: 匹配率 (基于真实顶点数)
|
| 232 |
+
pred_total: 预测顶点总数
|
| 233 |
+
gt_total: 真实顶点总数
|
| 234 |
+
"""
|
| 235 |
+
# 转换为整数坐标并去重
|
| 236 |
+
print('len(pred_coords)', len(pred_coords))
|
| 237 |
+
|
| 238 |
+
pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0)
|
| 239 |
+
gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0)
|
| 240 |
+
print('len(pred_array)', len(pred_array))
|
| 241 |
+
pred_total = len(pred_array)
|
| 242 |
+
gt_total = len(gt_array)
|
| 243 |
+
|
| 244 |
+
# 如果没有点,直接返回
|
| 245 |
+
if pred_total == 0 or gt_total == 0:
|
| 246 |
+
return 0, 0.0, pred_total, gt_total
|
| 247 |
+
|
| 248 |
+
# 建立 KDTree(以 gt 为基准)
|
| 249 |
+
tree = KDTree(gt_array)
|
| 250 |
+
|
| 251 |
+
# 查找预测点到最近的 gt 点
|
| 252 |
+
dist, indices = tree.query(pred_array, k=1)
|
| 253 |
+
dist = dist.squeeze()
|
| 254 |
+
indices = indices.squeeze()
|
| 255 |
+
|
| 256 |
+
# 贪心去重:确保 1 对 1
|
| 257 |
+
matches = 0
|
| 258 |
+
used_gt = set()
|
| 259 |
+
for d, idx in zip(dist, indices):
|
| 260 |
+
if d <= threshold and idx not in used_gt:
|
| 261 |
+
matches += 1
|
| 262 |
+
used_gt.add(idx)
|
| 263 |
+
|
| 264 |
+
match_rate = matches / max(gt_total, 1)
|
| 265 |
+
|
| 266 |
+
return matches, match_rate, pred_total, gt_total
|
| 267 |
+
|
| 268 |
+
def flatten_coords_4d(coords_4d: torch.Tensor):
|
| 269 |
+
coords_4d_long = coords_4d.long()
|
| 270 |
+
|
| 271 |
+
base_x = 256
|
| 272 |
+
base_y = 256 * 256
|
| 273 |
+
base_z = 256 * 256 * 256
|
| 274 |
+
|
| 275 |
+
flat_coords = coords_4d_long[:, 0] * base_z + \
|
| 276 |
+
coords_4d_long[:, 1] * base_y + \
|
| 277 |
+
coords_4d_long[:, 2] * base_x + \
|
| 278 |
+
coords_4d_long[:, 3]
|
| 279 |
+
return flat_coords
|
| 280 |
+
|
| 281 |
+
def flatten_coords_3d(coords_3d: torch.Tensor):
|
| 282 |
+
coords_3d_long = coords_3d #.long()
|
| 283 |
+
|
| 284 |
+
base_x = 256
|
| 285 |
+
base_y = 256 * 256
|
| 286 |
+
base_z = 256 * 256 * 256
|
| 287 |
+
|
| 288 |
+
flat_coords = coords_3d_long[:, 0] * base_z + \
|
| 289 |
+
coords_3d_long[:, 1] * base_y + \
|
| 290 |
+
coords_3d_long[:, 2] * base_x
|
| 291 |
+
return flat_coords
|
| 292 |
+
|
| 293 |
+
class Tester:
|
| 294 |
+
def __init__(self, ckpt_path, config_path=None, dataset_path=None):
|
| 295 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 296 |
+
self.ckpt_path = ckpt_path
|
| 297 |
+
|
| 298 |
+
self.config = self._load_config(config_path)
|
| 299 |
+
self.dataset_path = dataset_path # or self.config['dataset']['path']
|
| 300 |
+
checkpoint = torch.load(self.ckpt_path, map_location='cpu')
|
| 301 |
+
self.epoch = checkpoint.get('epoch', 0)
|
| 302 |
+
|
| 303 |
+
self._init_models()
|
| 304 |
+
self._init_dataset()
|
| 305 |
+
|
| 306 |
+
self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results")
|
| 307 |
+
os.makedirs(self.result_dir, exist_ok=True)
|
| 308 |
+
|
| 309 |
+
dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '')
|
| 310 |
+
self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path),
|
| 311 |
+
f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs")
|
| 312 |
+
os.makedirs(self.output_voxel_dir, exist_ok=True)
|
| 313 |
+
|
| 314 |
+
self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path),
|
| 315 |
+
f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs")
|
| 316 |
+
os.makedirs(self.output_obj_dir, exist_ok=True)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str):
|
| 320 |
+
if coords.numel() == 0:
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
coords_np = coords.cpu().to(torch.float32).numpy()
|
| 324 |
+
labels_np = labels.cpu().to(torch.float32).numpy()
|
| 325 |
+
|
| 326 |
+
colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8)
|
| 327 |
+
colors[labels_np == 0] = [255, 0, 0]
|
| 328 |
+
colors[labels_np == 1] = [0, 0, 255]
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
import trimesh
|
| 332 |
+
point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors)
|
| 333 |
+
ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
|
| 334 |
+
point_cloud.export(ply_path)
|
| 335 |
+
except ImportError:
|
| 336 |
+
ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
|
| 337 |
+
with open(ply_path, 'w') as f:
|
| 338 |
+
f.write("ply\n")
|
| 339 |
+
f.write("format ascii 1.0\n")
|
| 340 |
+
f.write(f"element vertex {coords_np.shape[0]}\n")
|
| 341 |
+
f.write("property float x\n")
|
| 342 |
+
f.write("property float y\n")
|
| 343 |
+
f.write("property float z\n")
|
| 344 |
+
f.write("property uchar red\n")
|
| 345 |
+
f.write("property uchar green\n")
|
| 346 |
+
f.write("property uchar blue\n")
|
| 347 |
+
f.write("end_header\n")
|
| 348 |
+
for i in range(coords_np.shape[0]):
|
| 349 |
+
f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
|
| 350 |
+
|
| 351 |
+
def _load_config(self, config_path=None):
|
| 352 |
+
if config_path and os.path.exists(config_path):
|
| 353 |
+
with open(config_path) as f:
|
| 354 |
+
return yaml.safe_load(f)
|
| 355 |
+
|
| 356 |
+
ckpt_dir = os.path.dirname(self.ckpt_path)
|
| 357 |
+
possible_configs = [
|
| 358 |
+
os.path.join(ckpt_dir, "config.yaml"),
|
| 359 |
+
os.path.join(os.path.dirname(ckpt_dir), "config.yaml")
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
for config_file in possible_configs:
|
| 363 |
+
if os.path.exists(config_file):
|
| 364 |
+
with open(config_file) as f:
|
| 365 |
+
print(f"Loaded config from: {config_file}")
|
| 366 |
+
return yaml.safe_load(f)
|
| 367 |
+
|
| 368 |
+
checkpoint = torch.load(self.ckpt_path, map_location='cpu')
|
| 369 |
+
if 'config' in checkpoint:
|
| 370 |
+
print("Loaded config from checkpoint")
|
| 371 |
+
return checkpoint['config']
|
| 372 |
+
|
| 373 |
+
raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.")
|
| 374 |
+
|
| 375 |
+
def _init_models(self):
|
| 376 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 377 |
+
in_channels=15,
|
| 378 |
+
hidden_dim=256,
|
| 379 |
+
out_channels=1024,
|
| 380 |
+
scatter_type='mean',
|
| 381 |
+
n_blocks=5,
|
| 382 |
+
resolution=128,
|
| 383 |
+
|
| 384 |
+
).to(self.device)
|
| 385 |
+
|
| 386 |
+
self.connection_head = ConnectionHead(
|
| 387 |
+
channels=128 * 2,
|
| 388 |
+
out_channels=1,
|
| 389 |
+
mlp_ratio=4,
|
| 390 |
+
).to(self.device)
|
| 391 |
+
|
| 392 |
+
self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation
|
| 393 |
+
in_channels=self.config['model']['in_channels'],
|
| 394 |
+
latent_dim=self.config['model']['latent_dim'],
|
| 395 |
+
encoder_blocks=self.config['model']['encoder_blocks'],
|
| 396 |
+
# decoder_blocks=self.config['model']['decoder_blocks'],
|
| 397 |
+
decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'],
|
| 398 |
+
decoder_blocks_edge=self.config['model']['decoder_blocks_edge'],
|
| 399 |
+
num_heads=8,
|
| 400 |
+
num_head_channels=64,
|
| 401 |
+
mlp_ratio=4.0,
|
| 402 |
+
attn_mode="swin",
|
| 403 |
+
window_size=8,
|
| 404 |
+
pe_mode="ape",
|
| 405 |
+
use_fp16=False,
|
| 406 |
+
use_checkpoint=False,
|
| 407 |
+
qk_rms_norm=False,
|
| 408 |
+
using_subdivide=True,
|
| 409 |
+
using_attn=self.config['model']['using_attn'],
|
| 410 |
+
attn_first=self.config['model'].get('attn_first', True),
|
| 411 |
+
pred_direction=self.config['model'].get('pred_direction', False),
|
| 412 |
+
).to(self.device)
|
| 413 |
+
|
| 414 |
+
load_pretrained_woself(
|
| 415 |
+
checkpoint_path=self.ckpt_path,
|
| 416 |
+
voxel_encoder=self.voxel_encoder,
|
| 417 |
+
connection_head=self.connection_head,
|
| 418 |
+
vae=self.vae,
|
| 419 |
+
)
|
| 420 |
+
# --- 【新增】在这里添加权重检查逻辑 ---
|
| 421 |
+
print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---")
|
| 422 |
+
has_nan_inf = False
|
| 423 |
+
if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"):
|
| 424 |
+
has_nan_inf = True
|
| 425 |
+
|
| 426 |
+
if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"):
|
| 427 |
+
has_nan_inf = True
|
| 428 |
+
|
| 429 |
+
if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"):
|
| 430 |
+
has_nan_inf = True
|
| 431 |
+
|
| 432 |
+
if not has_nan_inf:
|
| 433 |
+
print("--- 权重检查通过。未发现 NaN/Inf 值。 ---")
|
| 434 |
+
else:
|
| 435 |
+
# 如果发现坏值,直接抛出异常,因为评估无法继续
|
| 436 |
+
raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。")
|
| 437 |
+
# --- 检查逻辑结束 ---
|
| 438 |
+
|
| 439 |
+
self.vae.eval()
|
| 440 |
+
self.voxel_encoder.eval()
|
| 441 |
+
self.connection_head.eval()
|
| 442 |
+
|
| 443 |
+
def _init_dataset(self):
|
| 444 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 445 |
+
root_dir=self.dataset_path,
|
| 446 |
+
base_resolution=self.config['dataset']['base_resolution'],
|
| 447 |
+
min_resolution=self.config['dataset']['min_resolution'],
|
| 448 |
+
# cache_dir='./dataset_cache/test_15c_dora',
|
| 449 |
+
cache_dir=self.config['dataset']['cache_dir'],
|
| 450 |
+
renders_dir=self.config['dataset']['renders_dir'],
|
| 451 |
+
|
| 452 |
+
# filter_active_voxels=self.config['dataset']['filter_active_voxels'],
|
| 453 |
+
filter_active_voxels=False,
|
| 454 |
+
cache_filter_path=self.config['dataset']['cache_filter_path'],
|
| 455 |
+
|
| 456 |
+
sample_type=self.config['dataset']['sample_type'],
|
| 457 |
+
active_voxel_res=128,
|
| 458 |
+
pc_sample_number=819200,
|
| 459 |
+
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
self.dataloader = DataLoader(
|
| 463 |
+
self.dataset,
|
| 464 |
+
batch_size=1,
|
| 465 |
+
shuffle=False,
|
| 466 |
+
collate_fn=partial(collate_fn_pointnet),
|
| 467 |
+
num_workers=0,
|
| 468 |
+
pin_memory=True,
|
| 469 |
+
# prefetch_factor=4,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool:
|
| 473 |
+
"""
|
| 474 |
+
检查模型的所有参数中是否存在 NaN 或 Inf 值。
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
model (torch.nn.Module): 要检查的模型。
|
| 478 |
+
model_name (str): 模型的名称,用于打印日志。
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。
|
| 482 |
+
"""
|
| 483 |
+
found_issue = False
|
| 484 |
+
for name, param in model.named_parameters():
|
| 485 |
+
if torch.isnan(param.data).any():
|
| 486 |
+
print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!")
|
| 487 |
+
found_issue = True
|
| 488 |
+
if torch.isinf(param.data).any():
|
| 489 |
+
print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!")
|
| 490 |
+
found_issue = True
|
| 491 |
+
return found_issue
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
|
| 495 |
+
"""
|
| 496 |
+
修改后的函数,确保一对一匹配,并优先匹配最近的点对。
|
| 497 |
+
"""
|
| 498 |
+
pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
|
| 499 |
+
gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
|
| 500 |
+
|
| 501 |
+
pred_total = len(pred_array)
|
| 502 |
+
gt_total = len(gt_array)
|
| 503 |
+
|
| 504 |
+
if pred_total == 0 or gt_total == 0:
|
| 505 |
+
return {
|
| 506 |
+
'recall': 0.0,
|
| 507 |
+
'precision': 0.0,
|
| 508 |
+
'f1': 0.0,
|
| 509 |
+
'matches': 0,
|
| 510 |
+
'pred_count': pred_total,
|
| 511 |
+
'gt_count': gt_total
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
# 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点
|
| 515 |
+
tree = cKDTree(pred_array)
|
| 516 |
+
dists, pred_idxs = tree.query(gt_array, k=1)
|
| 517 |
+
|
| 518 |
+
# --- 核心修改部分 ---
|
| 519 |
+
|
| 520 |
+
# 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引)
|
| 521 |
+
# 这样我们就可以按距离对所有可能的匹配进行排序
|
| 522 |
+
possible_matches = []
|
| 523 |
+
for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)):
|
| 524 |
+
if dist <= threshold:
|
| 525 |
+
possible_matches.append((dist, gt_idx, pred_idx))
|
| 526 |
+
|
| 527 |
+
# 2. 按距离从小到大排序(贪心策略)
|
| 528 |
+
possible_matches.sort(key=lambda x: x[0])
|
| 529 |
+
|
| 530 |
+
matches = 0
|
| 531 |
+
# 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配
|
| 532 |
+
used_pred_indices = set()
|
| 533 |
+
used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨
|
| 534 |
+
|
| 535 |
+
# 3. 遍历排序后的可能匹配,进行一对一分配
|
| 536 |
+
for dist, gt_idx, pred_idx in possible_matches:
|
| 537 |
+
# 如果这个预测点和这个真实点都还没有被使用过
|
| 538 |
+
if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices:
|
| 539 |
+
matches += 1
|
| 540 |
+
used_pred_indices.add(pred_idx)
|
| 541 |
+
used_gt_indices.add(gt_idx)
|
| 542 |
+
|
| 543 |
+
# --- 修改结束 ---
|
| 544 |
+
|
| 545 |
+
# matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total
|
| 546 |
+
recall = matches / gt_total if gt_total > 0 else 0.0
|
| 547 |
+
precision = matches / pred_total if pred_total > 0 else 0.0
|
| 548 |
+
|
| 549 |
+
# 计算F1时,使用标准的 Precision 和 Recall 定义
|
| 550 |
+
if (precision + recall) == 0:
|
| 551 |
+
f1 = 0.0
|
| 552 |
+
else:
|
| 553 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 554 |
+
|
| 555 |
+
return {
|
| 556 |
+
'recall': recall,
|
| 557 |
+
'precision': precision,
|
| 558 |
+
'f1': f1,
|
| 559 |
+
'matches': matches,
|
| 560 |
+
'pred_count': pred_total,
|
| 561 |
+
'gt_count': gt_total
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
|
| 565 |
+
"""
|
| 566 |
+
一个折衷的顶点指标计算方案。
|
| 567 |
+
它沿用“为每个真实点寻找最近预测点”的逻辑,
|
| 568 |
+
但通过修正计算方式,确保Precision和F1值不会超过1.0。
|
| 569 |
+
"""
|
| 570 |
+
# 假设 pred_coords 和 gt_coords 是 PyTorch 张量
|
| 571 |
+
pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
|
| 572 |
+
gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
|
| 573 |
+
|
| 574 |
+
pred_total = len(pred_array)
|
| 575 |
+
gt_total = len(gt_array)
|
| 576 |
+
|
| 577 |
+
if pred_total == 0 or gt_total == 0:
|
| 578 |
+
return {
|
| 579 |
+
'recall': 0.0,
|
| 580 |
+
'precision': 0.0,
|
| 581 |
+
'f1': 0.0,
|
| 582 |
+
'matches': 0,
|
| 583 |
+
'pred_count': pred_total,
|
| 584 |
+
'gt_count': gt_total
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
# 在预测点上构建KD-Tree,为每个真实点查找最近的预测点
|
| 588 |
+
tree = cKDTree(pred_array)
|
| 589 |
+
dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引
|
| 590 |
+
|
| 591 |
+
# 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall)
|
| 592 |
+
# 这和您的第一个函数完全一样。
|
| 593 |
+
# 这个值代表了“有多少个真实点被成功找到了”。
|
| 594 |
+
matches_from_gt = np.sum(dists <= threshold)
|
| 595 |
+
|
| 596 |
+
# 2. 计算 Recall (召回率)
|
| 597 |
+
# 召回率的分母是真实点的总数,所以这里的计算是合理的。
|
| 598 |
+
recall = matches_from_gt / gt_total if gt_total > 0 else 0.0
|
| 599 |
+
|
| 600 |
+
# 3. 计算 Precision (精确率) - ✅ 这是核心修正点
|
| 601 |
+
# 精确率的分母是预测点的总数。
|
| 602 |
+
# 分子(True Positives)不能超过预测点的总数。
|
| 603 |
+
# 因此,我们取 matches_from_gt 和 pred_total 中的较小值。
|
| 604 |
+
# 这解决了 Precision > 1 的问题。
|
| 605 |
+
tp_for_precision = min(matches_from_gt, pred_total)
|
| 606 |
+
precision = tp_for_precision / pred_total if pred_total > 0 else 0.0
|
| 607 |
+
|
| 608 |
+
# 4. 使用标准的F1分数公式
|
| 609 |
+
# 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score,
|
| 610 |
+
# 更常用的是基于 Precision 和 Recall 的调和平均数。
|
| 611 |
+
if (precision + recall) == 0:
|
| 612 |
+
f1 = 0.0
|
| 613 |
+
else:
|
| 614 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 615 |
+
|
| 616 |
+
return {
|
| 617 |
+
'recall': recall,
|
| 618 |
+
'precision': precision,
|
| 619 |
+
'f1': f1,
|
| 620 |
+
'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察
|
| 621 |
+
'pred_count': pred_total,
|
| 622 |
+
'gt_count': gt_total
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False):
|
| 626 |
+
if len(p1) == 0 or len(p2) == 0:
|
| 627 |
+
return float('nan')
|
| 628 |
+
|
| 629 |
+
dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean()
|
| 630 |
+
|
| 631 |
+
if one_sided:
|
| 632 |
+
return dist_p1_p2.item()
|
| 633 |
+
else:
|
| 634 |
+
dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean()
|
| 635 |
+
return (dist_p1_p2 + dist_p2_p1).item() / 2
|
| 636 |
+
|
| 637 |
+
def visualize_latent_space_pca(self, sample_idx: int):
|
| 638 |
+
"""
|
| 639 |
+
Encodes a sample, performs PCA on its latent features, and saves a
|
| 640 |
+
colored PLY file for visualization.
|
| 641 |
+
|
| 642 |
+
The position of each point in the PLY file corresponds to the spatial
|
| 643 |
+
location in the latent grid.
|
| 644 |
+
|
| 645 |
+
The color of each point represents the first three principal components
|
| 646 |
+
of its feature vector.
|
| 647 |
+
"""
|
| 648 |
+
print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---")
|
| 649 |
+
self.vae.eval()
|
| 650 |
+
|
| 651 |
+
try:
|
| 652 |
+
# 1. Get the latent representation for the sample
|
| 653 |
+
latent = self._get_latent_for_sample(sample_idx)
|
| 654 |
+
except ValueError as e:
|
| 655 |
+
print(f"Error: {e}")
|
| 656 |
+
return
|
| 657 |
+
|
| 658 |
+
latent_coords = latent.coords.detach().cpu().numpy()
|
| 659 |
+
latent_feats = latent.feats.detach().cpu().numpy()
|
| 660 |
+
|
| 661 |
+
if latent_feats.shape[0] < 3:
|
| 662 |
+
print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.")
|
| 663 |
+
return
|
| 664 |
+
|
| 665 |
+
print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...")
|
| 666 |
+
|
| 667 |
+
# 2. Perform PCA to reduce feature dimensions to 3
|
| 668 |
+
pca = PCA(n_components=3)
|
| 669 |
+
pca_features = pca.fit_transform(latent_feats)
|
| 670 |
+
|
| 671 |
+
print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}")
|
| 672 |
+
print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}")
|
| 673 |
+
|
| 674 |
+
# 3. Normalize the PCA components to be used as RGB colors [0, 255]
|
| 675 |
+
# We normalize each component independently to maximize color contrast
|
| 676 |
+
normalized_colors = np.zeros_like(pca_features)
|
| 677 |
+
for i in range(3):
|
| 678 |
+
min_val = pca_features[:, i].min()
|
| 679 |
+
max_val = pca_features[:, i].max()
|
| 680 |
+
if max_val - min_val > 1e-6:
|
| 681 |
+
normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val)
|
| 682 |
+
else:
|
| 683 |
+
normalized_colors[:, i] = 0.5 # Handle case of constant value
|
| 684 |
+
|
| 685 |
+
colors_uint8 = (normalized_colors * 255).astype(np.uint8)
|
| 686 |
+
|
| 687 |
+
# 4. Prepare spatial coordinates for the point cloud
|
| 688 |
+
# latent_coords is (batch_idx, x, y, z), we want the xyz part
|
| 689 |
+
spatial_coords = latent_coords[:, 1:]
|
| 690 |
+
|
| 691 |
+
# 5. Create and save the colored PLY file
|
| 692 |
+
try:
|
| 693 |
+
# Create a Trimesh PointCloud object
|
| 694 |
+
point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8)
|
| 695 |
+
|
| 696 |
+
# Define the output filename
|
| 697 |
+
filename = f"sample_{sample_idx}_latent_pca.ply"
|
| 698 |
+
ply_path = os.path.join(self.output_voxel_dir, filename)
|
| 699 |
+
|
| 700 |
+
# Export the file
|
| 701 |
+
point_cloud.export(ply_path)
|
| 702 |
+
print(f"--> Successfully saved PCA visualization to: {ply_path}")
|
| 703 |
+
|
| 704 |
+
except Exception as e:
|
| 705 |
+
print(f"Error during Trimesh export: {e}")
|
| 706 |
+
print("Please ensure 'trimesh' is installed correctly.")
|
| 707 |
+
|
| 708 |
+
def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor:
|
| 709 |
+
"""
|
| 710 |
+
Encodes a single sample and returns its latent representation.
|
| 711 |
+
"""
|
| 712 |
+
print(f"--> Encoding sample {sample_idx} to get its latent vector...")
|
| 713 |
+
# Get data for the specified sample
|
| 714 |
+
batch_data = self.dataset[sample_idx]
|
| 715 |
+
if batch_data is None:
|
| 716 |
+
raise ValueError(f"Sample at index {sample_idx} could not be loaded.")
|
| 717 |
+
|
| 718 |
+
# Use the collate function to form a batch
|
| 719 |
+
batch_data = collate_fn_pointnet([batch_data])
|
| 720 |
+
|
| 721 |
+
with torch.no_grad():
|
| 722 |
+
# 1. Get input data and move to device
|
| 723 |
+
gt_vertex_voxels_256 = batch_data['gt_vertex_voxels_256'].to(self.device)
|
| 724 |
+
combined_voxels_256 = batch_data['combined_voxels_256'].to(self.device)
|
| 725 |
+
active_coords = batch_data['active_voxels_128'].to(self.device)
|
| 726 |
+
point_cloud = batch_data['point_cloud_128'].to(self.device)
|
| 727 |
+
|
| 728 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_256, input_resolution=256, output_resolution=128)
|
| 729 |
+
edge_128 = downsample_voxels(combined_voxels_256, input_resolution=256, output_resolution=128)
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
active_voxel_feats = self.voxel_encoder(
|
| 733 |
+
p=point_cloud,
|
| 734 |
+
sparse_coords=active_coords,
|
| 735 |
+
res=128,
|
| 736 |
+
bbox_size=(-0.5, 0.5),
|
| 737 |
+
|
| 738 |
+
# voxel_label=active_labels,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
sparse_input = SparseTensor(
|
| 742 |
+
feats=active_voxel_feats,
|
| 743 |
+
coords=active_coords.int()
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
# 2. Encode to get the latent representation
|
| 747 |
+
latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,)
|
| 748 |
+
print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}")
|
| 749 |
+
return latent_128
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.):
|
| 753 |
+
total_samples = len(self.dataset)
|
| 754 |
+
eval_samples = min(num_samples or total_samples, total_samples)
|
| 755 |
+
# sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples)
|
| 756 |
+
sample_indices = range(eval_samples)
|
| 757 |
+
|
| 758 |
+
eval_dataset = Subset(self.dataset, sample_indices)
|
| 759 |
+
eval_loader = DataLoader(
|
| 760 |
+
eval_dataset,
|
| 761 |
+
batch_size=1,
|
| 762 |
+
shuffle=False,
|
| 763 |
+
collate_fn=partial(collate_fn_pointnet),
|
| 764 |
+
num_workers=self.config['training']['num_workers'],
|
| 765 |
+
pin_memory=True,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
per_sample_metrics = {
|
| 769 |
+
'vertex': {res: [] for res in [128, 256]},
|
| 770 |
+
'edge': {res: [] for res in [128, 256]},
|
| 771 |
+
'sample_names': []
|
| 772 |
+
}
|
| 773 |
+
avg_metrics = {
|
| 774 |
+
'vertex': {res: defaultdict(list) for res in [128, 256]},
|
| 775 |
+
'edge': {res: defaultdict(list) for res in [128, 256]},
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
self.vae.eval()
|
| 779 |
+
|
| 780 |
+
for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")):
|
| 781 |
+
if batch_data is None:
|
| 782 |
+
continue
|
| 783 |
+
sample_idx = sample_indices[batch_idx]
|
| 784 |
+
sample_name = f'sample_{sample_idx}'
|
| 785 |
+
per_sample_metrics['sample_names'].append(sample_name)
|
| 786 |
+
|
| 787 |
+
# batch_save_path = f"/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/215000_sample_active_vis_42seed_1000complex/gt_data_batch_{batch_idx}.pt"
|
| 788 |
+
# if not os.path.exists(batch_save_path):
|
| 789 |
+
# print(f"Warning: Saved batch file not found: {batch_save_path}")
|
| 790 |
+
# continue
|
| 791 |
+
# batch_data = torch.load(batch_save_path, map_location=self.device)
|
| 792 |
+
|
| 793 |
+
with torch.no_grad():
|
| 794 |
+
# 1. Get input data
|
| 795 |
+
combined_voxels_256 = batch_data['combined_voxels_256'].to(self.device)
|
| 796 |
+
combined_voxel_labels_256 = batch_data['combined_voxel_labels_256'].to(self.device)
|
| 797 |
+
gt_combined_endpoints_256 = batch_data['gt_combined_endpoints_256'].to(self.device)
|
| 798 |
+
gt_combined_errors_256 = batch_data['gt_combined_errors_256'].to(self.device)
|
| 799 |
+
|
| 800 |
+
edge_mask = (combined_voxel_labels_256 == 1)
|
| 801 |
+
|
| 802 |
+
gt_edge_endpoints_256 = gt_combined_endpoints_256[edge_mask].to(self.device)
|
| 803 |
+
gt_edge_errors_256 = gt_combined_errors_256[edge_mask].to(self.device)
|
| 804 |
+
gt_edge_voxels_256 = combined_voxels_256[edge_mask].to(self.device)
|
| 805 |
+
|
| 806 |
+
p1 = gt_edge_endpoints_256[:, 1:4].float()
|
| 807 |
+
p2 = gt_edge_endpoints_256[:, 4:7].float()
|
| 808 |
+
|
| 809 |
+
mask = ( (p1[:,0] < p2[:,0]) |
|
| 810 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
|
| 811 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
|
| 812 |
+
|
| 813 |
+
pA = torch.where(mask[:, None], p1, p2) # smaller one
|
| 814 |
+
pB = torch.where(mask[:, None], p2, p1) # larger one
|
| 815 |
+
|
| 816 |
+
d = pB - pA
|
| 817 |
+
dir_gt = F.normalize(d, dim=-1, eps=1e-6)
|
| 818 |
+
|
| 819 |
+
gt_vertex_voxels_256 = batch_data['gt_vertex_voxels_256'].to(self.device).int()
|
| 820 |
+
|
| 821 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_256, input_resolution=256, output_resolution=128)
|
| 822 |
+
|
| 823 |
+
edge_128 = downsample_voxels(combined_voxels_256, input_resolution=256, output_resolution=128)
|
| 824 |
+
edge_256 = combined_voxels_256
|
| 825 |
+
|
| 826 |
+
print('vtx_128.shape', vtx_128.shape)
|
| 827 |
+
print('edge_128.shape', edge_128.shape)
|
| 828 |
+
|
| 829 |
+
gt_edge_voxels_list = [
|
| 830 |
+
edge_128,
|
| 831 |
+
edge_256,
|
| 832 |
+
]
|
| 833 |
+
|
| 834 |
+
active_coords = batch_data['active_voxels_128'].to(self.device)
|
| 835 |
+
point_cloud = batch_data['point_cloud_128'].to(self.device)
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
active_voxel_feats = self.voxel_encoder(
|
| 840 |
+
p=point_cloud,
|
| 841 |
+
sparse_coords=active_coords,
|
| 842 |
+
res=128,
|
| 843 |
+
bbox_size=(-0.5, 0.5),
|
| 844 |
+
|
| 845 |
+
# voxel_label=active_labels,
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
sparse_input = SparseTensor(
|
| 849 |
+
feats=active_voxel_feats,
|
| 850 |
+
coords=active_coords.int()
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
latent_128, posterior = self.vae.encode(sparse_input)
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
# load_path = f'/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/215000_sample_active_vis_42seed_1000complex/sample_latent_{batch_idx}.pt'
|
| 857 |
+
# latent_128 = torch.load(load_path, map_location=self.device)
|
| 858 |
+
|
| 859 |
+
print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std())
|
| 860 |
+
print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean())
|
| 861 |
+
print('latent_128.coords.shape', latent_128.coords.shape)
|
| 862 |
+
|
| 863 |
+
# latent_128 = torch.load(f"/root/Trisf/output_slat_flow_matching/ckpts/1100_chair_sample/110000step_sample/sample_results_samples_{batch_idx}.pt", map_location=self.device)
|
| 864 |
+
latent_128 = SparseTensor(
|
| 865 |
+
coords=latent_128.coords,
|
| 866 |
+
feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats),
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# self.output_voxel_dir = os.path.dirname(load_path)
|
| 870 |
+
# self.output_obj_dir = os.path.dirname(load_path)
|
| 871 |
+
|
| 872 |
+
# 7. Decoding with separate vertex and edge processing
|
| 873 |
+
decoded_results = self.vae.decode(
|
| 874 |
+
latent_128,
|
| 875 |
+
gt_vertex_voxels_list=[],
|
| 876 |
+
gt_edge_voxels_list=[],
|
| 877 |
+
training=False,
|
| 878 |
+
|
| 879 |
+
inference_threshold=0.5,
|
| 880 |
+
vis_last_layer=False,
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
error = 0 # decoded_results[-1]['edge']['predicted_offset_feats']
|
| 884 |
+
|
| 885 |
+
if self.config['model'].get('pred_direction', False):
|
| 886 |
+
pred_dir = decoded_results[-1]['edge']['predicted_direction_feats']
|
| 887 |
+
zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0
|
| 888 |
+
num_zeros = zero_mask.sum().item()
|
| 889 |
+
print("Number of zero vectors:", num_zeros)
|
| 890 |
+
|
| 891 |
+
pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
|
| 892 |
+
print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
|
| 893 |
+
print('pred_dir.shape', pred_dir.shape)
|
| 894 |
+
if pred_edge_coords_3d.shape[-1] == 4:
|
| 895 |
+
pred_edge_coords_3d = pred_edge_coords_3d[:, 1:]
|
| 896 |
+
# visualize_directions(pred_edge_coords_3d, pred_dir, sample_ratio=0.02)
|
| 897 |
+
save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply")
|
| 898 |
+
# visualize_colored_points_ply(pred_edge_coords_3d - error / 2. + 0.5, pred_dir, save_pth)
|
| 899 |
+
visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth)
|
| 900 |
+
|
| 901 |
+
save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply")
|
| 902 |
+
# visualize_colored_points_ply((gt_edge_voxels_256[:, 1:] - gt_edge_errors_256[:, 1:] + 0.5), dir_gt, save_pth)
|
| 903 |
+
visualize_colored_points_ply((gt_edge_voxels_256[:, 1:]), dir_gt, save_pth)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords']
|
| 907 |
+
pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
|
| 908 |
+
|
| 909 |
+
pred_edge_coords_np = np.round(pred_edge_coords_3d.cpu().numpy()).astype(int)
|
| 910 |
+
|
| 911 |
+
gt_vertex_voxels_256 = batch_data['gt_vertex_voxels_256'][:, 1:].to(self.device)
|
| 912 |
+
gt_edge_voxels_256 = batch_data['gt_edge_voxels_256'][:, 1:].to(self.device)
|
| 913 |
+
gt_edge_coords_np = np.round(gt_edge_voxels_256.cpu().numpy()).astype(int)
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
# Calculate metrics and save results
|
| 917 |
+
matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_256, threshold=threshold,)
|
| 918 |
+
print(f"\n----- Resolution {256} vtx -----")
|
| 919 |
+
print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
|
| 920 |
+
print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
|
| 921 |
+
|
| 922 |
+
self._save_voxel_ply(pred_vtx_coords_3d / 256., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx")
|
| 923 |
+
self._save_voxel_ply((pred_edge_coords_3d - error / 2. + 0.5) / 256, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge")
|
| 924 |
+
|
| 925 |
+
self._save_voxel_ply(gt_vertex_voxels_256 / 256, torch.zeros(len(gt_vertex_voxels_256)), f"{sample_name}_gt_vertex")
|
| 926 |
+
self._save_voxel_ply((combined_voxels_256[:, 1:] - gt_combined_errors_256[:, 1:] + 0.5) / 256., torch.zeros(len(gt_combined_errors_256)), f"{sample_name}_gt_edge")
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
# Calculate vertex-specific metrics
|
| 930 |
+
matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_256[:, 1:], threshold=threshold,)
|
| 931 |
+
print(f"\n----- Resolution {256} edge -----")
|
| 932 |
+
print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
|
| 933 |
+
print('gt_edge_voxels_256.shape', gt_edge_voxels_256.shape)
|
| 934 |
+
|
| 935 |
+
print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
|
| 936 |
+
print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
|
| 937 |
+
|
| 938 |
+
pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int)
|
| 939 |
+
pred_edges = []
|
| 940 |
+
gt_vertex_coords_np = np.round(gt_vertex_voxels_256.cpu().numpy()).astype(int)
|
| 941 |
+
if visualize:
|
| 942 |
+
if pred_vtx_coords_3d.shape[-1] == 4:
|
| 943 |
+
pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float()
|
| 944 |
+
else:
|
| 945 |
+
pred_vtx_coords_float = pred_vtx_coords_3d.float()
|
| 946 |
+
|
| 947 |
+
pred_vtx_feats = decoded_results[-1]['vertex']['feats']
|
| 948 |
+
|
| 949 |
+
# ==========================================
|
| 950 |
+
# Link Prediction & Mesh Generation
|
| 951 |
+
# ==========================================
|
| 952 |
+
print("Predicting connectivity...")
|
| 953 |
+
|
| 954 |
+
# 1. 预测边
|
| 955 |
+
# 注意:K_neighbors 的设置。如果是物体,64 足够了。
|
| 956 |
+
# 如果点非常稀疏,可能需要更大。
|
| 957 |
+
pred_edges = predict_mesh_connectivity(
|
| 958 |
+
connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的
|
| 959 |
+
vtx_feats=pred_vtx_feats,
|
| 960 |
+
vtx_coords=pred_vtx_coords_float,
|
| 961 |
+
batch_size=4096,
|
| 962 |
+
threshold=0.5,
|
| 963 |
+
k_neighbors=None,
|
| 964 |
+
device=self.device
|
| 965 |
+
)
|
| 966 |
+
print(f"Predicted {len(pred_edges)} edges.")
|
| 967 |
+
|
| 968 |
+
# 2. 构建三角形
|
| 969 |
+
num_verts = pred_vtx_coords_float.shape[0]
|
| 970 |
+
pred_faces = build_triangles_from_edges(pred_edges, num_verts)
|
| 971 |
+
print(f"Constructed {len(pred_faces)} triangles.")
|
| 972 |
+
|
| 973 |
+
# 3. 保存 OBJ
|
| 974 |
+
import trimesh
|
| 975 |
+
|
| 976 |
+
# 坐标归一化/还原 (根据你的需求,这里假设你是 0-256 的体素坐标)
|
| 977 |
+
# 如果想保存为归一化坐标:
|
| 978 |
+
mesh_verts = pred_vtx_coords_float.cpu().numpy() / 256.0
|
| 979 |
+
|
| 980 |
+
# 如果有 error offset,记得加上!
|
| 981 |
+
# 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
|
| 982 |
+
# 如果 vertex 也有 offset (如 dual contouring),在这里加上
|
| 983 |
+
|
| 984 |
+
# 移动到中心 (可选)
|
| 985 |
+
mesh_verts = mesh_verts - 0.5
|
| 986 |
+
|
| 987 |
+
mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
|
| 988 |
+
|
| 989 |
+
# 过滤孤立点 (可选)
|
| 990 |
+
# mesh.remove_unreferenced_vertices()
|
| 991 |
+
|
| 992 |
+
output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj")
|
| 993 |
+
mesh.export(output_obj_path)
|
| 994 |
+
print(f"Saved mesh to {output_obj_path}")
|
| 995 |
+
|
| 996 |
+
# 保存边线 (用于 Debug)
|
| 997 |
+
# 有时候三角形很难形成,只看边也很有用
|
| 998 |
+
edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply")
|
| 999 |
+
# self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison")
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
# Process results at different resolutions
|
| 1003 |
+
for i, res in enumerate([128, 256]):
|
| 1004 |
+
if i >= len(decoded_results):
|
| 1005 |
+
continue
|
| 1006 |
+
|
| 1007 |
+
gt_key = f'gt_vertex_voxels_{res}'
|
| 1008 |
+
if gt_key not in batch_data:
|
| 1009 |
+
continue
|
| 1010 |
+
if i == 0:
|
| 1011 |
+
pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float()
|
| 1012 |
+
gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
|
| 1013 |
+
else:
|
| 1014 |
+
pred_coords_res = decoded_results[i]['vertex']['coords'].float()
|
| 1015 |
+
gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold)
|
| 1019 |
+
|
| 1020 |
+
per_sample_metrics['vertex'][res].append({
|
| 1021 |
+
'recall': v_metrics['recall'],
|
| 1022 |
+
'precision': v_metrics['precision'],
|
| 1023 |
+
'f1': v_metrics['f1'],
|
| 1024 |
+
'num_pred': len(pred_coords_res),
|
| 1025 |
+
'num_gt': len(gt_coords_res)
|
| 1026 |
+
})
|
| 1027 |
+
|
| 1028 |
+
avg_metrics['vertex'][res]['recall'].append(v_metrics['recall'])
|
| 1029 |
+
avg_metrics['vertex'][res]['precision'].append(v_metrics['precision'])
|
| 1030 |
+
avg_metrics['vertex'][res]['f1'].append(v_metrics['f1'])
|
| 1031 |
+
|
| 1032 |
+
gt_edge_key = f'gt_edge_voxels_{res}'
|
| 1033 |
+
if gt_edge_key not in batch_data:
|
| 1034 |
+
continue
|
| 1035 |
+
|
| 1036 |
+
if i == 0:
|
| 1037 |
+
pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float()
|
| 1038 |
+
# gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1039 |
+
idx = i
|
| 1040 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1041 |
+
elif i == 1:
|
| 1042 |
+
idx = i
|
| 1043 |
+
#################################
|
| 1044 |
+
# pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5
|
| 1045 |
+
# # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1046 |
+
# gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_256[:, 1:].to(self.device) + 0.5
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
+
pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
|
| 1050 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1051 |
+
|
| 1052 |
+
# self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset")
|
| 1053 |
+
# self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset")
|
| 1054 |
+
|
| 1055 |
+
else:
|
| 1056 |
+
idx = i
|
| 1057 |
+
pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
|
| 1058 |
+
# gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1059 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1060 |
+
|
| 1061 |
+
# self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res")
|
| 1062 |
+
# self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res")
|
| 1063 |
+
|
| 1064 |
+
e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold)
|
| 1065 |
+
|
| 1066 |
+
per_sample_metrics['edge'][res].append({
|
| 1067 |
+
'recall': e_metrics['recall'],
|
| 1068 |
+
'precision': e_metrics['precision'],
|
| 1069 |
+
'f1': e_metrics['f1'],
|
| 1070 |
+
'num_pred': len(pred_edge_coords_res),
|
| 1071 |
+
'num_gt': len(gt_edge_coords_res)
|
| 1072 |
+
})
|
| 1073 |
+
|
| 1074 |
+
avg_metrics['edge'][res]['recall'].append(e_metrics['recall'])
|
| 1075 |
+
avg_metrics['edge'][res]['precision'].append(e_metrics['precision'])
|
| 1076 |
+
avg_metrics['edge'][res]['f1'].append(e_metrics['f1'])
|
| 1077 |
+
|
| 1078 |
+
avg_metrics_processed = {}
|
| 1079 |
+
for category, res_dict in avg_metrics.items():
|
| 1080 |
+
avg_metrics_processed[category] = {}
|
| 1081 |
+
for resolution, metric_dict in res_dict.items():
|
| 1082 |
+
avg_metrics_processed[category][resolution] = {
|
| 1083 |
+
metric_name: np.mean(values) if values else float('nan')
|
| 1084 |
+
for metric_name, values in metric_dict.items()
|
| 1085 |
+
}
|
| 1086 |
+
|
| 1087 |
+
result_data = {
|
| 1088 |
+
'config': self.config,
|
| 1089 |
+
'checkpoint': self.ckpt_path,
|
| 1090 |
+
'dataset': self.dataset_path,
|
| 1091 |
+
'num_samples': eval_samples,
|
| 1092 |
+
'per_sample_metrics': per_sample_metrics,
|
| 1093 |
+
'avg_metrics': avg_metrics_processed
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml")
|
| 1097 |
+
with open(results_file_path, 'w') as f:
|
| 1098 |
+
yaml.dump(result_data, f, default_flow_style=False)
|
| 1099 |
+
|
| 1100 |
+
return result_data
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
def _generate_line_voxels(
|
| 1105 |
+
self,
|
| 1106 |
+
p1: torch.Tensor,
|
| 1107 |
+
p2: torch.Tensor
|
| 1108 |
+
) -> Tuple[
|
| 1109 |
+
List[Tuple[int, int, int]],
|
| 1110 |
+
List[Tuple[torch.Tensor, torch.Tensor]],
|
| 1111 |
+
List[np.ndarray]
|
| 1112 |
+
]:
|
| 1113 |
+
"""
|
| 1114 |
+
Improved version using better sampling strategy
|
| 1115 |
+
"""
|
| 1116 |
+
p1_np = p1 #.cpu().numpy()
|
| 1117 |
+
p2_np = p2 #.cpu().numpy()
|
| 1118 |
+
voxel_dict = OrderedDict()
|
| 1119 |
+
|
| 1120 |
+
# Use proper 3D line voxelization algorithm
|
| 1121 |
+
def bresenham_3d(p1, p2):
|
| 1122 |
+
"""3D Bresenham's line algorithm"""
|
| 1123 |
+
x1, y1, z1 = np.round(p1).astype(int)
|
| 1124 |
+
x2, y2, z2 = np.round(p2).astype(int)
|
| 1125 |
+
|
| 1126 |
+
points = []
|
| 1127 |
+
dx = abs(x2 - x1)
|
| 1128 |
+
dy = abs(y2 - y1)
|
| 1129 |
+
dz = abs(z2 - z1)
|
| 1130 |
+
|
| 1131 |
+
xs = 1 if x2 > x1 else -1
|
| 1132 |
+
ys = 1 if y2 > y1 else -1
|
| 1133 |
+
zs = 1 if z2 > z1 else -1
|
| 1134 |
+
|
| 1135 |
+
# Driving axis is X
|
| 1136 |
+
if dx >= dy and dx >= dz:
|
| 1137 |
+
err_1 = 2 * dy - dx
|
| 1138 |
+
err_2 = 2 * dz - dx
|
| 1139 |
+
for i in range(dx + 1):
|
| 1140 |
+
points.append((x1, y1, z1))
|
| 1141 |
+
if err_1 > 0:
|
| 1142 |
+
y1 += ys
|
| 1143 |
+
err_1 -= 2 * dx
|
| 1144 |
+
if err_2 > 0:
|
| 1145 |
+
z1 += zs
|
| 1146 |
+
err_2 -= 2 * dx
|
| 1147 |
+
err_1 += 2 * dy
|
| 1148 |
+
err_2 += 2 * dz
|
| 1149 |
+
x1 += xs
|
| 1150 |
+
|
| 1151 |
+
# Driving axis is Y
|
| 1152 |
+
elif dy >= dx and dy >= dz:
|
| 1153 |
+
err_1 = 2 * dx - dy
|
| 1154 |
+
err_2 = 2 * dz - dy
|
| 1155 |
+
for i in range(dy + 1):
|
| 1156 |
+
points.append((x1, y1, z1))
|
| 1157 |
+
if err_1 > 0:
|
| 1158 |
+
x1 += xs
|
| 1159 |
+
err_1 -= 2 * dy
|
| 1160 |
+
if err_2 > 0:
|
| 1161 |
+
z1 += zs
|
| 1162 |
+
err_2 -= 2 * dy
|
| 1163 |
+
err_1 += 2 * dx
|
| 1164 |
+
err_2 += 2 * dz
|
| 1165 |
+
y1 += ys
|
| 1166 |
+
|
| 1167 |
+
# Driving axis is Z
|
| 1168 |
+
else:
|
| 1169 |
+
err_1 = 2 * dx - dz
|
| 1170 |
+
err_2 = 2 * dy - dz
|
| 1171 |
+
for i in range(dz + 1):
|
| 1172 |
+
points.append((x1, y1, z1))
|
| 1173 |
+
if err_1 > 0:
|
| 1174 |
+
x1 += xs
|
| 1175 |
+
err_1 -= 2 * dz
|
| 1176 |
+
if err_2 > 0:
|
| 1177 |
+
y1 += ys
|
| 1178 |
+
err_2 -= 2 * dz
|
| 1179 |
+
err_1 += 2 * dx
|
| 1180 |
+
err_2 += 2 * dy
|
| 1181 |
+
z1 += zs
|
| 1182 |
+
|
| 1183 |
+
return points
|
| 1184 |
+
|
| 1185 |
+
# Get all voxels using Bresenham algorithm
|
| 1186 |
+
voxel_coords = bresenham_3d(p1_np, p2_np)
|
| 1187 |
+
|
| 1188 |
+
# Add all voxels to dictionary
|
| 1189 |
+
for coord in voxel_coords:
|
| 1190 |
+
voxel_dict[tuple(coord)] = (p1, p2)
|
| 1191 |
+
|
| 1192 |
+
voxel_coords = list(voxel_dict.keys())
|
| 1193 |
+
endpoint_pairs = list(voxel_dict.values())
|
| 1194 |
+
|
| 1195 |
+
# --- compute error vectors ---
|
| 1196 |
+
error_vectors = []
|
| 1197 |
+
diff = p2_np - p1_np
|
| 1198 |
+
d_norm_sq = np.dot(diff, diff)
|
| 1199 |
+
|
| 1200 |
+
for v in voxel_coords:
|
| 1201 |
+
v_center = np.array(v, dtype=float) + 0.5
|
| 1202 |
+
if d_norm_sq == 0: # degenerate line
|
| 1203 |
+
closest = p1_np
|
| 1204 |
+
else:
|
| 1205 |
+
t = np.dot(v_center - p1_np, diff) / d_norm_sq
|
| 1206 |
+
t = np.clip(t, 0.0, 1.0)
|
| 1207 |
+
closest = p1_np + t * diff
|
| 1208 |
+
error_vectors.append(v_center - closest)
|
| 1209 |
+
|
| 1210 |
+
return voxel_coords, endpoint_pairs, error_vectors
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
+
|
| 1214 |
+
# 使用示例
|
| 1215 |
+
def set_seed(seed: int):
|
| 1216 |
+
random.seed(seed)
|
| 1217 |
+
np.random.seed(seed)
|
| 1218 |
+
torch.manual_seed(seed)
|
| 1219 |
+
if torch.cuda.is_available():
|
| 1220 |
+
torch.cuda.manual_seed(seed)
|
| 1221 |
+
torch.cuda.manual_seed_all(seed)
|
| 1222 |
+
torch.backends.cudnn.deterministic = True
|
| 1223 |
+
torch.backends.cudnn.benchmark = False
|
| 1224 |
+
|
| 1225 |
+
def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
|
| 1226 |
+
set_seed(42)
|
| 1227 |
+
tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
|
| 1228 |
+
result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD)
|
| 1229 |
+
|
| 1230 |
+
# 生成文件名
|
| 1231 |
+
epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0]
|
| 1232 |
+
dataset_name = os.path.basename(os.path.normpath(dataset_path))
|
| 1233 |
+
|
| 1234 |
+
# 保存简版报告(TXT)
|
| 1235 |
+
summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt")
|
| 1236 |
+
with open(summary_path, 'w') as f:
|
| 1237 |
+
# 头部信息
|
| 1238 |
+
f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n")
|
| 1239 |
+
f.write(f"Dataset: {dataset_name}\n")
|
| 1240 |
+
f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n")
|
| 1241 |
+
|
| 1242 |
+
# 平均指标
|
| 1243 |
+
f.write("=== Average Metrics ===\n")
|
| 1244 |
+
for category, data in result_data['avg_metrics'].items():
|
| 1245 |
+
if isinstance(data, dict): # 处理多分辨率情况
|
| 1246 |
+
f.write(f"\n{category.upper()}:\n")
|
| 1247 |
+
for res, metrics in data.items():
|
| 1248 |
+
f.write(f" Resolution {res}:\n")
|
| 1249 |
+
for k, v in metrics.items():
|
| 1250 |
+
# 确保值是数字类型后再格式化
|
| 1251 |
+
if isinstance(v, (int, float)):
|
| 1252 |
+
f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
|
| 1253 |
+
else:
|
| 1254 |
+
f.write(f" {str(k).ljust(15)}: {str(v)}\n")
|
| 1255 |
+
else: # 处理非多分辨率情况
|
| 1256 |
+
f.write(f"\n{category.upper()}:\n")
|
| 1257 |
+
for k, v in data.items():
|
| 1258 |
+
if isinstance(v, (int, float)):
|
| 1259 |
+
f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
|
| 1260 |
+
else:
|
| 1261 |
+
f.write(f" {str(k).ljust(15)}: {str(v)}\n")
|
| 1262 |
+
|
| 1263 |
+
# 样本级详细统计
|
| 1264 |
+
f.write("\n\n=== Detailed Per-Sample Metrics ===\n")
|
| 1265 |
+
for name, vertex_metrics, edge_metrics in zip(
|
| 1266 |
+
result_data['per_sample_metrics']['sample_names'],
|
| 1267 |
+
zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256]]),
|
| 1268 |
+
zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256]])
|
| 1269 |
+
):
|
| 1270 |
+
# 样本标题
|
| 1271 |
+
f.write(f"\n◆ Sample: {name}\n")
|
| 1272 |
+
|
| 1273 |
+
# 顶点指标
|
| 1274 |
+
f.write(f"[Vertex Prediction]\n")
|
| 1275 |
+
f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
|
| 1276 |
+
for res, metrics in zip([128, 256], vertex_metrics):
|
| 1277 |
+
f.write(f" {str(res).ljust(10)} "
|
| 1278 |
+
f"{metrics['recall']:.4f} "
|
| 1279 |
+
f"{metrics['precision']:.4f} "
|
| 1280 |
+
f"{metrics['f1']:.4f} "
|
| 1281 |
+
f"{metrics['num_pred']}/{metrics['num_gt']}\n")
|
| 1282 |
+
|
| 1283 |
+
# Edge指标
|
| 1284 |
+
f.write(f"[Edge Prediction]\n")
|
| 1285 |
+
f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
|
| 1286 |
+
for res, metrics in zip([128, 256], edge_metrics):
|
| 1287 |
+
f.write(f" {str(res).ljust(10)} "
|
| 1288 |
+
f"{metrics['recall']:.4f} "
|
| 1289 |
+
f"{metrics['precision']:.4f} "
|
| 1290 |
+
f"{metrics['f1']:.4f} "
|
| 1291 |
+
f"{metrics['num_pred']}/{metrics['num_gt']}\n")
|
| 1292 |
+
|
| 1293 |
+
f.write("-"*60 + "\n")
|
| 1294 |
+
|
| 1295 |
+
print(f"Saved summary to: {summary_path}")
|
| 1296 |
+
return result_data
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
if __name__ == '__main__':
|
| 1300 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 1301 |
+
evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
|
| 1302 |
+
EPOCH_START = 12
|
| 1303 |
+
EPOCH_END = 12
|
| 1304 |
+
CHAMFER_EDGE_THRESHOLD=0.5
|
| 1305 |
+
NUM_SAMPLES=50
|
| 1306 |
+
VISUALIZE=True
|
| 1307 |
+
THRESHOLD=1.5
|
| 1308 |
+
VISUAL_FIELD=False
|
| 1309 |
+
|
| 1310 |
+
ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to256_dir_sorted_dora_head_small/checkpoint_epoch13_batch6000_loss0.1381.pt'
|
| 1311 |
+
dataset_path = '/gemini/user/private/zhaotianhao/data/test_mesh'
|
| 1312 |
+
|
| 1313 |
+
if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
|
| 1314 |
+
RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
|
| 1315 |
+
else:
|
| 1316 |
+
RENDERS_DIR = ''
|
| 1317 |
+
|
| 1318 |
+
|
| 1319 |
+
ckpt_dir = os.path.dirname(ckpt_path)
|
| 1320 |
+
eval_dir = os.path.join(ckpt_dir, "evaluate")
|
| 1321 |
+
os.makedirs(eval_dir, exist_ok=True)
|
| 1322 |
+
|
| 1323 |
+
if False:
|
| 1324 |
+
for i in range(NUM_SAMPLES):
|
| 1325 |
+
print("--- Starting Latent Space PCA Visualization ---")
|
| 1326 |
+
tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
|
| 1327 |
+
tester.visualize_latent_space_pca(sample_idx=i)
|
| 1328 |
+
print("--- PCA Visualization Finished ---")
|
| 1329 |
+
|
| 1330 |
+
if not evaluate_all_checkpoints:
|
| 1331 |
+
evaluate_checkpoint(ckpt_path, dataset_path, eval_dir)
|
| 1332 |
+
else:
|
| 1333 |
+
pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
|
| 1334 |
+
|
| 1335 |
+
filtered_pt_files = []
|
| 1336 |
+
for f in pt_files:
|
| 1337 |
+
try:
|
| 1338 |
+
parts = f.split('_')
|
| 1339 |
+
epoch_str = parts[1].replace('epoch', '')
|
| 1340 |
+
epoch = int(epoch_str)
|
| 1341 |
+
if EPOCH_START <= epoch <= EPOCH_END:
|
| 1342 |
+
filtered_pt_files.append(f)
|
| 1343 |
+
except Exception as e:
|
| 1344 |
+
print(f"Warning: Could not parse epoch from {f}: {e}")
|
| 1345 |
+
continue
|
| 1346 |
+
|
| 1347 |
+
for pt_file in filtered_pt_files:
|
| 1348 |
+
full_ckpt_path = os.path.join(ckpt_dir, pt_file)
|
| 1349 |
+
evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir)
|
test_slat_vae_128to512_pointnet_vae_head.py
ADDED
|
@@ -0,0 +1,1636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
from plotly.subplots import make_subplots
|
| 10 |
+
from torch.utils.data import DataLoader, Subset
|
| 11 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 12 |
+
|
| 13 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 14 |
+
|
| 15 |
+
from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 16 |
+
from utils import load_pretrained_woself
|
| 17 |
+
|
| 18 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from functools import partial
|
| 22 |
+
import itertools
|
| 23 |
+
from typing import List, Tuple, Set
|
| 24 |
+
from collections import OrderedDict
|
| 25 |
+
from scipy.spatial import cKDTree
|
| 26 |
+
from sklearn.neighbors import KDTree
|
| 27 |
+
|
| 28 |
+
import trimesh
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
import time
|
| 33 |
+
|
| 34 |
+
from sklearn.decomposition import PCA
|
| 35 |
+
import matplotlib.pyplot as plt
|
| 36 |
+
|
| 37 |
+
import networkx as nx
|
| 38 |
+
|
| 39 |
+
def predict_mesh_connectivity(
|
| 40 |
+
connection_head,
|
| 41 |
+
vtx_feats,
|
| 42 |
+
vtx_coords,
|
| 43 |
+
batch_size=10000,
|
| 44 |
+
threshold=0.5,
|
| 45 |
+
k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测
|
| 46 |
+
device='cuda'
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
connection_head: 训练好的 MLP 模型
|
| 51 |
+
vtx_feats: [N, C] 顶点特征
|
| 52 |
+
vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边)
|
| 53 |
+
batch_size: MLP 推理的 batch size
|
| 54 |
+
threshold: 判定连接的概率阈值
|
| 55 |
+
k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。
|
| 56 |
+
"""
|
| 57 |
+
num_verts = vtx_feats.shape[0]
|
| 58 |
+
if num_verts < 3:
|
| 59 |
+
return [], [] # 无法构成三角形
|
| 60 |
+
|
| 61 |
+
connection_head.eval()
|
| 62 |
+
|
| 63 |
+
# --- 1. 生成候选边 (Candidate Edges) ---
|
| 64 |
+
if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts:
|
| 65 |
+
# 策略 A: 局部 KNN (推荐)
|
| 66 |
+
# 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版
|
| 67 |
+
# 或者直接暴力 cdist 如果 N < 10000
|
| 68 |
+
|
| 69 |
+
# 为了简单且高效,这里演示简单的 cdist (注意显存)
|
| 70 |
+
# 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree
|
| 71 |
+
dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N]
|
| 72 |
+
|
| 73 |
+
# 取 topk (smallest distance),排除自己
|
| 74 |
+
# values: [N, K], indices: [N, K]
|
| 75 |
+
_, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False)
|
| 76 |
+
neighbor_indices = indices[:, 1:] # 去掉第一列(自己)
|
| 77 |
+
|
| 78 |
+
# 构建 source, target 索引
|
| 79 |
+
src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten()
|
| 80 |
+
dst = neighbor_indices.flatten()
|
| 81 |
+
|
| 82 |
+
# 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重
|
| 83 |
+
# 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可
|
| 84 |
+
# 这里为了简单,我们生成 u < v 的 mask
|
| 85 |
+
mask = src < dst
|
| 86 |
+
u_indices = src[mask]
|
| 87 |
+
v_indices = dst[mask]
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
# 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用
|
| 91 |
+
u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device)
|
| 92 |
+
|
| 93 |
+
# --- 2. 批量推理 ---
|
| 94 |
+
all_probs = []
|
| 95 |
+
num_candidates = u_indices.shape[0]
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
for i in range(0, num_candidates, batch_size):
|
| 99 |
+
end = min(i + batch_size, num_candidates)
|
| 100 |
+
batch_u = u_indices[i:end]
|
| 101 |
+
batch_v = v_indices[i:end]
|
| 102 |
+
|
| 103 |
+
feat_u = vtx_feats[batch_u]
|
| 104 |
+
feat_v = vtx_feats[batch_v]
|
| 105 |
+
|
| 106 |
+
# Symmetric Forward (和你训练时保持一致)
|
| 107 |
+
# A -> B
|
| 108 |
+
input_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 109 |
+
logits_uv = connection_head(input_uv)
|
| 110 |
+
|
| 111 |
+
# B -> A
|
| 112 |
+
input_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 113 |
+
logits_vu = connection_head(input_vu)
|
| 114 |
+
|
| 115 |
+
# Sum logits
|
| 116 |
+
logits = (logits_uv + logits_vu) / 2.
|
| 117 |
+
probs = torch.sigmoid(logits)
|
| 118 |
+
all_probs.append(probs)
|
| 119 |
+
|
| 120 |
+
all_probs = torch.cat(all_probs).squeeze() # [M]
|
| 121 |
+
|
| 122 |
+
# --- 3. 筛选连接边 ---
|
| 123 |
+
connected_mask = all_probs > threshold
|
| 124 |
+
final_u = u_indices[connected_mask].cpu().numpy()
|
| 125 |
+
final_v = v_indices[connected_mask].cpu().numpy()
|
| 126 |
+
|
| 127 |
+
edges = np.stack([final_u, final_v], axis=1) # [E, 2]
|
| 128 |
+
|
| 129 |
+
return edges
|
| 130 |
+
|
| 131 |
+
def build_triangles_from_edges(edges, num_verts):
|
| 132 |
+
"""
|
| 133 |
+
从边列表构建三角形。
|
| 134 |
+
寻找图中所有的 3-Cliques (三元环)。
|
| 135 |
+
这在图论中是一个经典问题,可以使用 networkx 库。
|
| 136 |
+
"""
|
| 137 |
+
if len(edges) == 0:
|
| 138 |
+
return np.empty((0, 3), dtype=int)
|
| 139 |
+
|
| 140 |
+
G = nx.Graph()
|
| 141 |
+
G.add_nodes_from(range(num_verts))
|
| 142 |
+
G.add_edges_from(edges)
|
| 143 |
+
|
| 144 |
+
# 寻找所有的 3-cliques (三角形)
|
| 145 |
+
# enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的
|
| 146 |
+
# 或者使用 nx.triangles ? 不,那个只返回数量
|
| 147 |
+
# 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以
|
| 148 |
+
|
| 149 |
+
# 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w
|
| 150 |
+
triangles = []
|
| 151 |
+
adj = [set(G.neighbors(n)) for n in range(num_verts)]
|
| 152 |
+
|
| 153 |
+
# 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w
|
| 154 |
+
# 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u]
|
| 155 |
+
|
| 156 |
+
# 优化算法:
|
| 157 |
+
for u, v in edges:
|
| 158 |
+
if u > v: u, v = v, u # 确保有序
|
| 159 |
+
|
| 160 |
+
# 找公共邻居
|
| 161 |
+
common = adj[u].intersection(adj[v])
|
| 162 |
+
for w in common:
|
| 163 |
+
if w > v: # 强制顺序 u < v < w 防止重复
|
| 164 |
+
triangles.append([u, v, w])
|
| 165 |
+
|
| 166 |
+
return np.array(triangles)
|
| 167 |
+
|
| 168 |
+
def downsample_voxels(
|
| 169 |
+
voxels: torch.Tensor,
|
| 170 |
+
input_resolution: int,
|
| 171 |
+
output_resolution: int
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
if input_resolution % output_resolution != 0:
|
| 174 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 175 |
+
f"by output_resolution ({output_resolution}).")
|
| 176 |
+
|
| 177 |
+
factor = input_resolution // output_resolution
|
| 178 |
+
|
| 179 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 180 |
+
|
| 181 |
+
downsampled_voxels[:, 1:] //= factor
|
| 182 |
+
|
| 183 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 184 |
+
return unique_downsampled_voxels
|
| 185 |
+
|
| 186 |
+
def visualize_colored_points_ply(coords, vectors, filename):
|
| 187 |
+
"""
|
| 188 |
+
可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。
|
| 192 |
+
vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。
|
| 193 |
+
filename (str): 保存输出文件的名称,必须是 .ply 格式。
|
| 194 |
+
"""
|
| 195 |
+
# 确保输入是 numpy 数组
|
| 196 |
+
if isinstance(coords, torch.Tensor):
|
| 197 |
+
coords = coords.detach().cpu().numpy()
|
| 198 |
+
if isinstance(vectors, torch.Tensor):
|
| 199 |
+
vectors = vectors.detach().cpu().to(torch.float32).numpy()
|
| 200 |
+
|
| 201 |
+
# 检查输入数据是否为空,防止崩溃
|
| 202 |
+
if coords.size == 0 or vectors.size == 0:
|
| 203 |
+
print(f"警告:输入数据为空,未生成 {filename} 文件。")
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
# 将向量分量从 [-1, 1] 映射到 [0, 255]
|
| 207 |
+
# np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出
|
| 208 |
+
# (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2]
|
| 209 |
+
# * 127.5 将范围从 [0, 2] 缩放到 [0, 255]
|
| 210 |
+
colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8)
|
| 211 |
+
|
| 212 |
+
# 创建一个点云对象,并传入颜色信息
|
| 213 |
+
# trimesh.PointCloud 能够自动处理带颜色的点
|
| 214 |
+
points = trimesh.points.PointCloud(coords, colors=colors)
|
| 215 |
+
# 导出为 PLY 文件
|
| 216 |
+
points.export(filename, file_type='ply')
|
| 217 |
+
print(f"可视化文件已成功保存为: {filename}")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0):
|
| 221 |
+
# 转换为整数坐标并去重
|
| 222 |
+
print('len(pred_coords)', len(pred_coords))
|
| 223 |
+
|
| 224 |
+
pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0)
|
| 225 |
+
gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0)
|
| 226 |
+
print('len(pred_array)', len(pred_array))
|
| 227 |
+
pred_total = len(pred_array)
|
| 228 |
+
gt_total = len(gt_array)
|
| 229 |
+
|
| 230 |
+
# 如果没有点,直接返回
|
| 231 |
+
if pred_total == 0 or gt_total == 0:
|
| 232 |
+
return 0, 0.0, pred_total, gt_total
|
| 233 |
+
|
| 234 |
+
# 建立 KDTree(以 gt 为基准)
|
| 235 |
+
tree = KDTree(gt_array)
|
| 236 |
+
|
| 237 |
+
# 查找预测点到最近的 gt 点
|
| 238 |
+
dist, indices = tree.query(pred_array, k=1)
|
| 239 |
+
dist = dist.squeeze()
|
| 240 |
+
indices = indices.squeeze()
|
| 241 |
+
|
| 242 |
+
# 贪心去重:确保 1 对 1
|
| 243 |
+
matches = 0
|
| 244 |
+
used_gt = set()
|
| 245 |
+
for d, idx in zip(dist, indices):
|
| 246 |
+
if d <= threshold and idx not in used_gt:
|
| 247 |
+
matches += 1
|
| 248 |
+
used_gt.add(idx)
|
| 249 |
+
|
| 250 |
+
match_rate = matches / max(gt_total, 1)
|
| 251 |
+
|
| 252 |
+
return matches, match_rate, pred_total, gt_total
|
| 253 |
+
|
| 254 |
+
def flatten_coords_4d(coords_4d: torch.Tensor):
|
| 255 |
+
coords_4d_long = coords_4d.long()
|
| 256 |
+
|
| 257 |
+
base_x = 512
|
| 258 |
+
base_y = 512 * 512
|
| 259 |
+
base_z = 512 * 512 * 512
|
| 260 |
+
|
| 261 |
+
flat_coords = coords_4d_long[:, 0] * base_z + \
|
| 262 |
+
coords_4d_long[:, 1] * base_y + \
|
| 263 |
+
coords_4d_long[:, 2] * base_x + \
|
| 264 |
+
coords_4d_long[:, 3]
|
| 265 |
+
return flat_coords
|
| 266 |
+
|
| 267 |
+
class Tester:
|
| 268 |
+
def __init__(self, ckpt_path, config_path=None, dataset_path=None):
|
| 269 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 270 |
+
self.ckpt_path = ckpt_path
|
| 271 |
+
|
| 272 |
+
self.config = self._load_config(config_path)
|
| 273 |
+
self.dataset_path = dataset_path # or self.config['dataset']['path']
|
| 274 |
+
checkpoint = torch.load(self.ckpt_path, map_location='cpu')
|
| 275 |
+
self.epoch = checkpoint.get('epoch', 0)
|
| 276 |
+
|
| 277 |
+
self._init_models()
|
| 278 |
+
self._init_dataset()
|
| 279 |
+
|
| 280 |
+
self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results")
|
| 281 |
+
os.makedirs(self.result_dir, exist_ok=True)
|
| 282 |
+
|
| 283 |
+
dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '')
|
| 284 |
+
self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path),
|
| 285 |
+
f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs")
|
| 286 |
+
os.makedirs(self.output_voxel_dir, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path),
|
| 289 |
+
f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs")
|
| 290 |
+
os.makedirs(self.output_obj_dir, exist_ok=True)
|
| 291 |
+
|
| 292 |
+
def _save_logit_visualization(self, dense_vol, name, sample_name, ply_threshold=0.01):
|
| 293 |
+
"""
|
| 294 |
+
保存 Logit 的 3D .npy 文件、2D 最大投影热力图,以及带颜色和透明度的 3D .ply 点云
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
dense_vol: (H, W, D) numpy array, values in [0, 1]
|
| 298 |
+
name: str (e.g., "edge" or "vertex")
|
| 299 |
+
sample_name: str
|
| 300 |
+
ply_threshold: float, 只有概率大于此值的点才会被保存
|
| 301 |
+
"""
|
| 302 |
+
# 1. 保存原始 Dense 数据 (可选)
|
| 303 |
+
npy_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_logits.npy")
|
| 304 |
+
# np.save(npy_path, dense_vol)
|
| 305 |
+
|
| 306 |
+
# 2. 生成 2D 投影热力图 (保持不变)
|
| 307 |
+
proj_x = np.max(dense_vol, axis=0)
|
| 308 |
+
proj_y = np.max(dense_vol, axis=1)
|
| 309 |
+
proj_z = np.max(dense_vol, axis=2)
|
| 310 |
+
|
| 311 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 312 |
+
im0 = axes[0].imshow(proj_x, cmap='turbo', vmin=0, vmax=1, origin='lower')
|
| 313 |
+
axes[0].set_title(f"{name} Max-Proj (YZ)")
|
| 314 |
+
im1 = axes[1].imshow(proj_y, cmap='turbo', vmin=0, vmax=1, origin='lower')
|
| 315 |
+
axes[1].set_title(f"{name} Max-Proj (XZ)")
|
| 316 |
+
im2 = axes[2].imshow(proj_z, cmap='turbo', vmin=0, vmax=1, origin='lower')
|
| 317 |
+
axes[2].set_title(f"{name} Max-Proj (XY)")
|
| 318 |
+
|
| 319 |
+
fig.colorbar(im2, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
|
| 320 |
+
plt.suptitle(f"{sample_name} - {name} Occupancy Probability")
|
| 321 |
+
|
| 322 |
+
png_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_heatmap.png")
|
| 323 |
+
plt.savefig(png_path, dpi=150)
|
| 324 |
+
plt.close(fig)
|
| 325 |
+
|
| 326 |
+
# ------------------------------------------------------------------
|
| 327 |
+
# 3. 保存为带颜色和透明度(RGBA)的 PLY 点云
|
| 328 |
+
# ------------------------------------------------------------------
|
| 329 |
+
# 筛选出概率大于阈值的点坐标
|
| 330 |
+
indices = np.argwhere(dense_vol > ply_threshold)
|
| 331 |
+
|
| 332 |
+
if len(indices) > 0:
|
| 333 |
+
# 获取这些点的概率值 [0, 1]
|
| 334 |
+
values = dense_vol[indices[:, 0], indices[:, 1], indices[:, 2]]
|
| 335 |
+
|
| 336 |
+
# 使用 matplotlib 的 colormap 进行颜色映射
|
| 337 |
+
import matplotlib.cm as cm
|
| 338 |
+
cmap = cm.get_cmap('turbo')
|
| 339 |
+
|
| 340 |
+
# map values [0, 1] to RGBA [0, 1] (N, 4)
|
| 341 |
+
colors_float = cmap(values)
|
| 342 |
+
|
| 343 |
+
# -------------------------------------------------------
|
| 344 |
+
# 【核心修改】:修改 Alpha 通道 (透明度)
|
| 345 |
+
# -------------------------------------------------------
|
| 346 |
+
# 让透明度直接等于概率值。
|
| 347 |
+
# 概率 1.0 -> Alpha 1.0 (完全不透明/颜色深)
|
| 348 |
+
# 概率 0.1 -> Alpha 0.1 (非常透明/颜色浅)
|
| 349 |
+
colors_float[:, 3] = values
|
| 350 |
+
|
| 351 |
+
# 转换为 uint8 [0, 255],保留 4 个通道 (R, G, B, A)
|
| 352 |
+
colors_uint8 = (colors_float * 255).astype(np.uint8)
|
| 353 |
+
|
| 354 |
+
# 坐标转换
|
| 355 |
+
vertices = indices
|
| 356 |
+
|
| 357 |
+
ply_filename = f"{sample_name}_{name}_logits_colored.ply"
|
| 358 |
+
ply_save_path = os.path.join(self.output_voxel_dir, ply_filename)
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
# 使用 Trimesh 保存 (Trimesh 支持 (N, 4) 的 colors)
|
| 362 |
+
pcd = trimesh.points.PointCloud(vertices=vertices, colors=colors_uint8)
|
| 363 |
+
pcd.export(ply_save_path)
|
| 364 |
+
print(f"Saved colored RGBA logit PLY to {ply_save_path}")
|
| 365 |
+
except Exception as e:
|
| 366 |
+
print(f"Failed to save PLY with trimesh: {e}")
|
| 367 |
+
# Fallback: 手动写入 PLY (需要添加 alpha 属性)
|
| 368 |
+
with open(ply_save_path, 'w') as f:
|
| 369 |
+
f.write("ply\n")
|
| 370 |
+
f.write("format ascii 1.0\n")
|
| 371 |
+
f.write(f"element vertex {len(vertices)}\n")
|
| 372 |
+
f.write("property float x\n")
|
| 373 |
+
f.write("property float y\n")
|
| 374 |
+
f.write("property float z\n")
|
| 375 |
+
f.write("property uchar red\n")
|
| 376 |
+
f.write("property uchar green\n")
|
| 377 |
+
f.write("property uchar blue\n")
|
| 378 |
+
f.write("property uchar alpha\n") # 新增 Alpha 属性
|
| 379 |
+
f.write("end_header\n")
|
| 380 |
+
for i in range(len(vertices)):
|
| 381 |
+
v = vertices[i]
|
| 382 |
+
c = colors_uint8[i] # c is now (R, G, B, A)
|
| 383 |
+
f.write(f"{v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]} {c[3]}\n")
|
| 384 |
+
|
| 385 |
+
def _point_line_segment_distance(self, px, py, pz, x1, y1, z1, x2, y2, z2):
|
| 386 |
+
"""
|
| 387 |
+
计算点 (px,py,pz) 到线段 (x1,y1,z1)-(x2,y2,z2) 的最短距离的平方。
|
| 388 |
+
全部输入为 Tensor,支持广播。
|
| 389 |
+
"""
|
| 390 |
+
# 线段向量 AB
|
| 391 |
+
ABx = x2 - x1
|
| 392 |
+
ABy = y2 - y1
|
| 393 |
+
ABz = z2 - z1
|
| 394 |
+
|
| 395 |
+
# 向量 AP
|
| 396 |
+
APx = px - x1
|
| 397 |
+
APy = py - y1
|
| 398 |
+
APz = pz - z1
|
| 399 |
+
|
| 400 |
+
# AB 的长度平方
|
| 401 |
+
AB_sq = ABx**2 + ABy**2 + ABz**2
|
| 402 |
+
|
| 403 |
+
# 避免除以0 (如果两端点重合)
|
| 404 |
+
AB_sq = torch.clamp(AB_sq, min=1e-6)
|
| 405 |
+
|
| 406 |
+
# 投影系数 t = (AP · AB) / |AB|^2
|
| 407 |
+
t = (APx * ABx + APy * ABy + APz * ABz) / AB_sq
|
| 408 |
+
|
| 409 |
+
# 限制 t 在 [0, 1] 之间(线段约束)
|
| 410 |
+
t = torch.clamp(t, 0.0, 1.0)
|
| 411 |
+
|
| 412 |
+
# 最近点 (Projection)
|
| 413 |
+
closestX = x1 + t * ABx
|
| 414 |
+
closestY = y1 + t * ABy
|
| 415 |
+
closestZ = z1 + t * ABz
|
| 416 |
+
|
| 417 |
+
# 距离平方
|
| 418 |
+
dx = px - closestX
|
| 419 |
+
dy = py - closestY
|
| 420 |
+
dz = pz - closestZ
|
| 421 |
+
|
| 422 |
+
return dx**2 + dy**2 + dz**2
|
| 423 |
+
|
| 424 |
+
def _extract_mesh_projection_based(
|
| 425 |
+
self,
|
| 426 |
+
vtx_result: dict,
|
| 427 |
+
edge_result: dict,
|
| 428 |
+
resolution: int = 1024,
|
| 429 |
+
vtx_prob_threshold: float = 0.5,
|
| 430 |
+
|
| 431 |
+
# --- 你的新逻辑参数 ---
|
| 432 |
+
search_radius: float = 128.0, # 1. 候选边最大长度
|
| 433 |
+
project_dist_thresh: float = 1.5, # 2. 投影距离阈值 (管子半径,单位:voxel)
|
| 434 |
+
dir_align_threshold: float = 0.6, # 3. 方向相似度阈值 (cos theta)
|
| 435 |
+
connect_ratio_threshold: float = 0.4, # 4. 最终连接阈值 (匹配点数 / 理论长度)
|
| 436 |
+
|
| 437 |
+
edge_prob_threshold: float = 0.1, # 仅仅用于提取"存在的"体素
|
| 438 |
+
):
|
| 439 |
+
t_start = time.perf_counter()
|
| 440 |
+
|
| 441 |
+
# ---------------------------------------------------------------------
|
| 442 |
+
# 1. 准备全局数据:提取所有"活着"的 Edge Voxels (作为点云处理)
|
| 443 |
+
# ---------------------------------------------------------------------
|
| 444 |
+
e_probs = torch.sigmoid(edge_result['occ_probs'][:, 0])
|
| 445 |
+
e_coords = edge_result['coords_4d'][:, 1:].float() # (N, 3)
|
| 446 |
+
|
| 447 |
+
# 获取方向向量
|
| 448 |
+
if 'predicted_direction_feats' in edge_result:
|
| 449 |
+
e_dirs = edge_result['predicted_direction_feats'] # (N, 3)
|
| 450 |
+
# 归一化方向
|
| 451 |
+
e_dirs = F.normalize(e_dirs, p=2, dim=1)
|
| 452 |
+
else:
|
| 453 |
+
print("Warning: No direction features, using dummy.")
|
| 454 |
+
e_dirs = torch.zeros_like(e_coords)
|
| 455 |
+
|
| 456 |
+
# 筛选有效的 Edge Voxels (Global Point Cloud)
|
| 457 |
+
valid_mask = e_probs > edge_prob_threshold
|
| 458 |
+
|
| 459 |
+
cloud_coords = e_coords[valid_mask] # (M, 3)
|
| 460 |
+
cloud_dirs = e_dirs[valid_mask] # (M, 3)
|
| 461 |
+
|
| 462 |
+
num_cloud = cloud_coords.shape[0]
|
| 463 |
+
print(f"[Projection] Global active edge voxels: {num_cloud}")
|
| 464 |
+
|
| 465 |
+
if num_cloud == 0:
|
| 466 |
+
return [], []
|
| 467 |
+
|
| 468 |
+
# ---------------------------------------------------------------------
|
| 469 |
+
# 2. 准备顶点和候选边
|
| 470 |
+
# ---------------------------------------------------------------------
|
| 471 |
+
v_probs = torch.sigmoid(vtx_result['occ_probs'][:, 0])
|
| 472 |
+
v_coords = vtx_result['coords_4d'][:, 1:].float()
|
| 473 |
+
v_mask = v_probs > vtx_prob_threshold
|
| 474 |
+
valid_v_coords = v_coords[v_mask] # (V, 3)
|
| 475 |
+
|
| 476 |
+
if valid_v_coords.shape[0] < 2:
|
| 477 |
+
return valid_v_coords.cpu().numpy() / resolution, []
|
| 478 |
+
|
| 479 |
+
# 生成所有可能的候选边 (基于距离粗筛)
|
| 480 |
+
dists = torch.cdist(valid_v_coords, valid_v_coords)
|
| 481 |
+
triu_mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
|
| 482 |
+
cand_mask = (dists < search_radius) & triu_mask
|
| 483 |
+
cand_indices = torch.nonzero(cand_mask, as_tuple=False) # (E_cand, 2)
|
| 484 |
+
|
| 485 |
+
p1s = valid_v_coords[cand_indices[:, 0]] # (E, 3)
|
| 486 |
+
p2s = valid_v_coords[cand_indices[:, 1]] # (E, 3)
|
| 487 |
+
|
| 488 |
+
num_candidates = p1s.shape[0]
|
| 489 |
+
print(f"[Projection] Checking {num_candidates} candidate pairs...")
|
| 490 |
+
|
| 491 |
+
# ---------------------------------------------------------------------
|
| 492 |
+
# 3. 循环处理候选边 (使用 Bounding Box 快速裁剪)
|
| 493 |
+
# ---------------------------------------------------------------------
|
| 494 |
+
final_edges = []
|
| 495 |
+
|
| 496 |
+
# 预计算所有候选边的方向和长度
|
| 497 |
+
edge_vecs = p2s - p1s
|
| 498 |
+
edge_lengths = torch.norm(edge_vecs, dim=1)
|
| 499 |
+
edge_dirs = F.normalize(edge_vecs, p=2, dim=1)
|
| 500 |
+
|
| 501 |
+
# 为了避免显存爆炸,也不要在 Python 里做太慢的循环
|
| 502 |
+
# 我们对点云进行操作太慢,对每一条边去遍历整个点云也太慢。
|
| 503 |
+
# 策略:
|
| 504 |
+
# 我们循环“边”,但在循环内部利用 mask 快速筛选点云。
|
| 505 |
+
# 由于 Python 循环 10000 次会很慢,我们只处理那些有希望的边。
|
| 506 |
+
# 这里为了演示逻辑的准确性,我们使用简单的循环,但在 GPU 上做计算。
|
| 507 |
+
|
| 508 |
+
# 将全局点云拆分到各个坐标轴,便于快速 BBox 筛选
|
| 509 |
+
cx, cy, cz = cloud_coords[:, 0], cloud_coords[:, 1], cloud_coords[:, 2]
|
| 510 |
+
|
| 511 |
+
# 优化:如果候选边太多,可以分块。这里假设边在 5万以内,点在 10万以内,可以处理。
|
| 512 |
+
|
| 513 |
+
# 这一步是瓶颈,我们尝试用 Python 循环,但只对局部点计算
|
| 514 |
+
# 为了加速,我们可以将点云放入 HashGrid 或者只是简单的 BBox Check。
|
| 515 |
+
|
| 516 |
+
# 让我们用简单的逻辑:对于每条边,找出 BBox 内的点,算距离。
|
| 517 |
+
# 这里的 batch_size 是指一次并行处理多少条边
|
| 518 |
+
|
| 519 |
+
batch_size = 128 # 每次处理 128 条边
|
| 520 |
+
|
| 521 |
+
for i in range(0, num_candidates, batch_size):
|
| 522 |
+
end = min(i + batch_size, num_candidates)
|
| 523 |
+
|
| 524 |
+
# 当前批次的边数据
|
| 525 |
+
b_p1 = p1s[i:end] # (B, 3)
|
| 526 |
+
b_p2 = p2s[i:end] # (B, 3)
|
| 527 |
+
b_dirs = edge_dirs[i:end] # (B, 3)
|
| 528 |
+
b_lens = edge_lengths[i:end] # (B,)
|
| 529 |
+
|
| 530 |
+
# --- 步骤 A: 投影 & 距离检查 ---
|
| 531 |
+
# 这是一个 (B, M) 的大矩阵计算,容易 OOM。
|
| 532 |
+
# M (点云数) 可能很大。
|
| 533 |
+
# 解决方法:我们反过来思考。
|
| 534 |
+
# 不计算矩阵,我们只对单个边进行循环?太慢。
|
| 535 |
+
|
| 536 |
+
# 实用优化:只对 bounding box 内的点进行距离计算。
|
| 537 |
+
# 由于 GPU 难以动态索引不规则数据,我们还是逐个边循环比较稳妥,
|
| 538 |
+
# 但为了 Python 速度,必须尽可能向量化。
|
| 539 |
+
|
| 540 |
+
# 这里我采用一种折中方案:逐个处理边,但是利用 torch.where 快速定位。
|
| 541 |
+
# 实际上,对于 Python 里的 for loop,几千次是可以接受的。
|
| 542 |
+
|
| 543 |
+
current_edges_indices = cand_indices[i:end]
|
| 544 |
+
|
| 545 |
+
for j in range(len(b_p1)):
|
| 546 |
+
# 单条边处理
|
| 547 |
+
p1 = b_p1[j]
|
| 548 |
+
p2 = b_p2[j]
|
| 549 |
+
e_dir = b_dirs[j]
|
| 550 |
+
e_len = b_lens[j].item()
|
| 551 |
+
|
| 552 |
+
# 1. Bounding Box Filter (快速大幅裁剪)
|
| 553 |
+
# 找出这条边 BBox 范围内的所有点 (+ padding)
|
| 554 |
+
padding = project_dist_thresh + 2.0
|
| 555 |
+
min_xyz = torch.min(p1, p2) - padding
|
| 556 |
+
max_xyz = torch.max(p1, p2) + padding
|
| 557 |
+
|
| 558 |
+
# 利用 boolean mask 筛选
|
| 559 |
+
mask_x = (cx >= min_xyz[0]) & (cx <= max_xyz[0])
|
| 560 |
+
mask_y = (cy >= min_xyz[1]) & (cy <= max_xyz[1])
|
| 561 |
+
mask_z = (cz >= min_xyz[2]) & (cz <= max_xyz[2])
|
| 562 |
+
bbox_mask = mask_x & mask_y & mask_z
|
| 563 |
+
|
| 564 |
+
subset_coords = cloud_coords[bbox_mask]
|
| 565 |
+
subset_dirs = cloud_dirs[bbox_mask]
|
| 566 |
+
|
| 567 |
+
if subset_coords.shape[0] == 0:
|
| 568 |
+
continue
|
| 569 |
+
|
| 570 |
+
# 2. 精确距离计算 (Projection Distance)
|
| 571 |
+
# 计算 subset 中每个点到线段 p1-p2 的距离平方
|
| 572 |
+
dist_sq = self._point_line_segment_distance(
|
| 573 |
+
subset_coords[:, 0], subset_coords[:, 1], subset_coords[:, 2],
|
| 574 |
+
p1[0], p1[1], p1[2],
|
| 575 |
+
p2[0], p2[1], p2[2]
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
# 3. 距离阈值过滤 (Keep voxels inside the tube)
|
| 579 |
+
dist_mask = dist_sq < (project_dist_thresh ** 2)
|
| 580 |
+
|
| 581 |
+
# 获取在管子内部的体素
|
| 582 |
+
tube_dirs = subset_dirs[dist_mask]
|
| 583 |
+
|
| 584 |
+
if tube_dirs.shape[0] == 0:
|
| 585 |
+
continue
|
| 586 |
+
|
| 587 |
+
# 4. 方向一致性检查 (Direction Check)
|
| 588 |
+
# 计算点积 (cos theta)
|
| 589 |
+
# e_dir 是 (3,), tube_dirs 是 (K, 3)
|
| 590 |
+
dot_prod = torch.matmul(tube_dirs, e_dir)
|
| 591 |
+
|
| 592 |
+
# 这里使用 abs,因为边可能是无向的,或者网络预测可能反向
|
| 593 |
+
# 如果你的网络严格预测流向,可以去掉 abs
|
| 594 |
+
dir_sim = torch.abs(dot_prod)
|
| 595 |
+
|
| 596 |
+
# 统计方向符合要求的体素数量
|
| 597 |
+
valid_voxel_count = (dir_sim > dir_align_threshold).sum().item()
|
| 598 |
+
|
| 599 |
+
# 5. 比值判决 (Ratio Check)
|
| 600 |
+
# 量化出的 Voxel 数目 ≈ 边的长度 (e_len)
|
| 601 |
+
# 如果 e_len 很小(比如<1),我们设为1防止除以0
|
| 602 |
+
theoretical_count = max(e_len, 1.0)
|
| 603 |
+
|
| 604 |
+
ratio = valid_voxel_count / theoretical_count
|
| 605 |
+
|
| 606 |
+
if ratio > connect_ratio_threshold:
|
| 607 |
+
# 找到了!
|
| 608 |
+
global_idx = i + j
|
| 609 |
+
edge_tuple = cand_indices[global_idx].cpu().numpy().tolist()
|
| 610 |
+
final_edges.append(edge_tuple)
|
| 611 |
+
|
| 612 |
+
t_end = time.perf_counter()
|
| 613 |
+
print(f"[Projection] Logic finished. Accepted {len(final_edges)} edges. Time={t_end - t_start:.4f}s")
|
| 614 |
+
|
| 615 |
+
out_vertices = valid_v_coords.cpu().numpy() / resolution
|
| 616 |
+
return out_vertices, final_edges
|
| 617 |
+
|
| 618 |
+
def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str):
|
| 619 |
+
if coords.numel() == 0:
|
| 620 |
+
return
|
| 621 |
+
|
| 622 |
+
coords_np = coords.cpu().to(torch.float32).numpy()
|
| 623 |
+
labels_np = labels.cpu().to(torch.float32).numpy()
|
| 624 |
+
|
| 625 |
+
colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8)
|
| 626 |
+
colors[labels_np == 0] = [255, 0, 0]
|
| 627 |
+
colors[labels_np == 1] = [0, 0, 255]
|
| 628 |
+
|
| 629 |
+
try:
|
| 630 |
+
import trimesh
|
| 631 |
+
point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors)
|
| 632 |
+
ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
|
| 633 |
+
point_cloud.export(ply_path)
|
| 634 |
+
except ImportError:
|
| 635 |
+
ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
|
| 636 |
+
with open(ply_path, 'w') as f:
|
| 637 |
+
f.write("ply\n")
|
| 638 |
+
f.write("format ascii 1.0\n")
|
| 639 |
+
f.write(f"element vertex {coords_np.shape[0]}\n")
|
| 640 |
+
f.write("property float x\n")
|
| 641 |
+
f.write("property float y\n")
|
| 642 |
+
f.write("property float z\n")
|
| 643 |
+
f.write("property uchar red\n")
|
| 644 |
+
f.write("property uchar green\n")
|
| 645 |
+
f.write("property uchar blue\n")
|
| 646 |
+
f.write("end_header\n")
|
| 647 |
+
for i in range(coords_np.shape[0]):
|
| 648 |
+
f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
|
| 649 |
+
|
| 650 |
+
def _load_config(self, config_path=None):
|
| 651 |
+
if config_path and os.path.exists(config_path):
|
| 652 |
+
with open(config_path) as f:
|
| 653 |
+
return yaml.safe_load(f)
|
| 654 |
+
|
| 655 |
+
ckpt_dir = os.path.dirname(self.ckpt_path)
|
| 656 |
+
possible_configs = [
|
| 657 |
+
os.path.join(ckpt_dir, "config.yaml"),
|
| 658 |
+
os.path.join(os.path.dirname(ckpt_dir), "config.yaml")
|
| 659 |
+
]
|
| 660 |
+
|
| 661 |
+
for config_file in possible_configs:
|
| 662 |
+
if os.path.exists(config_file):
|
| 663 |
+
with open(config_file) as f:
|
| 664 |
+
print(f"Loaded config from: {config_file}")
|
| 665 |
+
return yaml.safe_load(f)
|
| 666 |
+
|
| 667 |
+
checkpoint = torch.load(self.ckpt_path, map_location='cpu')
|
| 668 |
+
if 'config' in checkpoint:
|
| 669 |
+
print("Loaded config from checkpoint")
|
| 670 |
+
return checkpoint['config']
|
| 671 |
+
|
| 672 |
+
raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.")
|
| 673 |
+
|
| 674 |
+
def _init_models(self):
|
| 675 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 676 |
+
in_channels=15,
|
| 677 |
+
hidden_dim=256,
|
| 678 |
+
out_channels=1024,
|
| 679 |
+
scatter_type='mean',
|
| 680 |
+
n_blocks=5,
|
| 681 |
+
resolution=128,
|
| 682 |
+
|
| 683 |
+
).to(self.device)
|
| 684 |
+
|
| 685 |
+
self.connection_head = ConnectionHead(
|
| 686 |
+
channels=128 * 2,
|
| 687 |
+
out_channels=1,
|
| 688 |
+
mlp_ratio=4,
|
| 689 |
+
).to(self.device)
|
| 690 |
+
|
| 691 |
+
self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation
|
| 692 |
+
in_channels=self.config['model']['in_channels'],
|
| 693 |
+
latent_dim=self.config['model']['latent_dim'],
|
| 694 |
+
encoder_blocks=self.config['model']['encoder_blocks'],
|
| 695 |
+
# decoder_blocks=self.config['model']['decoder_blocks'],
|
| 696 |
+
decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'],
|
| 697 |
+
decoder_blocks_edge=self.config['model']['decoder_blocks_edge'],
|
| 698 |
+
num_heads=8,
|
| 699 |
+
num_head_channels=64,
|
| 700 |
+
mlp_ratio=4.0,
|
| 701 |
+
attn_mode="swin",
|
| 702 |
+
window_size=8,
|
| 703 |
+
pe_mode="ape",
|
| 704 |
+
use_fp16=False,
|
| 705 |
+
use_checkpoint=False,
|
| 706 |
+
qk_rms_norm=False,
|
| 707 |
+
using_subdivide=True,
|
| 708 |
+
using_attn=self.config['model']['using_attn'],
|
| 709 |
+
attn_first=self.config['model'].get('attn_first', True),
|
| 710 |
+
pred_direction=self.config['model'].get('pred_direction', False),
|
| 711 |
+
).to(self.device)
|
| 712 |
+
|
| 713 |
+
load_pretrained_woself(
|
| 714 |
+
checkpoint_path=self.ckpt_path,
|
| 715 |
+
voxel_encoder=self.voxel_encoder,
|
| 716 |
+
connection_head=self.connection_head,
|
| 717 |
+
vae=self.vae,
|
| 718 |
+
)
|
| 719 |
+
# --- 【新增】在这里添加权重检查逻辑 ---
|
| 720 |
+
print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---")
|
| 721 |
+
has_nan_inf = False
|
| 722 |
+
if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"):
|
| 723 |
+
has_nan_inf = True
|
| 724 |
+
|
| 725 |
+
if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"):
|
| 726 |
+
has_nan_inf = True
|
| 727 |
+
|
| 728 |
+
if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"):
|
| 729 |
+
has_nan_inf = True
|
| 730 |
+
|
| 731 |
+
if not has_nan_inf:
|
| 732 |
+
print("--- 权重检查通过。未发现 NaN/Inf 值。 ---")
|
| 733 |
+
else:
|
| 734 |
+
# 如果发现坏值,直接抛出异常,因为评估无法继续
|
| 735 |
+
raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。")
|
| 736 |
+
# --- 检查逻辑结束 ---
|
| 737 |
+
|
| 738 |
+
self.vae.eval()
|
| 739 |
+
self.voxel_encoder.eval()
|
| 740 |
+
self.connection_head.eval()
|
| 741 |
+
|
| 742 |
+
def _init_dataset(self):
|
| 743 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 744 |
+
root_dir=self.dataset_path,
|
| 745 |
+
base_resolution=self.config['dataset']['base_resolution'],
|
| 746 |
+
min_resolution=self.config['dataset']['min_resolution'],
|
| 747 |
+
cache_dir='/gemini/user/private/zhaotianhao/dataset_cache/test_15c_dora',
|
| 748 |
+
# cache_dir=self.config['dataset']['cache_dir'],
|
| 749 |
+
renders_dir=self.config['dataset']['renders_dir'],
|
| 750 |
+
|
| 751 |
+
# filter_active_voxels=self.config['dataset']['filter_active_voxels'],
|
| 752 |
+
filter_active_voxels=False,
|
| 753 |
+
cache_filter_path=self.config['dataset']['cache_filter_path'],
|
| 754 |
+
|
| 755 |
+
sample_type=self.config['dataset']['sample_type'],
|
| 756 |
+
active_voxel_res=128,
|
| 757 |
+
pc_sample_number=819200,
|
| 758 |
+
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
self.dataloader = DataLoader(
|
| 762 |
+
self.dataset,
|
| 763 |
+
batch_size=1,
|
| 764 |
+
shuffle=False,
|
| 765 |
+
collate_fn=partial(collate_fn_pointnet),
|
| 766 |
+
num_workers=0,
|
| 767 |
+
pin_memory=True,
|
| 768 |
+
# prefetch_factor=4,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool:
|
| 772 |
+
"""
|
| 773 |
+
检查模型的所有参数中是否存在 NaN 或 Inf 值。
|
| 774 |
+
|
| 775 |
+
Args:
|
| 776 |
+
model (torch.nn.Module): 要检查的模型。
|
| 777 |
+
model_name (str): 模型的名称,用于打印日志。
|
| 778 |
+
|
| 779 |
+
Returns:
|
| 780 |
+
bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。
|
| 781 |
+
"""
|
| 782 |
+
found_issue = False
|
| 783 |
+
for name, param in model.named_parameters():
|
| 784 |
+
if torch.isnan(param.data).any():
|
| 785 |
+
print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!")
|
| 786 |
+
found_issue = True
|
| 787 |
+
if torch.isinf(param.data).any():
|
| 788 |
+
print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!")
|
| 789 |
+
found_issue = True
|
| 790 |
+
return found_issue
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
|
| 794 |
+
"""
|
| 795 |
+
修改后的函数,确保一对一匹配,并优先匹配最近的点对。
|
| 796 |
+
"""
|
| 797 |
+
pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
|
| 798 |
+
gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
|
| 799 |
+
|
| 800 |
+
pred_total = len(pred_array)
|
| 801 |
+
gt_total = len(gt_array)
|
| 802 |
+
|
| 803 |
+
if pred_total == 0 or gt_total == 0:
|
| 804 |
+
return {
|
| 805 |
+
'recall': 0.0,
|
| 806 |
+
'precision': 0.0,
|
| 807 |
+
'f1': 0.0,
|
| 808 |
+
'matches': 0,
|
| 809 |
+
'pred_count': pred_total,
|
| 810 |
+
'gt_count': gt_total
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
# 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点
|
| 814 |
+
tree = cKDTree(pred_array)
|
| 815 |
+
dists, pred_idxs = tree.query(gt_array, k=1)
|
| 816 |
+
|
| 817 |
+
# --- 核心修改部分 ---
|
| 818 |
+
|
| 819 |
+
# 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引)
|
| 820 |
+
# 这样我们就可以按距离对所有可能的匹配进行排序
|
| 821 |
+
possible_matches = []
|
| 822 |
+
for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)):
|
| 823 |
+
if dist <= threshold:
|
| 824 |
+
possible_matches.append((dist, gt_idx, pred_idx))
|
| 825 |
+
|
| 826 |
+
# 2. 按距离从小到大排序(贪心策略)
|
| 827 |
+
possible_matches.sort(key=lambda x: x[0])
|
| 828 |
+
|
| 829 |
+
matches = 0
|
| 830 |
+
# 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配
|
| 831 |
+
used_pred_indices = set()
|
| 832 |
+
used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨
|
| 833 |
+
|
| 834 |
+
# 3. 遍历排序后的可能匹配,进行一对一分配
|
| 835 |
+
for dist, gt_idx, pred_idx in possible_matches:
|
| 836 |
+
# 如果这个预测点和这个真实点都还没有被使用过
|
| 837 |
+
if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices:
|
| 838 |
+
matches += 1
|
| 839 |
+
used_pred_indices.add(pred_idx)
|
| 840 |
+
used_gt_indices.add(gt_idx)
|
| 841 |
+
|
| 842 |
+
# --- 修改结束 ---
|
| 843 |
+
|
| 844 |
+
# matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total
|
| 845 |
+
recall = matches / gt_total if gt_total > 0 else 0.0
|
| 846 |
+
precision = matches / pred_total if pred_total > 0 else 0.0
|
| 847 |
+
|
| 848 |
+
# 计算F1时,使用标准的 Precision 和 Recall 定义
|
| 849 |
+
if (precision + recall) == 0:
|
| 850 |
+
f1 = 0.0
|
| 851 |
+
else:
|
| 852 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 853 |
+
|
| 854 |
+
return {
|
| 855 |
+
'recall': recall,
|
| 856 |
+
'precision': precision,
|
| 857 |
+
'f1': f1,
|
| 858 |
+
'matches': matches,
|
| 859 |
+
'pred_count': pred_total,
|
| 860 |
+
'gt_count': gt_total
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
|
| 864 |
+
"""
|
| 865 |
+
一个折衷的顶点指标计算方案。
|
| 866 |
+
它沿用“为每个真实点寻找最近预测点”的逻辑,
|
| 867 |
+
但通过修正计算方式,确保Precision和F1值不会超过1.0。
|
| 868 |
+
"""
|
| 869 |
+
# 假设 pred_coords 和 gt_coords 是 PyTorch 张量
|
| 870 |
+
pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
|
| 871 |
+
gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
|
| 872 |
+
|
| 873 |
+
pred_total = len(pred_array)
|
| 874 |
+
gt_total = len(gt_array)
|
| 875 |
+
|
| 876 |
+
if pred_total == 0 or gt_total == 0:
|
| 877 |
+
return {
|
| 878 |
+
'recall': 0.0,
|
| 879 |
+
'precision': 0.0,
|
| 880 |
+
'f1': 0.0,
|
| 881 |
+
'matches': 0,
|
| 882 |
+
'pred_count': pred_total,
|
| 883 |
+
'gt_count': gt_total
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
# 在预测点上构建KD-Tree,为每个真实点查找最近的预测点
|
| 887 |
+
tree = cKDTree(pred_array)
|
| 888 |
+
dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引
|
| 889 |
+
|
| 890 |
+
# 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall)
|
| 891 |
+
# 这和您的第一个函数完全一样。
|
| 892 |
+
# 这个值代表了“有多少个真实点被成功找到了”。
|
| 893 |
+
matches_from_gt = np.sum(dists <= threshold)
|
| 894 |
+
|
| 895 |
+
# 2. 计算 Recall (召回率)
|
| 896 |
+
# 召回率的分母是真实点的总数,所以这里的计算是合理的。
|
| 897 |
+
recall = matches_from_gt / gt_total if gt_total > 0 else 0.0
|
| 898 |
+
|
| 899 |
+
# 3. 计算 Precision (精确率) - ✅ 这是核心修正点
|
| 900 |
+
# 精确率的分母是预测点的总数。
|
| 901 |
+
# 分子(True Positives)不能超过预测点的总数。
|
| 902 |
+
# 因此,我们取 matches_from_gt 和 pred_total 中的较小值。
|
| 903 |
+
# 这解决了 Precision > 1 的问题。
|
| 904 |
+
tp_for_precision = min(matches_from_gt, pred_total)
|
| 905 |
+
precision = tp_for_precision / pred_total if pred_total > 0 else 0.0
|
| 906 |
+
|
| 907 |
+
# 4. 使用标准的F1分数公式
|
| 908 |
+
# 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score,
|
| 909 |
+
# 更常用的是基于 Precision 和 Recall 的调和平均数。
|
| 910 |
+
if (precision + recall) == 0:
|
| 911 |
+
f1 = 0.0
|
| 912 |
+
else:
|
| 913 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 914 |
+
|
| 915 |
+
return {
|
| 916 |
+
'recall': recall,
|
| 917 |
+
'precision': precision,
|
| 918 |
+
'f1': f1,
|
| 919 |
+
'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察
|
| 920 |
+
'pred_count': pred_total,
|
| 921 |
+
'gt_count': gt_total
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False):
|
| 925 |
+
if len(p1) == 0 or len(p2) == 0:
|
| 926 |
+
return float('nan')
|
| 927 |
+
|
| 928 |
+
dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean()
|
| 929 |
+
|
| 930 |
+
if one_sided:
|
| 931 |
+
return dist_p1_p2.item()
|
| 932 |
+
else:
|
| 933 |
+
dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean()
|
| 934 |
+
return (dist_p1_p2 + dist_p2_p1).item() / 2
|
| 935 |
+
|
| 936 |
+
def visualize_latent_space_pca(self, sample_idx: int):
|
| 937 |
+
"""
|
| 938 |
+
Encodes a sample, performs PCA on its latent features, and saves a
|
| 939 |
+
colored PLY file for visualization.
|
| 940 |
+
|
| 941 |
+
The position of each point in the PLY file corresponds to the spatial
|
| 942 |
+
location in the latent grid.
|
| 943 |
+
|
| 944 |
+
The color of each point represents the first three principal components
|
| 945 |
+
of its feature vector.
|
| 946 |
+
"""
|
| 947 |
+
print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---")
|
| 948 |
+
self.vae.eval()
|
| 949 |
+
|
| 950 |
+
try:
|
| 951 |
+
# 1. Get the latent representation for the sample
|
| 952 |
+
latent = self._get_latent_for_sample(sample_idx)
|
| 953 |
+
except ValueError as e:
|
| 954 |
+
print(f"Error: {e}")
|
| 955 |
+
return
|
| 956 |
+
|
| 957 |
+
latent_coords = latent.coords.detach().cpu().numpy()
|
| 958 |
+
latent_feats = latent.feats.detach().cpu().numpy()
|
| 959 |
+
|
| 960 |
+
if latent_feats.shape[0] < 3:
|
| 961 |
+
print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.")
|
| 962 |
+
return
|
| 963 |
+
|
| 964 |
+
print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...")
|
| 965 |
+
|
| 966 |
+
# 2. Perform PCA to reduce feature dimensions to 3
|
| 967 |
+
pca = PCA(n_components=3)
|
| 968 |
+
pca_features = pca.fit_transform(latent_feats)
|
| 969 |
+
|
| 970 |
+
print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}")
|
| 971 |
+
print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}")
|
| 972 |
+
|
| 973 |
+
# 3. Normalize the PCA components to be used as RGB colors [0, 255]
|
| 974 |
+
# We normalize each component independently to maximize color contrast
|
| 975 |
+
normalized_colors = np.zeros_like(pca_features)
|
| 976 |
+
for i in range(3):
|
| 977 |
+
min_val = pca_features[:, i].min()
|
| 978 |
+
max_val = pca_features[:, i].max()
|
| 979 |
+
if max_val - min_val > 1e-6:
|
| 980 |
+
normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val)
|
| 981 |
+
else:
|
| 982 |
+
normalized_colors[:, i] = 0.5 # Handle case of constant value
|
| 983 |
+
|
| 984 |
+
colors_uint8 = (normalized_colors * 255).astype(np.uint8)
|
| 985 |
+
|
| 986 |
+
# 4. Prepare spatial coordinates for the point cloud
|
| 987 |
+
# latent_coords is (batch_idx, x, y, z), we want the xyz part
|
| 988 |
+
spatial_coords = latent_coords[:, 1:]
|
| 989 |
+
|
| 990 |
+
# 5. Create and save the colored PLY file
|
| 991 |
+
try:
|
| 992 |
+
# Create a Trimesh PointCloud object
|
| 993 |
+
point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8)
|
| 994 |
+
|
| 995 |
+
# Define the output filename
|
| 996 |
+
filename = f"sample_{sample_idx}_latent_pca.ply"
|
| 997 |
+
ply_path = os.path.join(self.output_voxel_dir, filename)
|
| 998 |
+
|
| 999 |
+
# Export the file
|
| 1000 |
+
point_cloud.export(ply_path)
|
| 1001 |
+
print(f"--> Successfully saved PCA visualization to: {ply_path}")
|
| 1002 |
+
|
| 1003 |
+
except Exception as e:
|
| 1004 |
+
print(f"Error during Trimesh export: {e}")
|
| 1005 |
+
print("Please ensure 'trimesh' is installed correctly.")
|
| 1006 |
+
|
| 1007 |
+
def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor:
|
| 1008 |
+
"""
|
| 1009 |
+
Encodes a single sample and returns its latent representation.
|
| 1010 |
+
"""
|
| 1011 |
+
print(f"--> Encoding sample {sample_idx} to get its latent vector...")
|
| 1012 |
+
# Get data for the specified sample
|
| 1013 |
+
batch_data = self.dataset[sample_idx]
|
| 1014 |
+
if batch_data is None:
|
| 1015 |
+
raise ValueError(f"Sample at index {sample_idx} could not be loaded.")
|
| 1016 |
+
|
| 1017 |
+
# Use the collate function to form a batch
|
| 1018 |
+
batch_data = collate_fn_pointnet([batch_data])
|
| 1019 |
+
|
| 1020 |
+
with torch.no_grad():
|
| 1021 |
+
active_coords = batch_data['active_voxels_128'].to(self.device)
|
| 1022 |
+
point_cloud = batch_data['point_cloud_128'].to(self.device)
|
| 1023 |
+
|
| 1024 |
+
active_voxel_feats = self.voxel_encoder(
|
| 1025 |
+
p=point_cloud,
|
| 1026 |
+
sparse_coords=active_coords,
|
| 1027 |
+
res=128,
|
| 1028 |
+
bbox_size=(-0.5, 0.5),
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
sparse_input = SparseTensor(
|
| 1032 |
+
feats=active_voxel_feats,
|
| 1033 |
+
coords=active_coords.int()
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# 2. Encode to get the latent representation
|
| 1037 |
+
latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,)
|
| 1038 |
+
print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}")
|
| 1039 |
+
return latent_128
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.):
|
| 1044 |
+
total_samples = len(self.dataset)
|
| 1045 |
+
eval_samples = min(num_samples or total_samples, total_samples)
|
| 1046 |
+
sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples)
|
| 1047 |
+
# sample_indices = range(eval_samples)
|
| 1048 |
+
|
| 1049 |
+
eval_dataset = Subset(self.dataset, sample_indices)
|
| 1050 |
+
eval_loader = DataLoader(
|
| 1051 |
+
eval_dataset,
|
| 1052 |
+
batch_size=1,
|
| 1053 |
+
shuffle=False,
|
| 1054 |
+
collate_fn=partial(collate_fn_pointnet),
|
| 1055 |
+
num_workers=self.config['training']['num_workers'],
|
| 1056 |
+
pin_memory=True,
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
per_sample_metrics = {
|
| 1060 |
+
'vertex': {res: [] for res in [128, 256, 512]},
|
| 1061 |
+
'edge': {res: [] for res in [128, 256, 512]},
|
| 1062 |
+
'sample_names': []
|
| 1063 |
+
}
|
| 1064 |
+
avg_metrics = {
|
| 1065 |
+
'vertex': {res: defaultdict(list) for res in [128, 256, 512]},
|
| 1066 |
+
'edge': {res: defaultdict(list) for res in [128, 256, 512]},
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
self.vae.eval()
|
| 1070 |
+
|
| 1071 |
+
for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")):
|
| 1072 |
+
if batch_data is None:
|
| 1073 |
+
continue
|
| 1074 |
+
sample_idx = sample_indices[batch_idx]
|
| 1075 |
+
sample_name = f'sample_{sample_idx}'
|
| 1076 |
+
per_sample_metrics['sample_names'].append(sample_name)
|
| 1077 |
+
|
| 1078 |
+
# batch_save_path = f"/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/215000_sample_active_vis_42seed_1000complex/gt_data_batch_{batch_idx}.pt"
|
| 1079 |
+
# if not os.path.exists(batch_save_path):
|
| 1080 |
+
# print(f"Warning: Saved batch file not found: {batch_save_path}")
|
| 1081 |
+
# continue
|
| 1082 |
+
# batch_data = torch.load(batch_save_path, map_location=self.device)
|
| 1083 |
+
|
| 1084 |
+
with torch.no_grad():
|
| 1085 |
+
# 1. Get input data
|
| 1086 |
+
combined_voxels_512 = batch_data['combined_voxels_512'].to(self.device)
|
| 1087 |
+
combined_voxel_labels_512 = batch_data['combined_voxel_labels_512'].to(self.device)
|
| 1088 |
+
gt_combined_endpoints_512 = batch_data['gt_combined_endpoints_512'].to(self.device)
|
| 1089 |
+
gt_combined_errors_512 = batch_data['gt_combined_errors_512'].to(self.device)
|
| 1090 |
+
|
| 1091 |
+
edge_mask = (combined_voxel_labels_512 == 1)
|
| 1092 |
+
|
| 1093 |
+
gt_edge_endpoints_512 = gt_combined_endpoints_512[edge_mask].to(self.device)
|
| 1094 |
+
|
| 1095 |
+
gt_edge_voxels_512 = combined_voxels_512[edge_mask].to(self.device)
|
| 1096 |
+
|
| 1097 |
+
p1 = gt_edge_endpoints_512[:, 1:4].float()
|
| 1098 |
+
p2 = gt_edge_endpoints_512[:, 4:7].float()
|
| 1099 |
+
|
| 1100 |
+
mask = ( (p1[:,0] < p2[:,0]) |
|
| 1101 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
|
| 1102 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
|
| 1103 |
+
|
| 1104 |
+
pA = torch.where(mask[:, None], p1, p2) # smaller one
|
| 1105 |
+
pB = torch.where(mask[:, None], p2, p1) # larger one
|
| 1106 |
+
|
| 1107 |
+
d = pB - pA
|
| 1108 |
+
dir_gt = F.normalize(d, dim=-1, eps=1e-6)
|
| 1109 |
+
|
| 1110 |
+
gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'].to(self.device).int()
|
| 1111 |
+
|
| 1112 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=128)
|
| 1113 |
+
vtx_256 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=256)
|
| 1114 |
+
|
| 1115 |
+
edge_128 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=128)
|
| 1116 |
+
edge_256 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=256)
|
| 1117 |
+
edge_512 = combined_voxels_512
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
gt_edge_voxels_list = [
|
| 1121 |
+
edge_128,
|
| 1122 |
+
edge_256,
|
| 1123 |
+
edge_512,
|
| 1124 |
+
]
|
| 1125 |
+
|
| 1126 |
+
active_coords = batch_data['active_voxels_128'].to(self.device)
|
| 1127 |
+
point_cloud = batch_data['point_cloud_128'].to(self.device)
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
active_voxel_feats = self.voxel_encoder(
|
| 1131 |
+
p=point_cloud,
|
| 1132 |
+
sparse_coords=active_coords,
|
| 1133 |
+
res=128,
|
| 1134 |
+
bbox_size=(-0.5, 0.5),
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
sparse_input = SparseTensor(
|
| 1138 |
+
feats=active_voxel_feats,
|
| 1139 |
+
coords=active_coords.int()
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
latent_128, posterior = self.vae.encode(sparse_input)
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
# load_path = f'/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/215000_sample_active_vis_42seed_1000complex/sample_latent_{batch_idx}.pt'
|
| 1146 |
+
# latent_128 = torch.load(load_path, map_location=self.device)
|
| 1147 |
+
|
| 1148 |
+
print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std())
|
| 1149 |
+
print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean())
|
| 1150 |
+
print('latent_128.coords.shape', latent_128.coords.shape)
|
| 1151 |
+
|
| 1152 |
+
# latent_128 = torch.load(f"/root/Trisf/output_slat_flow_matching/ckpts/1100_chair_sample/110000step_sample/sample_results_samples_{batch_idx}.pt", map_location=self.device)
|
| 1153 |
+
latent_128 = SparseTensor(
|
| 1154 |
+
coords=latent_128.coords,
|
| 1155 |
+
feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats),
|
| 1156 |
+
)
|
| 1157 |
+
|
| 1158 |
+
# self.output_voxel_dir = os.path.dirname(load_path)
|
| 1159 |
+
# self.output_obj_dir = os.path.dirname(load_path)
|
| 1160 |
+
|
| 1161 |
+
# 7. Decoding with separate vertex and edge processing
|
| 1162 |
+
decoded_results = self.vae.decode(
|
| 1163 |
+
latent_128,
|
| 1164 |
+
gt_vertex_voxels_list=[],
|
| 1165 |
+
gt_edge_voxels_list=[],
|
| 1166 |
+
training=False,
|
| 1167 |
+
|
| 1168 |
+
inference_threshold=0.5,
|
| 1169 |
+
vis_last_layer=False,
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
error = 0 # decoded_results[-1]['edge']['predicted_offset_feats']
|
| 1173 |
+
|
| 1174 |
+
if self.config['model'].get('pred_direction', False):
|
| 1175 |
+
pred_dir = decoded_results[-1]['edge']['predicted_direction_feats']
|
| 1176 |
+
zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0
|
| 1177 |
+
num_zeros = zero_mask.sum().item()
|
| 1178 |
+
print("Number of zero vectors:", num_zeros)
|
| 1179 |
+
|
| 1180 |
+
pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
|
| 1181 |
+
print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
|
| 1182 |
+
print('pred_dir.shape', pred_dir.shape)
|
| 1183 |
+
if pred_edge_coords_3d.shape[-1] == 4:
|
| 1184 |
+
pred_edge_coords_3d = pred_edge_coords_3d[:, 1:]
|
| 1185 |
+
|
| 1186 |
+
save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply")
|
| 1187 |
+
visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth)
|
| 1188 |
+
|
| 1189 |
+
save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply")
|
| 1190 |
+
visualize_colored_points_ply((gt_edge_voxels_512[:, 1:]), dir_gt, save_pth)
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords']
|
| 1194 |
+
pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'][:, 1:].to(self.device)
|
| 1198 |
+
gt_edge_voxels_512 = batch_data['gt_edge_voxels_512'][:, 1:].to(self.device)
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
# Calculate metrics and save results
|
| 1202 |
+
matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_512, threshold=threshold,)
|
| 1203 |
+
print(f"\n----- Resolution {512} vtx -----")
|
| 1204 |
+
print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
|
| 1205 |
+
print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
|
| 1206 |
+
|
| 1207 |
+
self._save_voxel_ply(pred_vtx_coords_3d / 512., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx")
|
| 1208 |
+
self._save_voxel_ply((pred_edge_coords_3d) / 512, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge")
|
| 1209 |
+
|
| 1210 |
+
self._save_voxel_ply(gt_vertex_voxels_512 / 512, torch.zeros(len(gt_vertex_voxels_512)), f"{sample_name}_gt_vertex")
|
| 1211 |
+
self._save_voxel_ply((combined_voxels_512[:, 1:]) / 512., torch.zeros(len(gt_combined_errors_512)), f"{sample_name}_gt_edge")
|
| 1212 |
+
|
| 1213 |
+
|
| 1214 |
+
# Calculate vertex-specific metrics
|
| 1215 |
+
matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_512[:, 1:], threshold=threshold,)
|
| 1216 |
+
print(f"\n----- Resolution {512} edge -----")
|
| 1217 |
+
print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
|
| 1218 |
+
print('gt_edge_voxels_512.shape', gt_edge_voxels_512.shape)
|
| 1219 |
+
|
| 1220 |
+
print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
|
| 1221 |
+
print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
|
| 1222 |
+
|
| 1223 |
+
pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int)
|
| 1224 |
+
pred_edges = []
|
| 1225 |
+
gt_vertex_coords_np = np.round(gt_vertex_voxels_512.cpu().numpy()).astype(int)
|
| 1226 |
+
if visualize:
|
| 1227 |
+
if pred_vtx_coords_3d.shape[-1] == 4:
|
| 1228 |
+
pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float()
|
| 1229 |
+
else:
|
| 1230 |
+
pred_vtx_coords_float = pred_vtx_coords_3d.float()
|
| 1231 |
+
|
| 1232 |
+
pred_vtx_feats = decoded_results[-1]['vertex']['feats']
|
| 1233 |
+
|
| 1234 |
+
# ==========================================
|
| 1235 |
+
# Link Prediction & Mesh Generation
|
| 1236 |
+
# ==========================================
|
| 1237 |
+
print("Predicting connectivity...")
|
| 1238 |
+
|
| 1239 |
+
# 1. 预测边
|
| 1240 |
+
# 注意:K_neighbors 的设置。如果是物体,64 足够了。
|
| 1241 |
+
# 如果点非常稀疏,可能需要更大。
|
| 1242 |
+
pred_edges = predict_mesh_connectivity(
|
| 1243 |
+
connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的
|
| 1244 |
+
vtx_feats=pred_vtx_feats,
|
| 1245 |
+
vtx_coords=pred_vtx_coords_float,
|
| 1246 |
+
batch_size=4096,
|
| 1247 |
+
threshold=0.5,
|
| 1248 |
+
k_neighbors=None,
|
| 1249 |
+
device=self.device
|
| 1250 |
+
)
|
| 1251 |
+
print(f"Predicted {len(pred_edges)} edges.")
|
| 1252 |
+
|
| 1253 |
+
# 2. 构建三角形
|
| 1254 |
+
num_verts = pred_vtx_coords_float.shape[0]
|
| 1255 |
+
pred_faces = build_triangles_from_edges(pred_edges, num_verts)
|
| 1256 |
+
print(f"Constructed {len(pred_faces)} triangles.")
|
| 1257 |
+
|
| 1258 |
+
# 3. 保存 OBJ
|
| 1259 |
+
import trimesh
|
| 1260 |
+
|
| 1261 |
+
# 坐标归一化/还原 (根据你的需求,这里假设你是 0-512 的体素坐标)
|
| 1262 |
+
# 如果想保存为归一化坐标:
|
| 1263 |
+
mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0
|
| 1264 |
+
|
| 1265 |
+
# 如果有 error offset,记得加上!
|
| 1266 |
+
# 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
|
| 1267 |
+
# 如果 vertex 也有 offset (如 dual contouring),在这里加上
|
| 1268 |
+
|
| 1269 |
+
# 移动到中心 (可选)
|
| 1270 |
+
mesh_verts = mesh_verts - 0.5
|
| 1271 |
+
|
| 1272 |
+
mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
|
| 1273 |
+
|
| 1274 |
+
# 过滤孤立点 (可选)
|
| 1275 |
+
# mesh.remove_unreferenced_vertices()
|
| 1276 |
+
|
| 1277 |
+
output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj")
|
| 1278 |
+
mesh.export(output_obj_path)
|
| 1279 |
+
print(f"Saved mesh to {output_obj_path}")
|
| 1280 |
+
|
| 1281 |
+
# 保存边线 (用于 Debug)
|
| 1282 |
+
# 有时候三角形很难形成,只看边也很有用
|
| 1283 |
+
edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply")
|
| 1284 |
+
# self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison")
|
| 1285 |
+
|
| 1286 |
+
|
| 1287 |
+
# Process results at different resolutions
|
| 1288 |
+
for i, res in enumerate([128, 256, 512]):
|
| 1289 |
+
if i >= len(decoded_results):
|
| 1290 |
+
continue
|
| 1291 |
+
|
| 1292 |
+
gt_key = f'gt_vertex_voxels_{res}'
|
| 1293 |
+
if gt_key not in batch_data:
|
| 1294 |
+
continue
|
| 1295 |
+
if i == 0:
|
| 1296 |
+
pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float()
|
| 1297 |
+
gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
|
| 1298 |
+
else:
|
| 1299 |
+
pred_coords_res = decoded_results[i]['vertex']['coords'].float()
|
| 1300 |
+
gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
|
| 1301 |
+
|
| 1302 |
+
|
| 1303 |
+
v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold)
|
| 1304 |
+
|
| 1305 |
+
per_sample_metrics['vertex'][res].append({
|
| 1306 |
+
'recall': v_metrics['recall'],
|
| 1307 |
+
'precision': v_metrics['precision'],
|
| 1308 |
+
'f1': v_metrics['f1'],
|
| 1309 |
+
'num_pred': len(pred_coords_res),
|
| 1310 |
+
'num_gt': len(gt_coords_res)
|
| 1311 |
+
})
|
| 1312 |
+
|
| 1313 |
+
avg_metrics['vertex'][res]['recall'].append(v_metrics['recall'])
|
| 1314 |
+
avg_metrics['vertex'][res]['precision'].append(v_metrics['precision'])
|
| 1315 |
+
avg_metrics['vertex'][res]['f1'].append(v_metrics['f1'])
|
| 1316 |
+
|
| 1317 |
+
gt_edge_key = f'gt_edge_voxels_{res}'
|
| 1318 |
+
if gt_edge_key not in batch_data:
|
| 1319 |
+
continue
|
| 1320 |
+
|
| 1321 |
+
if i == 0:
|
| 1322 |
+
pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float()
|
| 1323 |
+
# gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1324 |
+
idx = i
|
| 1325 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1326 |
+
elif i == 1:
|
| 1327 |
+
idx = i
|
| 1328 |
+
#################################
|
| 1329 |
+
# pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5
|
| 1330 |
+
# # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1331 |
+
# gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_512[:, 1:].to(self.device) + 0.5
|
| 1332 |
+
|
| 1333 |
+
|
| 1334 |
+
pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
|
| 1335 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1336 |
+
|
| 1337 |
+
# self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset")
|
| 1338 |
+
# self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset")
|
| 1339 |
+
|
| 1340 |
+
else:
|
| 1341 |
+
idx = i
|
| 1342 |
+
pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
|
| 1343 |
+
# gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1344 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1345 |
+
|
| 1346 |
+
# self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res")
|
| 1347 |
+
# self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res")
|
| 1348 |
+
|
| 1349 |
+
e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold)
|
| 1350 |
+
|
| 1351 |
+
per_sample_metrics['edge'][res].append({
|
| 1352 |
+
'recall': e_metrics['recall'],
|
| 1353 |
+
'precision': e_metrics['precision'],
|
| 1354 |
+
'f1': e_metrics['f1'],
|
| 1355 |
+
'num_pred': len(pred_edge_coords_res),
|
| 1356 |
+
'num_gt': len(gt_edge_coords_res)
|
| 1357 |
+
})
|
| 1358 |
+
|
| 1359 |
+
avg_metrics['edge'][res]['recall'].append(e_metrics['recall'])
|
| 1360 |
+
avg_metrics['edge'][res]['precision'].append(e_metrics['precision'])
|
| 1361 |
+
avg_metrics['edge'][res]['f1'].append(e_metrics['f1'])
|
| 1362 |
+
|
| 1363 |
+
avg_metrics_processed = {}
|
| 1364 |
+
for category, res_dict in avg_metrics.items():
|
| 1365 |
+
avg_metrics_processed[category] = {}
|
| 1366 |
+
for resolution, metric_dict in res_dict.items():
|
| 1367 |
+
avg_metrics_processed[category][resolution] = {
|
| 1368 |
+
metric_name: np.mean(values) if values else float('nan')
|
| 1369 |
+
for metric_name, values in metric_dict.items()
|
| 1370 |
+
}
|
| 1371 |
+
|
| 1372 |
+
result_data = {
|
| 1373 |
+
'config': self.config,
|
| 1374 |
+
'checkpoint': self.ckpt_path,
|
| 1375 |
+
'dataset': self.dataset_path,
|
| 1376 |
+
'num_samples': eval_samples,
|
| 1377 |
+
'per_sample_metrics': per_sample_metrics,
|
| 1378 |
+
'avg_metrics': avg_metrics_processed
|
| 1379 |
+
}
|
| 1380 |
+
|
| 1381 |
+
results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml")
|
| 1382 |
+
with open(results_file_path, 'w') as f:
|
| 1383 |
+
yaml.dump(result_data, f, default_flow_style=False)
|
| 1384 |
+
|
| 1385 |
+
return result_data
|
| 1386 |
+
|
| 1387 |
+
def _generate_line_voxels(
|
| 1388 |
+
self,
|
| 1389 |
+
p1: torch.Tensor,
|
| 1390 |
+
p2: torch.Tensor
|
| 1391 |
+
) -> Tuple[
|
| 1392 |
+
List[Tuple[int, int, int]],
|
| 1393 |
+
List[Tuple[torch.Tensor, torch.Tensor]],
|
| 1394 |
+
List[np.ndarray]
|
| 1395 |
+
]:
|
| 1396 |
+
"""
|
| 1397 |
+
Improved version using better sampling strategy
|
| 1398 |
+
"""
|
| 1399 |
+
p1_np = p1 #.cpu().numpy()
|
| 1400 |
+
p2_np = p2 #.cpu().numpy()
|
| 1401 |
+
voxel_dict = OrderedDict()
|
| 1402 |
+
|
| 1403 |
+
# Use proper 3D line voxelization algorithm
|
| 1404 |
+
def bresenham_3d(p1, p2):
|
| 1405 |
+
"""3D Bresenham's line algorithm"""
|
| 1406 |
+
x1, y1, z1 = np.round(p1).astype(int)
|
| 1407 |
+
x2, y2, z2 = np.round(p2).astype(int)
|
| 1408 |
+
|
| 1409 |
+
points = []
|
| 1410 |
+
dx = abs(x2 - x1)
|
| 1411 |
+
dy = abs(y2 - y1)
|
| 1412 |
+
dz = abs(z2 - z1)
|
| 1413 |
+
|
| 1414 |
+
xs = 1 if x2 > x1 else -1
|
| 1415 |
+
ys = 1 if y2 > y1 else -1
|
| 1416 |
+
zs = 1 if z2 > z1 else -1
|
| 1417 |
+
|
| 1418 |
+
# Driving axis is X
|
| 1419 |
+
if dx >= dy and dx >= dz:
|
| 1420 |
+
err_1 = 2 * dy - dx
|
| 1421 |
+
err_2 = 2 * dz - dx
|
| 1422 |
+
for i in range(dx + 1):
|
| 1423 |
+
points.append((x1, y1, z1))
|
| 1424 |
+
if err_1 > 0:
|
| 1425 |
+
y1 += ys
|
| 1426 |
+
err_1 -= 2 * dx
|
| 1427 |
+
if err_2 > 0:
|
| 1428 |
+
z1 += zs
|
| 1429 |
+
err_2 -= 2 * dx
|
| 1430 |
+
err_1 += 2 * dy
|
| 1431 |
+
err_2 += 2 * dz
|
| 1432 |
+
x1 += xs
|
| 1433 |
+
|
| 1434 |
+
# Driving axis is Y
|
| 1435 |
+
elif dy >= dx and dy >= dz:
|
| 1436 |
+
err_1 = 2 * dx - dy
|
| 1437 |
+
err_2 = 2 * dz - dy
|
| 1438 |
+
for i in range(dy + 1):
|
| 1439 |
+
points.append((x1, y1, z1))
|
| 1440 |
+
if err_1 > 0:
|
| 1441 |
+
x1 += xs
|
| 1442 |
+
err_1 -= 2 * dy
|
| 1443 |
+
if err_2 > 0:
|
| 1444 |
+
z1 += zs
|
| 1445 |
+
err_2 -= 2 * dy
|
| 1446 |
+
err_1 += 2 * dx
|
| 1447 |
+
err_2 += 2 * dz
|
| 1448 |
+
y1 += ys
|
| 1449 |
+
|
| 1450 |
+
# Driving axis is Z
|
| 1451 |
+
else:
|
| 1452 |
+
err_1 = 2 * dx - dz
|
| 1453 |
+
err_2 = 2 * dy - dz
|
| 1454 |
+
for i in range(dz + 1):
|
| 1455 |
+
points.append((x1, y1, z1))
|
| 1456 |
+
if err_1 > 0:
|
| 1457 |
+
x1 += xs
|
| 1458 |
+
err_1 -= 2 * dz
|
| 1459 |
+
if err_2 > 0:
|
| 1460 |
+
y1 += ys
|
| 1461 |
+
err_2 -= 2 * dz
|
| 1462 |
+
err_1 += 2 * dx
|
| 1463 |
+
err_2 += 2 * dy
|
| 1464 |
+
z1 += zs
|
| 1465 |
+
|
| 1466 |
+
return points
|
| 1467 |
+
|
| 1468 |
+
# Get all voxels using Bresenham algorithm
|
| 1469 |
+
voxel_coords = bresenham_3d(p1_np, p2_np)
|
| 1470 |
+
|
| 1471 |
+
# Add all voxels to dictionary
|
| 1472 |
+
for coord in voxel_coords:
|
| 1473 |
+
voxel_dict[tuple(coord)] = (p1, p2)
|
| 1474 |
+
|
| 1475 |
+
voxel_coords = list(voxel_dict.keys())
|
| 1476 |
+
endpoint_pairs = list(voxel_dict.values())
|
| 1477 |
+
|
| 1478 |
+
# --- compute error vectors ---
|
| 1479 |
+
error_vectors = []
|
| 1480 |
+
diff = p2_np - p1_np
|
| 1481 |
+
d_norm_sq = np.dot(diff, diff)
|
| 1482 |
+
|
| 1483 |
+
for v in voxel_coords:
|
| 1484 |
+
v_center = np.array(v, dtype=float) + 0.5
|
| 1485 |
+
if d_norm_sq == 0: # degenerate line
|
| 1486 |
+
closest = p1_np
|
| 1487 |
+
else:
|
| 1488 |
+
t = np.dot(v_center - p1_np, diff) / d_norm_sq
|
| 1489 |
+
t = np.clip(t, 0.0, 1.0)
|
| 1490 |
+
closest = p1_np + t * diff
|
| 1491 |
+
error_vectors.append(v_center - closest)
|
| 1492 |
+
|
| 1493 |
+
return voxel_coords, endpoint_pairs, error_vectors
|
| 1494 |
+
|
| 1495 |
+
|
| 1496 |
+
# 使用示例
|
| 1497 |
+
def set_seed(seed: int):
|
| 1498 |
+
random.seed(seed)
|
| 1499 |
+
np.random.seed(seed)
|
| 1500 |
+
torch.manual_seed(seed)
|
| 1501 |
+
if torch.cuda.is_available():
|
| 1502 |
+
torch.cuda.manual_seed(seed)
|
| 1503 |
+
torch.cuda.manual_seed_all(seed)
|
| 1504 |
+
torch.backends.cudnn.deterministic = True
|
| 1505 |
+
torch.backends.cudnn.benchmark = False
|
| 1506 |
+
|
| 1507 |
+
def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
|
| 1508 |
+
set_seed(42)
|
| 1509 |
+
tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
|
| 1510 |
+
result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD)
|
| 1511 |
+
|
| 1512 |
+
# 生成文件名
|
| 1513 |
+
epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0]
|
| 1514 |
+
dataset_name = os.path.basename(os.path.normpath(dataset_path))
|
| 1515 |
+
|
| 1516 |
+
# 保存简版报告(TXT)
|
| 1517 |
+
summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt")
|
| 1518 |
+
with open(summary_path, 'w') as f:
|
| 1519 |
+
# 头部信息
|
| 1520 |
+
f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n")
|
| 1521 |
+
f.write(f"Dataset: {dataset_name}\n")
|
| 1522 |
+
f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n")
|
| 1523 |
+
|
| 1524 |
+
# 平均指标
|
| 1525 |
+
f.write("=== Average Metrics ===\n")
|
| 1526 |
+
for category, data in result_data['avg_metrics'].items():
|
| 1527 |
+
if isinstance(data, dict): # 处理多分辨率情况
|
| 1528 |
+
f.write(f"\n{category.upper()}:\n")
|
| 1529 |
+
for res, metrics in data.items():
|
| 1530 |
+
f.write(f" Resolution {res}:\n")
|
| 1531 |
+
for k, v in metrics.items():
|
| 1532 |
+
# 确保值是数字类型后再格式化
|
| 1533 |
+
if isinstance(v, (int, float)):
|
| 1534 |
+
f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
|
| 1535 |
+
else:
|
| 1536 |
+
f.write(f" {str(k).ljust(15)}: {str(v)}\n")
|
| 1537 |
+
else: # 处理非多分辨率情况
|
| 1538 |
+
f.write(f"\n{category.upper()}:\n")
|
| 1539 |
+
for k, v in data.items():
|
| 1540 |
+
if isinstance(v, (int, float)):
|
| 1541 |
+
f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
|
| 1542 |
+
else:
|
| 1543 |
+
f.write(f" {str(k).ljust(15)}: {str(v)}\n")
|
| 1544 |
+
|
| 1545 |
+
# 样本级详细统计
|
| 1546 |
+
f.write("\n\n=== Detailed Per-Sample Metrics ===\n")
|
| 1547 |
+
for name, vertex_metrics, edge_metrics in zip(
|
| 1548 |
+
result_data['per_sample_metrics']['sample_names'],
|
| 1549 |
+
zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256, 512]]),
|
| 1550 |
+
zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256, 512]])
|
| 1551 |
+
):
|
| 1552 |
+
# 样本标题
|
| 1553 |
+
f.write(f"\n◆ Sample: {name}\n")
|
| 1554 |
+
|
| 1555 |
+
# 顶点指标
|
| 1556 |
+
f.write(f"[Vertex Prediction]\n")
|
| 1557 |
+
f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
|
| 1558 |
+
for res, metrics in zip([128, 256, 512], vertex_metrics):
|
| 1559 |
+
f.write(f" {str(res).ljust(10)} "
|
| 1560 |
+
f"{metrics['recall']:.4f} "
|
| 1561 |
+
f"{metrics['precision']:.4f} "
|
| 1562 |
+
f"{metrics['f1']:.4f} "
|
| 1563 |
+
f"{metrics['num_pred']}/{metrics['num_gt']}\n")
|
| 1564 |
+
|
| 1565 |
+
# Edge指标
|
| 1566 |
+
f.write(f"[Edge Prediction]\n")
|
| 1567 |
+
f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
|
| 1568 |
+
for res, metrics in zip([128, 256, 512], edge_metrics):
|
| 1569 |
+
f.write(f" {str(res).ljust(10)} "
|
| 1570 |
+
f"{metrics['recall']:.4f} "
|
| 1571 |
+
f"{metrics['precision']:.4f} "
|
| 1572 |
+
f"{metrics['f1']:.4f} "
|
| 1573 |
+
f"{metrics['num_pred']}/{metrics['num_gt']}\n")
|
| 1574 |
+
|
| 1575 |
+
f.write("-"*60 + "\n")
|
| 1576 |
+
|
| 1577 |
+
print(f"Saved summary to: {summary_path}")
|
| 1578 |
+
return result_data
|
| 1579 |
+
|
| 1580 |
+
|
| 1581 |
+
if __name__ == '__main__':
|
| 1582 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 1583 |
+
evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
|
| 1584 |
+
EPOCH_START = 0
|
| 1585 |
+
EPOCH_END = 12
|
| 1586 |
+
CHAMFER_EDGE_THRESHOLD=0.5
|
| 1587 |
+
NUM_SAMPLES=20
|
| 1588 |
+
VISUALIZE=True
|
| 1589 |
+
THRESHOLD=1.5
|
| 1590 |
+
VISUAL_FIELD=False
|
| 1591 |
+
|
| 1592 |
+
ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch10433_loss1.2657.pt'
|
| 1593 |
+
ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch2000_loss0.3315.pt'
|
| 1594 |
+
|
| 1595 |
+
dataset_path = '/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/test'
|
| 1596 |
+
dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized'
|
| 1597 |
+
# dataset_path = '/gemini/user/private/zhaotianhao/data/trellis500k_compress_glb'
|
| 1598 |
+
dataset_path = '/gemini/user/private/zhaotianhao/data/unique_files_glb_under6000face_2degree_30ratio_0.01'
|
| 1599 |
+
|
| 1600 |
+
if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
|
| 1601 |
+
RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
|
| 1602 |
+
else:
|
| 1603 |
+
RENDERS_DIR = ''
|
| 1604 |
+
|
| 1605 |
+
|
| 1606 |
+
ckpt_dir = os.path.dirname(ckpt_path)
|
| 1607 |
+
eval_dir = os.path.join(ckpt_dir, "evaluate")
|
| 1608 |
+
os.makedirs(eval_dir, exist_ok=True)
|
| 1609 |
+
|
| 1610 |
+
if False:
|
| 1611 |
+
for i in range(NUM_SAMPLES):
|
| 1612 |
+
print("--- Starting Latent Space PCA Visualization ---")
|
| 1613 |
+
tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
|
| 1614 |
+
tester.visualize_latent_space_pca(sample_idx=i)
|
| 1615 |
+
print("--- PCA Visualization Finished ---")
|
| 1616 |
+
|
| 1617 |
+
if not evaluate_all_checkpoints:
|
| 1618 |
+
evaluate_checkpoint(ckpt_path, dataset_path, eval_dir)
|
| 1619 |
+
else:
|
| 1620 |
+
pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
|
| 1621 |
+
|
| 1622 |
+
filtered_pt_files = []
|
| 1623 |
+
for f in pt_files:
|
| 1624 |
+
try:
|
| 1625 |
+
parts = f.split('_')
|
| 1626 |
+
epoch_str = parts[1].replace('epoch', '')
|
| 1627 |
+
epoch = int(epoch_str)
|
| 1628 |
+
if EPOCH_START <= epoch <= EPOCH_END:
|
| 1629 |
+
filtered_pt_files.append(f)
|
| 1630 |
+
except Exception as e:
|
| 1631 |
+
print(f"Warning: Could not parse epoch from {f}: {e}")
|
| 1632 |
+
continue
|
| 1633 |
+
|
| 1634 |
+
for pt_file in filtered_pt_files:
|
| 1635 |
+
full_ckpt_path = os.path.join(ckpt_dir, pt_file)
|
| 1636 |
+
evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir)
|
train_slat_flow_128to1024_pointnet.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# os.environ['ATTN_BACKEND'] = 'xformers' # xformers is generally compatible with DDP
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import yaml
|
| 6 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 7 |
+
from functools import partial
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.optim import AdamW
|
| 10 |
+
from torch.amp import GradScaler, autocast
|
| 11 |
+
from typing import *
|
| 12 |
+
from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 15 |
+
|
| 16 |
+
# --- Updated Imports based on VAE script ---
|
| 17 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 18 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 19 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 20 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 21 |
+
|
| 22 |
+
from trellis.models.structured_latent_flow import SLatFlowModel
|
| 23 |
+
from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
|
| 24 |
+
from safetensors.torch import load_file
|
| 25 |
+
import torch.multiprocessing as mp
|
| 26 |
+
import open3d as o3d
|
| 27 |
+
from PIL import Image
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
from triposf.modules.utils import DiagonalGaussianDistribution
|
| 31 |
+
import torchvision.transforms as transforms
|
| 32 |
+
import re
|
| 33 |
+
|
| 34 |
+
# --- Distributed Setup Functions ---
|
| 35 |
+
def setup_distributed(backend="nccl"):
|
| 36 |
+
"""Initializes the distributed environment."""
|
| 37 |
+
if not dist.is_initialized():
|
| 38 |
+
rank = int(os.environ["RANK"])
|
| 39 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 40 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 41 |
+
|
| 42 |
+
torch.cuda.set_device(local_rank)
|
| 43 |
+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
| 44 |
+
|
| 45 |
+
return int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
| 46 |
+
|
| 47 |
+
def cleanup_distributed():
|
| 48 |
+
dist.destroy_process_group()
|
| 49 |
+
|
| 50 |
+
# --- Modified Trainer Class ---
|
| 51 |
+
class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
|
| 52 |
+
def __init__(self, *args, rank: int, local_rank: int, world_size: int, **kwargs):
|
| 53 |
+
super().__init__(*args, **kwargs)
|
| 54 |
+
self.cfg = kwargs.pop('cfg', None)
|
| 55 |
+
if self.cfg is None:
|
| 56 |
+
raise ValueError("Configuration dictionary 'cfg' must be provided.")
|
| 57 |
+
|
| 58 |
+
# --- Distributed-related attributes ---
|
| 59 |
+
self.rank = rank
|
| 60 |
+
self.local_rank = local_rank
|
| 61 |
+
self.world_size = world_size
|
| 62 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 63 |
+
self.is_master = (self.rank == 0)
|
| 64 |
+
|
| 65 |
+
self.i_save = self.cfg['training']['save_every']
|
| 66 |
+
self.save_dir = self.cfg['training']['output_dir']
|
| 67 |
+
|
| 68 |
+
self.resolution = 128
|
| 69 |
+
self.condition_type = 'image'
|
| 70 |
+
self.is_cond = False
|
| 71 |
+
self.img_res = 518
|
| 72 |
+
|
| 73 |
+
if self.is_master:
|
| 74 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 75 |
+
print(f"Checkpoints and logs will be saved to: {self.save_dir}")
|
| 76 |
+
|
| 77 |
+
# Initialize components and set up for DDP
|
| 78 |
+
self._init_components(
|
| 79 |
+
clip_model_path=self.cfg['training'].get('clip_model_path', None),
|
| 80 |
+
dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
|
| 81 |
+
vae_path=self.cfg['training']['vae_path'],
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self._setup_ddp()
|
| 85 |
+
|
| 86 |
+
self.denoiser_checkpoint_path = self.cfg['training'].get('denoiser_checkpoint_path', None)
|
| 87 |
+
|
| 88 |
+
trainable_params = list(self.denoiser.parameters())
|
| 89 |
+
self.optimizer = AdamW(trainable_params, lr=self.cfg['training'].get('lr', 0.0001), weight_decay=0.0)
|
| 90 |
+
|
| 91 |
+
self.scaler = GradScaler()
|
| 92 |
+
|
| 93 |
+
if self.is_master:
|
| 94 |
+
print("Using Automatic Mixed Precision (AMP) with GradScaler.")
|
| 95 |
+
|
| 96 |
+
def _init_components(self,
|
| 97 |
+
clip_model_path=None,
|
| 98 |
+
dinov2_model_path=None,
|
| 99 |
+
vae_path=None,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Initializes VAE, VoxelEncoder (PointNet), and condition models.
|
| 103 |
+
"""
|
| 104 |
+
def load_file_func(path, device='cpu'):
|
| 105 |
+
return torch.load(path, map_location=device)
|
| 106 |
+
|
| 107 |
+
def _load_and_broadcast(model, load_fn=None, path=None, strict=True):
|
| 108 |
+
if self.is_master:
|
| 109 |
+
try:
|
| 110 |
+
state = load_fn(path) if load_fn else model.state_dict()
|
| 111 |
+
except Exception as e:
|
| 112 |
+
raise RuntimeError(f"Failed to load weights from {path}: {e}")
|
| 113 |
+
else:
|
| 114 |
+
state = None
|
| 115 |
+
|
| 116 |
+
dist.barrier()
|
| 117 |
+
state_b = [state] if self.is_master else [None]
|
| 118 |
+
dist.broadcast_object_list(state_b, src=0)
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# Handle potential key mismatches (e.g. 'module.' prefix)
|
| 122 |
+
model.load_state_dict(state_b[0], strict=strict)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
if self.is_master: print(f"Strict loading failed for {model.__class__.__name__}, trying non-strict: {e}")
|
| 125 |
+
model.load_state_dict(state_b[0], strict=False)
|
| 126 |
+
|
| 127 |
+
# ------------------------- Voxel Encoder (PointNet) -------------------------
|
| 128 |
+
# Matching the VAE script configuration
|
| 129 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 130 |
+
in_channels=15,
|
| 131 |
+
hidden_dim=256,
|
| 132 |
+
out_channels=1024,
|
| 133 |
+
scatter_type='mean',
|
| 134 |
+
n_blocks=5,
|
| 135 |
+
resolution=128,
|
| 136 |
+
add_label=False,
|
| 137 |
+
).to(self.device)
|
| 138 |
+
|
| 139 |
+
# ------------------------- VAE -------------------------
|
| 140 |
+
self.vae = VoxelVAE(
|
| 141 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 142 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 143 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 144 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 145 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 146 |
+
num_heads=8,
|
| 147 |
+
num_head_channels=64,
|
| 148 |
+
mlp_ratio=4.0,
|
| 149 |
+
attn_mode="swin",
|
| 150 |
+
window_size=8,
|
| 151 |
+
pe_mode="ape",
|
| 152 |
+
use_fp16=False,
|
| 153 |
+
use_checkpoint=True,
|
| 154 |
+
qk_rms_norm=False,
|
| 155 |
+
using_subdivide=True,
|
| 156 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 157 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 158 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 159 |
+
).to(self.device)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ------------------------- Conditioning -------------------------
|
| 163 |
+
if self.condition_type == 'text':
|
| 164 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
|
| 165 |
+
if self.is_master:
|
| 166 |
+
self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
|
| 167 |
+
else:
|
| 168 |
+
config = CLIPTextConfig.from_pretrained(clip_model_path)
|
| 169 |
+
self.condition_model = CLIPTextModel(config)
|
| 170 |
+
_load_and_broadcast(self.condition_model)
|
| 171 |
+
|
| 172 |
+
elif self.condition_type == 'image':
|
| 173 |
+
if self.is_master:
|
| 174 |
+
print("Initializing for IMAGE conditioning (DINOv2).")
|
| 175 |
+
|
| 176 |
+
# Update paths as per your environment
|
| 177 |
+
local_repo_path = "/root/Trisf/dinov2_resources/dinov2-main"
|
| 178 |
+
weights_path = "/root/Trisf/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
|
| 179 |
+
|
| 180 |
+
dinov2_model = torch.hub.load(
|
| 181 |
+
repo_or_dir=local_repo_path,
|
| 182 |
+
model='dinov2_vitl14_reg',
|
| 183 |
+
source='local',
|
| 184 |
+
pretrained=False
|
| 185 |
+
)
|
| 186 |
+
self.condition_model = dinov2_model
|
| 187 |
+
|
| 188 |
+
_load_and_broadcast(self.condition_model, load_fn=torch.load, path=weights_path)
|
| 189 |
+
|
| 190 |
+
self.image_cond_model_transform = transforms.Compose([
|
| 191 |
+
transforms.ToTensor(),
|
| 192 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 193 |
+
])
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError(f"Unsupported condition type: {self.condition_type}")
|
| 196 |
+
|
| 197 |
+
self.condition_model.to(self.device).eval()
|
| 198 |
+
for p in self.condition_model.parameters(): p.requires_grad = False
|
| 199 |
+
|
| 200 |
+
# ------------------------- Load VAE/Encoder Weights -------------------------
|
| 201 |
+
# Load weights corresponding to the logic in VAE script's `load_pretrained_woself`
|
| 202 |
+
# Assuming checkpoint contains 'vae' and 'voxel_encoder' keys
|
| 203 |
+
_load_and_broadcast(self.vae,
|
| 204 |
+
load_fn=lambda p: load_file_func(p)['vae'],
|
| 205 |
+
path=vae_path)
|
| 206 |
+
|
| 207 |
+
_load_and_broadcast(self.voxel_encoder,
|
| 208 |
+
load_fn=lambda p: load_file_func(p)['voxel_encoder'],
|
| 209 |
+
path=vae_path)
|
| 210 |
+
|
| 211 |
+
self.vae.eval()
|
| 212 |
+
self.voxel_encoder.eval()
|
| 213 |
+
for p in self.vae.parameters(): p.requires_grad = False
|
| 214 |
+
for p in self.voxel_encoder.parameters(): p.requires_grad = False
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _load_denoiser(self):
|
| 218 |
+
"""Loads a checkpoint for the denoiser."""
|
| 219 |
+
path = self.denoiser_checkpoint_path
|
| 220 |
+
if not path or not os.path.isfile(path):
|
| 221 |
+
if self.is_master:
|
| 222 |
+
print("No valid checkpoint path provided for denoiser. Starting from scratch.")
|
| 223 |
+
return
|
| 224 |
+
|
| 225 |
+
if self.is_master:
|
| 226 |
+
print(f"Loading denoiser checkpoint from: {path}")
|
| 227 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 228 |
+
else:
|
| 229 |
+
checkpoint = None
|
| 230 |
+
|
| 231 |
+
dist.barrier()
|
| 232 |
+
dist_list = [checkpoint] if self.is_master else [None]
|
| 233 |
+
dist.broadcast_object_list(dist_list, src=0)
|
| 234 |
+
checkpoint = dist_list[0]
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
self.denoiser.module.load_state_dict(checkpoint['denoiser'])
|
| 238 |
+
if self.is_master: print("Denoiser weights loaded successfully.")
|
| 239 |
+
except Exception as e:
|
| 240 |
+
if self.is_master: print(f"[ERROR] Failed to load denoiser state_dict: {e}")
|
| 241 |
+
|
| 242 |
+
if 'step' in checkpoint and self.is_master:
|
| 243 |
+
print(f"Checkpoint from step {checkpoint['step']}.")
|
| 244 |
+
|
| 245 |
+
dist.barrier()
|
| 246 |
+
|
| 247 |
+
def _setup_ddp(self):
|
| 248 |
+
"""Sets up DDP and DataLoaders."""
|
| 249 |
+
self.denoiser = self.denoiser.to(self.device)
|
| 250 |
+
self.denoiser = DDP(self.denoiser, device_ids=[self.local_rank])
|
| 251 |
+
|
| 252 |
+
for param in self.denoiser.parameters():
|
| 253 |
+
param.requires_grad = True
|
| 254 |
+
|
| 255 |
+
# Use the Dataset from the VAE script
|
| 256 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 257 |
+
root_dir=self.cfg['dataset']['path'],
|
| 258 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 259 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 260 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 261 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 262 |
+
|
| 263 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 264 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 265 |
+
|
| 266 |
+
active_voxel_res=128,
|
| 267 |
+
pc_sample_number=819200,
|
| 268 |
+
|
| 269 |
+
sample_type=self.cfg['dataset']['sample_type'],
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
self.sampler = DistributedSampler(
|
| 273 |
+
self.dataset,
|
| 274 |
+
num_replicas=self.world_size,
|
| 275 |
+
rank=self.rank,
|
| 276 |
+
shuffle=True
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Use collate_fn_pointnet
|
| 280 |
+
self.dataloader = DataLoader(
|
| 281 |
+
self.dataset,
|
| 282 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 283 |
+
shuffle=False,
|
| 284 |
+
collate_fn=partial(collate_fn_pointnet,),
|
| 285 |
+
num_workers=self.cfg['training']['num_workers'],
|
| 286 |
+
pin_memory=True,
|
| 287 |
+
sampler=self.sampler,
|
| 288 |
+
prefetch_factor=4,
|
| 289 |
+
persistent_workers=True,
|
| 290 |
+
drop_last=True,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
@torch.no_grad()
|
| 294 |
+
def encode_image(self, images) -> torch.Tensor:
|
| 295 |
+
if isinstance(images, torch.Tensor):
|
| 296 |
+
batch_tensor = images.to(self.device)
|
| 297 |
+
elif isinstance(images, list):
|
| 298 |
+
assert all(isinstance(i, Image.Image) for i in images), "Image list should be list of PIL images"
|
| 299 |
+
image = [i.resize((518, 518), Image.LANCZOS) for i in images]
|
| 300 |
+
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
| 301 |
+
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
| 302 |
+
batch_tensor = torch.stack(image).to(self.device)
|
| 303 |
+
else:
|
| 304 |
+
raise ValueError(f"Unsupported type of image: {type(image)}")
|
| 305 |
+
|
| 306 |
+
if batch_tensor.shape[-2:] != (518, 518):
|
| 307 |
+
batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
|
| 308 |
+
|
| 309 |
+
features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
|
| 310 |
+
patchtokens = F.layer_norm(features, features.shape[-1:])
|
| 311 |
+
return patchtokens
|
| 312 |
+
|
| 313 |
+
def process_batch(self, batch):
|
| 314 |
+
preprocessed_images = batch['image']
|
| 315 |
+
cond_ = self.encode_image(preprocessed_images)
|
| 316 |
+
return cond_
|
| 317 |
+
|
| 318 |
+
def train(self, num_epochs=10000):
|
| 319 |
+
# Unconditional handling for text/image
|
| 320 |
+
if self.is_cond == False:
|
| 321 |
+
if self.condition_type == 'text':
|
| 322 |
+
txt = ['']
|
| 323 |
+
encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
| 324 |
+
tokens = encoding['input_ids'].to(self.device)
|
| 325 |
+
with torch.no_grad():
|
| 326 |
+
cond_ = self.condition_model(input_ids=tokens).last_hidden_state
|
| 327 |
+
else:
|
| 328 |
+
blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
dummy_cond = self.encode_image([blank_img])
|
| 331 |
+
cond_ = torch.zeros_like(dummy_cond)
|
| 332 |
+
if self.is_master: print(f"Generated unconditional image prompt with shape: {cond_.shape}")
|
| 333 |
+
|
| 334 |
+
self._load_denoiser()
|
| 335 |
+
self.denoiser.train()
|
| 336 |
+
|
| 337 |
+
# Step tracking
|
| 338 |
+
step = 0
|
| 339 |
+
if self.denoiser_checkpoint_path:
|
| 340 |
+
match = re.search(r'step(\d+)', self.denoiser_checkpoint_path)
|
| 341 |
+
if match:
|
| 342 |
+
step = int(match.group(1))
|
| 343 |
+
step = 0
|
| 344 |
+
|
| 345 |
+
for epoch in range(num_epochs):
|
| 346 |
+
self.sampler.set_epoch(epoch)
|
| 347 |
+
epoch_losses = []
|
| 348 |
+
|
| 349 |
+
for i, batch in enumerate(self.dataloader):
|
| 350 |
+
self.optimizer.zero_grad()
|
| 351 |
+
|
| 352 |
+
# --- Conditioning ---
|
| 353 |
+
if self.is_cond and self.condition_type == 'image':
|
| 354 |
+
cond_ = self.process_batch(batch)
|
| 355 |
+
|
| 356 |
+
# Retrieve Data from collate_fn_pointnet
|
| 357 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 358 |
+
active_coords = batch['active_voxels_128'].to(self.device)
|
| 359 |
+
|
| 360 |
+
# Handle Batch Size for Conditioning
|
| 361 |
+
batch_size = int(active_coords[:, 0].max().item() + 1)
|
| 362 |
+
if cond_.shape[0] != batch_size:
|
| 363 |
+
cond_ = cond_.expand(batch_size, -1, -1).contiguous().to(self.device)
|
| 364 |
+
else:
|
| 365 |
+
cond_ = cond_.to(self.device)
|
| 366 |
+
|
| 367 |
+
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 368 |
+
with torch.no_grad():
|
| 369 |
+
# 1. Encode Point Cloud to Features on Active Voxels
|
| 370 |
+
# The encoder processes point clouds and scatters features to `active_coords`
|
| 371 |
+
active_voxel_feats = self.voxel_encoder(
|
| 372 |
+
p=point_cloud,
|
| 373 |
+
sparse_coords=active_coords,
|
| 374 |
+
res=128,
|
| 375 |
+
bbox_size=(-0.5, 0.5),
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# 2. Prepare Sparse Input for VAE
|
| 379 |
+
sparse_input = SparseTensor(
|
| 380 |
+
feats=active_voxel_feats,
|
| 381 |
+
coords=active_coords.int()
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# 3. Get Latent Distribution from VAE
|
| 385 |
+
# We use the encode method of VoxelVAE to get posterior
|
| 386 |
+
latent_128, posterior = self.vae.encode(sparse_input)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# 5. Calculate Diffusion Loss
|
| 390 |
+
terms, _ = self.training_losses(x_0=latent_128, cond=cond_)
|
| 391 |
+
loss = terms['loss']
|
| 392 |
+
|
| 393 |
+
self.scaler.scale(loss).backward()
|
| 394 |
+
self.scaler.step(self.optimizer)
|
| 395 |
+
self.scaler.update()
|
| 396 |
+
|
| 397 |
+
with torch.no_grad():
|
| 398 |
+
avg_loss = loss.detach()
|
| 399 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
|
| 400 |
+
|
| 401 |
+
step += 1
|
| 402 |
+
|
| 403 |
+
# --- Logging and Saving ---
|
| 404 |
+
if self.is_master:
|
| 405 |
+
epoch_losses.append(avg_loss.item())
|
| 406 |
+
if step % 10 == 0:
|
| 407 |
+
print(f"Epoch {epoch+1} Step {step}: "
|
| 408 |
+
f"Rank0_Loss = {loss.item():.4f}, "
|
| 409 |
+
f"Global_Avg_Loss = {avg_loss.item():.4f}, "
|
| 410 |
+
f"Epoch_Mean = {np.mean(epoch_losses):.4f}")
|
| 411 |
+
|
| 412 |
+
if step % self.i_save == 0 or step == 1:
|
| 413 |
+
checkpoint = {
|
| 414 |
+
'denoiser': self.denoiser.module.state_dict(),
|
| 415 |
+
'step': step
|
| 416 |
+
}
|
| 417 |
+
loss_val_str = f"{loss.item():.6f}".replace('.', '_')
|
| 418 |
+
save_path = os.path.join(self.save_dir, f"checkpoint_step{step}_loss{loss_val_str}.pt")
|
| 419 |
+
torch.save(checkpoint, save_path)
|
| 420 |
+
print(f"Saved checkpoint to {save_path}")
|
| 421 |
+
|
| 422 |
+
if self.is_master:
|
| 423 |
+
avg_loss = np.mean(epoch_losses) if epoch_losses else 0
|
| 424 |
+
log_path = os.path.join(self.save_dir, "loss_log.txt")
|
| 425 |
+
with open(log_path, "a") as f:
|
| 426 |
+
f.write(f"Epoch {epoch+1}, Step {step}, AvgLoss {avg_loss:.6f}\n")
|
| 427 |
+
|
| 428 |
+
dist.barrier()
|
| 429 |
+
# torch.cuda.empty_cache()
|
| 430 |
+
# gc.collect()
|
| 431 |
+
|
| 432 |
+
if self.is_master:
|
| 433 |
+
print("Training complete.")
|
| 434 |
+
|
| 435 |
+
def main():
|
| 436 |
+
if mp.get_start_method(allow_none=True) != 'spawn':
|
| 437 |
+
mp.set_start_method('spawn', force=True)
|
| 438 |
+
|
| 439 |
+
rank, local_rank, world_size = setup_distributed()
|
| 440 |
+
torch.manual_seed(42)
|
| 441 |
+
np.random.seed(42)
|
| 442 |
+
|
| 443 |
+
# Path to your config
|
| 444 |
+
config_path = "/root/Trisf/config_slat_flow_128to1024_pointnet_head.yaml"
|
| 445 |
+
with open(config_path) as f:
|
| 446 |
+
cfg = yaml.safe_load(f)
|
| 447 |
+
|
| 448 |
+
# Initialize Flow Model (on CPU first)
|
| 449 |
+
diffusion_model = SLatFlowModel(
|
| 450 |
+
resolution=cfg['flow']['resolution'],
|
| 451 |
+
in_channels=cfg['flow']['in_channels'],
|
| 452 |
+
out_channels=cfg['flow']['out_channels'],
|
| 453 |
+
model_channels=cfg['flow']['model_channels'],
|
| 454 |
+
cond_channels=cfg['flow']['cond_channels'],
|
| 455 |
+
num_blocks=cfg['flow']['num_blocks'],
|
| 456 |
+
num_heads=cfg['flow']['num_heads'],
|
| 457 |
+
mlp_ratio=cfg['flow']['mlp_ratio'],
|
| 458 |
+
patch_size=cfg['flow']['patch_size'],
|
| 459 |
+
num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
|
| 460 |
+
io_block_channels=cfg['flow']['io_block_channels'],
|
| 461 |
+
pe_mode=cfg['flow']['pe_mode'],
|
| 462 |
+
qk_rms_norm=cfg['flow']['qk_rms_norm'],
|
| 463 |
+
qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
|
| 464 |
+
use_fp16=cfg['flow'].get('use_fp16', False),
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
torch.manual_seed(42 + rank)
|
| 468 |
+
np.random.seed(42 + rank)
|
| 469 |
+
|
| 470 |
+
trainer = SLatFlowMatchingTrainer(
|
| 471 |
+
denoiser=diffusion_model,
|
| 472 |
+
t_schedule=cfg['t_schedule'],
|
| 473 |
+
sigma_min=cfg['sigma_min'],
|
| 474 |
+
cfg=cfg,
|
| 475 |
+
rank=rank,
|
| 476 |
+
local_rank=local_rank,
|
| 477 |
+
world_size=world_size,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
trainer.train()
|
| 481 |
+
cleanup_distributed()
|
| 482 |
+
|
| 483 |
+
if __name__ == '__main__':
|
| 484 |
+
main()
|
train_slat_vae_512_128to1024_pointnet.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import os
|
| 3 |
+
# os.environ['ATTN_BACKEND'] = 'xformers'
|
| 4 |
+
import yaml
|
| 5 |
+
import torch
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from functools import partial
|
| 10 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 14 |
+
|
| 15 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder import VoxelVAE
|
| 16 |
+
|
| 17 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 21 |
+
|
| 22 |
+
from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss
|
| 23 |
+
|
| 24 |
+
import torch.distributed as dist
|
| 25 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 26 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 27 |
+
|
| 28 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 29 |
+
import math
|
| 30 |
+
|
| 31 |
+
def flatten_coords_4d(coords_4d: torch.Tensor):
|
| 32 |
+
coords_4d_long = coords_4d.long()
|
| 33 |
+
|
| 34 |
+
base_x = 1024
|
| 35 |
+
base_y = 1024 * 1024
|
| 36 |
+
base_z = 1024 * 1024 * 1024
|
| 37 |
+
|
| 38 |
+
flat_coords = coords_4d_long[:, 0] * base_z + \
|
| 39 |
+
coords_4d_long[:, 1] * base_y + \
|
| 40 |
+
coords_4d_long[:, 2] * base_x + \
|
| 41 |
+
coords_4d_long[:, 3]
|
| 42 |
+
return flat_coords
|
| 43 |
+
|
| 44 |
+
def downsample_voxels(
|
| 45 |
+
voxels: torch.Tensor,
|
| 46 |
+
input_resolution: int,
|
| 47 |
+
output_resolution: int
|
| 48 |
+
) -> torch.Tensor:
|
| 49 |
+
if input_resolution % output_resolution != 0:
|
| 50 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 51 |
+
f"by output_resolution ({output_resolution}).")
|
| 52 |
+
|
| 53 |
+
factor = input_resolution // output_resolution
|
| 54 |
+
|
| 55 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 56 |
+
|
| 57 |
+
downsampled_voxels[:, 1:] //= factor
|
| 58 |
+
|
| 59 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 60 |
+
return unique_downsampled_voxels
|
| 61 |
+
|
| 62 |
+
class Trainer:
|
| 63 |
+
def __init__(self, config_path, rank, world_size, local_rank):
|
| 64 |
+
self.rank = rank
|
| 65 |
+
self.world_size = world_size
|
| 66 |
+
self.local_rank = local_rank
|
| 67 |
+
self.is_master = self.rank == 0
|
| 68 |
+
|
| 69 |
+
self.load_config(config_path)
|
| 70 |
+
self.accum_steps = max(1, 8 // self.cfg['training']['batch_size'])
|
| 71 |
+
|
| 72 |
+
self.config_hash = self.save_config_with_hash()
|
| 73 |
+
self.init_device()
|
| 74 |
+
self.init_dirs()
|
| 75 |
+
self.init_components()
|
| 76 |
+
self.init_training()
|
| 77 |
+
|
| 78 |
+
self.train_loss_history = []
|
| 79 |
+
self.eval_loss_history = []
|
| 80 |
+
self.best_eval_loss = float('inf')
|
| 81 |
+
|
| 82 |
+
def save_config_with_hash(self):
|
| 83 |
+
import hashlib
|
| 84 |
+
|
| 85 |
+
# Serialize config to hash
|
| 86 |
+
config_str = yaml.dump(self.cfg)
|
| 87 |
+
config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
|
| 88 |
+
|
| 89 |
+
# Prepare all flags as string for formatting
|
| 90 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 91 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 92 |
+
|
| 93 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 94 |
+
|
| 95 |
+
# Format save_dir with all placeholders
|
| 96 |
+
self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
|
| 97 |
+
dataset_name=dataset_name,
|
| 98 |
+
config_hash=config_hash,
|
| 99 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 100 |
+
multires=self.cfg['model']['multires'],
|
| 101 |
+
add_block_embed=add_block_embed_flag,
|
| 102 |
+
using_attn=using_attn_flag,
|
| 103 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if self.is_master:
|
| 107 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 108 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 109 |
+
with open(config_path, 'w') as f:
|
| 110 |
+
yaml.dump(self.cfg, f)
|
| 111 |
+
|
| 112 |
+
dist.barrier()
|
| 113 |
+
return config_hash
|
| 114 |
+
|
| 115 |
+
def save_checkpoint(self, epoch, avg_loss, batch_idx):
|
| 116 |
+
if not self.is_master:
|
| 117 |
+
return
|
| 118 |
+
checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
|
| 119 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 120 |
+
|
| 121 |
+
torch.save({
|
| 122 |
+
'voxel_encoder': self.voxel_encoder.module.state_dict(),
|
| 123 |
+
'vae': self.vae.module.state_dict(),
|
| 124 |
+
'epoch': epoch,
|
| 125 |
+
'loss': avg_loss,
|
| 126 |
+
'config': self.cfg
|
| 127 |
+
}, checkpoint_path)
|
| 128 |
+
|
| 129 |
+
def quoted_presenter(dumper, data):
|
| 130 |
+
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
|
| 131 |
+
|
| 132 |
+
yaml.add_representer(str, quoted_presenter)
|
| 133 |
+
|
| 134 |
+
with open(config_path, 'w') as f:
|
| 135 |
+
yaml.dump(self.cfg, f)
|
| 136 |
+
|
| 137 |
+
def load_config(self, config_path):
|
| 138 |
+
with open(config_path) as f:
|
| 139 |
+
self.cfg = yaml.safe_load(f)
|
| 140 |
+
|
| 141 |
+
# Extract and convert flags for formatting
|
| 142 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 143 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 144 |
+
|
| 145 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 146 |
+
|
| 147 |
+
self.save_dir = self.cfg['experiment']['save_dir'].format(
|
| 148 |
+
dataset_name=dataset_name,
|
| 149 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 150 |
+
multires=self.cfg['model']['multires'],
|
| 151 |
+
add_block_embed=add_block_embed_flag,
|
| 152 |
+
using_attn=using_attn_flag,
|
| 153 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if self.is_master:
|
| 157 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 158 |
+
dist.barrier()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def init_device(self):
|
| 162 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 163 |
+
|
| 164 |
+
def init_dirs(self):
|
| 165 |
+
self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
|
| 166 |
+
if self.is_master:
|
| 167 |
+
with open(self.log_file, "a") as f:
|
| 168 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 169 |
+
f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
|
| 170 |
+
|
| 171 |
+
def init_components(self):
|
| 172 |
+
# self.dataset = VoxelVertexDataset_edge_shapenet(
|
| 173 |
+
# root_dir=self.cfg['dataset']['path'],
|
| 174 |
+
# file_list_path=self.cfg['dataset']['file_list_path'],
|
| 175 |
+
# map_file_path=self.cfg['dataset']['map_file_path'],
|
| 176 |
+
|
| 177 |
+
# base_resolution=self.cfg['dataset']['base_resolution'],
|
| 178 |
+
# min_resolution=self.cfg['dataset']['min_resolution'],
|
| 179 |
+
# cache_dir=self.cfg['dataset']['cache_dir'],
|
| 180 |
+
# renders_dir=self.cfg['dataset']['renders_dir'],
|
| 181 |
+
|
| 182 |
+
# filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 183 |
+
# cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 184 |
+
# )
|
| 185 |
+
|
| 186 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 187 |
+
root_dir=self.cfg['dataset']['path'],
|
| 188 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 189 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 190 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 191 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 192 |
+
|
| 193 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 194 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 195 |
+
|
| 196 |
+
active_voxel_res=128,
|
| 197 |
+
pc_sample_number=819200,
|
| 198 |
+
|
| 199 |
+
sample_type=self.cfg['dataset']['sample_type'],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
self.sampler = DistributedSampler(
|
| 203 |
+
self.dataset,
|
| 204 |
+
num_replicas=self.world_size,
|
| 205 |
+
rank=self.rank,
|
| 206 |
+
shuffle=True,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self.dataloader = DataLoader(
|
| 210 |
+
self.dataset,
|
| 211 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 212 |
+
shuffle=False,
|
| 213 |
+
collate_fn=partial(collate_fn_pointnet,),
|
| 214 |
+
num_workers=self.cfg['training']['num_workers'],
|
| 215 |
+
pin_memory=True,
|
| 216 |
+
sampler=self.sampler,
|
| 217 |
+
# prefetch_factor=4,
|
| 218 |
+
persistent_workers=True,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 222 |
+
in_channels=15,
|
| 223 |
+
hidden_dim=256,
|
| 224 |
+
out_channels=1024,
|
| 225 |
+
scatter_type='mean',
|
| 226 |
+
n_blocks=5,
|
| 227 |
+
resolution=128,
|
| 228 |
+
add_label=False,
|
| 229 |
+
).to(self.device)
|
| 230 |
+
|
| 231 |
+
# ablation 3: voxelvae_1volume, have tested
|
| 232 |
+
self.vae = VoxelVAE(
|
| 233 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 234 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 235 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 236 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 237 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 238 |
+
num_heads=8,
|
| 239 |
+
num_head_channels=64,
|
| 240 |
+
mlp_ratio=4.0,
|
| 241 |
+
attn_mode="swin",
|
| 242 |
+
window_size=8,
|
| 243 |
+
pe_mode="ape",
|
| 244 |
+
use_fp16=False,
|
| 245 |
+
use_checkpoint=True,
|
| 246 |
+
qk_rms_norm=False,
|
| 247 |
+
using_subdivide=True,
|
| 248 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 249 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 250 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 251 |
+
).to(self.device)
|
| 252 |
+
|
| 253 |
+
if self.cfg['training']['from_pretrained']:
|
| 254 |
+
load_pretrained_woself(
|
| 255 |
+
checkpoint_path=self.cfg['training']['checkpoint_path'],
|
| 256 |
+
voxel_encoder=self.voxel_encoder,
|
| 257 |
+
vae=self.vae,
|
| 258 |
+
optimizer=None,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 262 |
+
self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 263 |
+
|
| 264 |
+
def init_training(self):
|
| 265 |
+
self.optimizer = AdamW(
|
| 266 |
+
list(self.vae.module.parameters()) +
|
| 267 |
+
list(self.voxel_encoder.module.parameters()),
|
| 268 |
+
lr=self.cfg['training']['lr'],
|
| 269 |
+
weight_decay=0.01,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
|
| 273 |
+
# print('num_update_steps_per_epoch', num_update_steps_per_epoch) # 1305
|
| 274 |
+
max_epochs = self.cfg['training']['max_epochs']
|
| 275 |
+
num_training_steps = max_epochs * num_update_steps_per_epoch
|
| 276 |
+
|
| 277 |
+
num_warmup_steps = 1000
|
| 278 |
+
|
| 279 |
+
# self.scheduler = torch.optim.lr_scheduler.LambdaLR(
|
| 280 |
+
# self.optimizer,
|
| 281 |
+
# lr_lambda=lambda epoch: self.cfg['training']['gamma'] ** (epoch // self.cfg['training']['step_size'])
|
| 282 |
+
# )
|
| 283 |
+
|
| 284 |
+
self.scheduler = get_cosine_schedule_with_warmup(
|
| 285 |
+
self.optimizer,
|
| 286 |
+
num_warmup_steps=num_warmup_steps,
|
| 287 |
+
num_training_steps=num_training_steps
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=10.0).to(self.device)
|
| 291 |
+
self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
|
| 292 |
+
self.asyloss = AsymmetricFocalLoss(
|
| 293 |
+
gamma_pos=0.0,
|
| 294 |
+
gamma_neg=4.0,
|
| 295 |
+
clip=0.05,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.bce_loss = torch.nn.BCEWithLogitsLoss()
|
| 299 |
+
|
| 300 |
+
self.dice_loss = DiceLoss()
|
| 301 |
+
self.scaler = GradScaler()
|
| 302 |
+
|
| 303 |
+
def train_step(self, batch):
|
| 304 |
+
"""Modified training step that handles vertex and edge voxels separately after initial prediction."""
|
| 305 |
+
# 1. Retrieve data from batch
|
| 306 |
+
combined_voxels_1024 = batch['combined_voxels_1024'].to(self.device)
|
| 307 |
+
combined_voxel_labels_1024 = batch['combined_voxel_labels_1024'].to(self.device)
|
| 308 |
+
gt_vertex_voxels_1024 = batch['gt_vertex_voxels_1024'].to(self.device)
|
| 309 |
+
|
| 310 |
+
gt_edge_voxels_1024_ = batch['gt_edge_voxels_1024'].to(self.device)
|
| 311 |
+
gt_combined_endpoints_1024 = batch['gt_combined_endpoints_1024'].to(self.device)
|
| 312 |
+
gt_combined_errors_1024 = batch['gt_combined_errors_1024'].to(self.device)
|
| 313 |
+
|
| 314 |
+
edge_mask = (combined_voxel_labels_1024 == 1)
|
| 315 |
+
|
| 316 |
+
gt_edge_endpoints_1024 = gt_combined_endpoints_1024[edge_mask]
|
| 317 |
+
gt_edge_errors_1024 = gt_combined_errors_1024[edge_mask]
|
| 318 |
+
gt_edge_voxels_1024 = combined_voxels_1024[edge_mask].to(self.device)
|
| 319 |
+
|
| 320 |
+
# print('gt_edge_voxels_1024_-gt_edge_voxels_1024.sum()', (gt_edge_voxels_1024_-gt_edge_voxels_1024).sum())
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
p1 = gt_edge_endpoints_1024[:, 1:4].float()
|
| 324 |
+
p2 = gt_edge_endpoints_1024[:, 4:7].float()
|
| 325 |
+
|
| 326 |
+
mask = ( (p1[:,0] < p2[:,0]) |
|
| 327 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
|
| 328 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
|
| 329 |
+
|
| 330 |
+
pA = torch.where(mask[:, None], p1, p2) # smaller one
|
| 331 |
+
pB = torch.where(mask[:, None], p2, p1) # larger one
|
| 332 |
+
|
| 333 |
+
d = pB - pA
|
| 334 |
+
dir_gt = F.normalize(d, dim=-1, eps=1e-6)
|
| 335 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 336 |
+
vtx_256 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=256)
|
| 337 |
+
vtx_512 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=512)
|
| 338 |
+
vtx_1024 = gt_vertex_voxels_1024
|
| 339 |
+
|
| 340 |
+
edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 341 |
+
edge_256 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=256)
|
| 342 |
+
edge_512 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=512)
|
| 343 |
+
edge_1024 = combined_voxels_1024
|
| 344 |
+
|
| 345 |
+
active_coords = batch['active_voxels_128'].to(self.device)
|
| 346 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 347 |
+
|
| 348 |
+
active_voxel_feats = self.voxel_encoder(
|
| 349 |
+
p=point_cloud,
|
| 350 |
+
sparse_coords=active_coords,
|
| 351 |
+
res=128,
|
| 352 |
+
bbox_size=(-0.5, 0.5),
|
| 353 |
+
|
| 354 |
+
# voxel_label=active_labels,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
sparse_input = SparseTensor(
|
| 358 |
+
feats=active_voxel_feats,
|
| 359 |
+
coords=active_coords.int()
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
gt_edge_voxels_list = [
|
| 363 |
+
edge_128,
|
| 364 |
+
edge_256,
|
| 365 |
+
edge_512,
|
| 366 |
+
edge_1024,
|
| 367 |
+
]
|
| 368 |
+
|
| 369 |
+
gt_vertex_voxels_list = [
|
| 370 |
+
vtx_128,
|
| 371 |
+
vtx_256,
|
| 372 |
+
vtx_512,
|
| 373 |
+
vtx_1024,
|
| 374 |
+
]
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
results, posterior, latent_128 = self.vae(
|
| 378 |
+
sparse_input,
|
| 379 |
+
gt_vertex_voxels_list=gt_vertex_voxels_list,
|
| 380 |
+
gt_edge_voxels_list=gt_edge_voxels_list,
|
| 381 |
+
training=True,
|
| 382 |
+
sample_ratio=0.,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# print("results[-1]['edge']['coords_4d'][1827:1830]", results[-1]['edge']['coords_4d'][1827:1830])
|
| 386 |
+
total_loss = 0.
|
| 387 |
+
prune_loss_total = 0.
|
| 388 |
+
vertex_loss_total = 0.
|
| 389 |
+
edge_loss_total=0.
|
| 390 |
+
|
| 391 |
+
with autocast(dtype=torch.bfloat16):
|
| 392 |
+
initial_result = results[0]
|
| 393 |
+
vertex_mask = initial_result['vertex_mask']
|
| 394 |
+
vtx_logits = initial_result['vtx_feats']
|
| 395 |
+
vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
|
| 396 |
+
|
| 397 |
+
edge_mask = initial_result['edge_mask']
|
| 398 |
+
edge_logits = initial_result['edge_feats']
|
| 399 |
+
edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
|
| 400 |
+
|
| 401 |
+
vertex_loss_total += vertex_loss
|
| 402 |
+
edge_loss_total += edge_loss
|
| 403 |
+
|
| 404 |
+
total_loss += vertex_loss
|
| 405 |
+
total_loss += edge_loss
|
| 406 |
+
|
| 407 |
+
# Process each level's results
|
| 408 |
+
for idx, res_dict in enumerate(results[1:], start=1):
|
| 409 |
+
# Vertex branch losses
|
| 410 |
+
vertex_pred_coords = res_dict['vertex']['occ_coords']
|
| 411 |
+
vertex_occ_probs = res_dict['vertex']['occ_probs']
|
| 412 |
+
vertex_gt_coords = res_dict['vertex']['coords']
|
| 413 |
+
|
| 414 |
+
vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=1024).float()
|
| 415 |
+
# print('vertex_labels.sum()', vertex_labels.sum(), idx)
|
| 416 |
+
vertex_logits = vertex_occ_probs.squeeze()
|
| 417 |
+
|
| 418 |
+
# if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
|
| 419 |
+
vertex_prune_loss = self.focal_loss(vertex_logits, vertex_labels)
|
| 420 |
+
# vertex_prune_loss = self.dice_loss(vertex_logits, vertex_labels)
|
| 421 |
+
|
| 422 |
+
# dilation 1: bce loss
|
| 423 |
+
# vertex_prune_loss = self.bce_loss(vertex_logits, vertex_labels,)
|
| 424 |
+
|
| 425 |
+
prune_loss_total += vertex_prune_loss
|
| 426 |
+
total_loss += vertex_prune_loss
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# Edge branch losses
|
| 430 |
+
edge_pred_coords = res_dict['edge']['occ_coords']
|
| 431 |
+
edge_occ_probs = res_dict['edge']['occ_probs']
|
| 432 |
+
edge_gt_coords = res_dict['edge']['coords']
|
| 433 |
+
|
| 434 |
+
edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=1024).float()
|
| 435 |
+
# print('edge_labels.sum()', edge_labels.sum(), idx)
|
| 436 |
+
edge_logits = edge_occ_probs.squeeze()
|
| 437 |
+
# if edge_labels.sum() > 0 and edge_labels.sum() < len(edge_labels):
|
| 438 |
+
edge_prune_loss = self.focal_loss(edge_logits, edge_labels)
|
| 439 |
+
|
| 440 |
+
# dilation 1: bce loss
|
| 441 |
+
# edge_prune_loss = self.bce_loss(edge_logits, edge_labels,)
|
| 442 |
+
|
| 443 |
+
prune_loss_total += edge_prune_loss
|
| 444 |
+
total_loss += edge_prune_loss
|
| 445 |
+
|
| 446 |
+
if idx == 3:
|
| 447 |
+
pred_coords = res_dict['edge']['coords_4d'] # [N,4] (b,x,y,z)
|
| 448 |
+
pred_feats = res_dict['edge']['predicted_offset_feats'] # [N,C]
|
| 449 |
+
|
| 450 |
+
gt_coords = gt_edge_voxels_1024.to(pred_coords.device) # [M,4]
|
| 451 |
+
gt_feats = gt_edge_errors_1024[:, 1:].to(pred_coords.device) # [M,C]
|
| 452 |
+
|
| 453 |
+
pred_keys = flatten_coords_4d(pred_coords)
|
| 454 |
+
gt_keys = flatten_coords_4d(gt_coords)
|
| 455 |
+
|
| 456 |
+
sorted_pred_keys, pred_order = torch.sort(pred_keys)
|
| 457 |
+
pred_coords_sorted = pred_coords[pred_order]
|
| 458 |
+
pred_feats_sorted = pred_feats[pred_order]
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
sorted_gt_keys, gt_order = torch.sort(gt_keys)
|
| 462 |
+
gt_coords_sorted = gt_coords[gt_order]
|
| 463 |
+
gt_feats_sorted = gt_feats[gt_order]
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# pos = torch.searchsorted(sorted_gt_keys, sorted_pred_keys)
|
| 467 |
+
# valid_pos = pos < len(sorted_gt_keys)
|
| 468 |
+
# matched = torch.zeros_like(valid_pos, dtype=torch.bool)
|
| 469 |
+
# matched[valid_pos] = (sorted_gt_keys[pos[valid_pos]] == sorted_pred_keys[valid_pos])
|
| 470 |
+
# valid_mask = valid_pos & matched
|
| 471 |
+
|
| 472 |
+
pos = torch.searchsorted(sorted_gt_keys, sorted_pred_keys)
|
| 473 |
+
valid_mask = (pos < len(sorted_gt_keys)) & (sorted_gt_keys[pos] == sorted_pred_keys)
|
| 474 |
+
|
| 475 |
+
if valid_mask.any():
|
| 476 |
+
matched_pred_feats = pred_feats_sorted[valid_mask]
|
| 477 |
+
matched_gt_feats = gt_feats_sorted[pos[valid_mask]]
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
mse_loss_feats = self.mse_loss(matched_pred_feats, matched_gt_feats * 2)
|
| 481 |
+
total_loss += mse_loss_feats
|
| 482 |
+
|
| 483 |
+
if self.cfg['model'].get('pred_direction', False):
|
| 484 |
+
pred_dirs = res_dict['edge']['predicted_direction_feats']
|
| 485 |
+
dir_gt_device = dir_gt.to(pred_coords.device)
|
| 486 |
+
|
| 487 |
+
pred_dirs_sorted = pred_dirs[pred_order]
|
| 488 |
+
dir_gt_sorted = dir_gt_device[gt_order]
|
| 489 |
+
|
| 490 |
+
matched_pred_dirs = pred_dirs_sorted[valid_mask]
|
| 491 |
+
matched_gt_dirs = dir_gt_sorted[pos[valid_mask]]
|
| 492 |
+
|
| 493 |
+
mse_loss_dirs = self.mse_loss(matched_pred_dirs, matched_gt_dirs)
|
| 494 |
+
total_loss += mse_loss_dirs
|
| 495 |
+
else:
|
| 496 |
+
mse_loss_feats = torch.tensor(0., device=pred_coords.device)
|
| 497 |
+
|
| 498 |
+
if self.cfg['model'].get('pred_direction', False):
|
| 499 |
+
mse_loss_dirs = torch.tensor(0., device=pred_coords.device)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# KL loss
|
| 503 |
+
kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
|
| 504 |
+
total_loss += kl_loss
|
| 505 |
+
|
| 506 |
+
# Backpropagation
|
| 507 |
+
scaled_total_loss = total_loss / self.accum_steps
|
| 508 |
+
self.scaler.scale(scaled_total_loss).backward()
|
| 509 |
+
|
| 510 |
+
return {
|
| 511 |
+
'total_loss': total_loss.item(),
|
| 512 |
+
'kl_loss': kl_loss.item(),
|
| 513 |
+
'prune_loss': prune_loss_total.item(),
|
| 514 |
+
'vertex_loss': vertex_loss_total.item(),
|
| 515 |
+
'edge_loss': edge_loss_total.item(),
|
| 516 |
+
'offset_loss': mse_loss_feats.item(),
|
| 517 |
+
'direction_loss': mse_loss_dirs.item(),
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def train(self):
|
| 522 |
+
accum_steps = self.accum_steps
|
| 523 |
+
for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
|
| 524 |
+
self.dataloader.sampler.set_epoch(epoch)
|
| 525 |
+
# Initialize metrics
|
| 526 |
+
metrics = {
|
| 527 |
+
'total_loss': 0.0,
|
| 528 |
+
'kl_loss': 0.0,
|
| 529 |
+
'prune_loss': 0.0,
|
| 530 |
+
'vertex_loss': 0.0,
|
| 531 |
+
'edge_loss': 0.0,
|
| 532 |
+
'offset_loss': 0.0,
|
| 533 |
+
'direction_loss': 0.0,
|
| 534 |
+
}
|
| 535 |
+
num_batches = 0
|
| 536 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 537 |
+
|
| 538 |
+
for i, batch in enumerate(self.dataloader):
|
| 539 |
+
# Get all losses from train_step
|
| 540 |
+
if batch is None:
|
| 541 |
+
continue
|
| 542 |
+
step_losses = self.train_step(batch)
|
| 543 |
+
|
| 544 |
+
# Accumulate losses
|
| 545 |
+
for key in metrics:
|
| 546 |
+
metrics[key] += step_losses[key]
|
| 547 |
+
num_batches += 1
|
| 548 |
+
|
| 549 |
+
if (i + 1) % accum_steps == 0:
|
| 550 |
+
self.scaler.unscale_(self.optimizer)
|
| 551 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
|
| 552 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
|
| 553 |
+
|
| 554 |
+
self.scaler.step(self.optimizer)
|
| 555 |
+
self.scaler.update()
|
| 556 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 557 |
+
|
| 558 |
+
self.scheduler.step()
|
| 559 |
+
|
| 560 |
+
# Print batch-level metrics
|
| 561 |
+
if self.is_master:
|
| 562 |
+
avg_metric = {key: value / num_batches for key, value in metrics.items()}
|
| 563 |
+
print(
|
| 564 |
+
f"[Epoch {epoch}] Batch:{num_batches} "
|
| 565 |
+
f"AvgL:{avg_metric['total_loss']:.4f} "
|
| 566 |
+
f"Loss: {step_losses['total_loss']:.4f}, "
|
| 567 |
+
f"KLL: {step_losses['kl_loss']:.4f}, "
|
| 568 |
+
f"PruneL: {step_losses['prune_loss']:.4f}, "
|
| 569 |
+
f"VertexL: {step_losses['vertex_loss']:.4f}, "
|
| 570 |
+
f"EdgeL: {step_losses['edge_loss']:.4f}, "
|
| 571 |
+
f"OffsetL: {step_losses['offset_loss']:.4f}, "
|
| 572 |
+
f"DireL: {step_losses['direction_loss']:.4f}, "
|
| 573 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
if i % 2000 == 0 and i != 0:
|
| 577 |
+
self.save_checkpoint(epoch, avg_metric['total_loss'], i)
|
| 578 |
+
with open(self.log_file, "a") as f:
|
| 579 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 580 |
+
log_line = (
|
| 581 |
+
f"Epoch {epoch:05d} | "
|
| 582 |
+
f"Batch {i:05d} | "
|
| 583 |
+
f"Loss: {avg_metric['total_loss']:.6f} "
|
| 584 |
+
f"Avg KLL: {avg_metric['kl_loss']:.4f} "
|
| 585 |
+
f"Avg PruneL: {avg_metric['prune_loss']:.4f} "
|
| 586 |
+
f"Avg VertexL: {avg_metric['vertex_loss']:.4f} "
|
| 587 |
+
f"Avg EdgeL: {avg_metric['edge_loss']:.4f} "
|
| 588 |
+
f"Avg OffsetL: {avg_metric['offset_loss']:.4f} "
|
| 589 |
+
f"Avg DireL: {avg_metric['direction_loss']:.4f} "
|
| 590 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 591 |
+
f"[{current_time}]\n"
|
| 592 |
+
)
|
| 593 |
+
f.write(log_line)
|
| 594 |
+
|
| 595 |
+
if num_batches % accum_steps != 0:
|
| 596 |
+
self.scaler.unscale_(self.optimizer)
|
| 597 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
|
| 598 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
|
| 599 |
+
|
| 600 |
+
self.scaler.step(self.optimizer)
|
| 601 |
+
self.scaler.update()
|
| 602 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 603 |
+
|
| 604 |
+
self.scheduler.step()
|
| 605 |
+
|
| 606 |
+
# Calculate epoch averages
|
| 607 |
+
avg_metrics = {key: value / num_batches for key, value in metrics.items()}
|
| 608 |
+
self.train_loss_history.append(avg_metrics['total_loss'])
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# Log to file
|
| 612 |
+
if self.is_master:
|
| 613 |
+
with open(self.log_file, "a") as f:
|
| 614 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 615 |
+
log_line = (
|
| 616 |
+
f"Epoch {epoch:05d} | "
|
| 617 |
+
f"Loss: {avg_metrics['total_loss']:.6f} "
|
| 618 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 619 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 620 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 621 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 622 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 623 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 624 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 625 |
+
f"[{current_time}]\n"
|
| 626 |
+
)
|
| 627 |
+
f.write(log_line)
|
| 628 |
+
|
| 629 |
+
# Print epoch summary
|
| 630 |
+
print(
|
| 631 |
+
f"[Epoch {epoch}] "
|
| 632 |
+
f"Avg Loss: {avg_metrics['total_loss']:.4f} "
|
| 633 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 634 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 635 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 636 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 637 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 638 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 639 |
+
f"[{current_time}]\n"
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# Save checkpoint
|
| 643 |
+
if epoch % self.cfg['training']['save_every'] == 0:
|
| 644 |
+
self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
|
| 645 |
+
|
| 646 |
+
# Update learning rate
|
| 647 |
+
if self.is_master:
|
| 648 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 649 |
+
print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
|
| 650 |
+
|
| 651 |
+
dist.barrier()
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def main():
|
| 655 |
+
# Initialize the process group
|
| 656 |
+
dist.init_process_group(backend='nccl')
|
| 657 |
+
|
| 658 |
+
# Get rank and world size from environment variables set by the launcher
|
| 659 |
+
rank = int(os.environ['RANK'])
|
| 660 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 661 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 662 |
+
|
| 663 |
+
# Set the device for the current process. This is crucial.
|
| 664 |
+
torch.cuda.set_device(local_rank)
|
| 665 |
+
torch.manual_seed(42+rank)
|
| 666 |
+
|
| 667 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 668 |
+
# Pass the distributed info to the Trainer
|
| 669 |
+
trainer = Trainer(
|
| 670 |
+
config_path="/gemini/user/private/zhaotianhao/Triposf/config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_addhead.yaml",
|
| 671 |
+
rank=rank,
|
| 672 |
+
world_size=world_size,
|
| 673 |
+
local_rank=local_rank
|
| 674 |
+
)
|
| 675 |
+
trainer.train()
|
| 676 |
+
|
| 677 |
+
# Clean up the process group
|
| 678 |
+
dist.destroy_process_group()
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
if __name__ == '__main__':
|
| 682 |
+
main()
|
train_slat_vae_512_128to1024_pointnet_addhead.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import os
|
| 3 |
+
# os.environ['ATTN_BACKEND'] = 'xformers'
|
| 4 |
+
import yaml
|
| 5 |
+
import torch
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from functools import partial
|
| 10 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 14 |
+
|
| 15 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_addhead import VoxelVAE
|
| 16 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 17 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 18 |
+
|
| 19 |
+
from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss, FocalLoss
|
| 20 |
+
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 23 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 24 |
+
|
| 25 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 26 |
+
import math
|
| 27 |
+
|
| 28 |
+
from torchvision.ops import sigmoid_focal_loss
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import open3d as o3d
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def flatten_coords_4d(coords_4d: torch.Tensor):
|
| 35 |
+
coords_4d_long = coords_4d.long()
|
| 36 |
+
|
| 37 |
+
base_x = 1024
|
| 38 |
+
base_y = 1024 * 1024
|
| 39 |
+
base_z = 1024 * 1024 * 1024
|
| 40 |
+
|
| 41 |
+
flat_coords = coords_4d_long[:, 0] * base_z + \
|
| 42 |
+
coords_4d_long[:, 1] * base_y + \
|
| 43 |
+
coords_4d_long[:, 2] * base_x + \
|
| 44 |
+
coords_4d_long[:, 3]
|
| 45 |
+
return flat_coords
|
| 46 |
+
|
| 47 |
+
def downsample_voxels(
|
| 48 |
+
voxels: torch.Tensor,
|
| 49 |
+
input_resolution: int,
|
| 50 |
+
output_resolution: int
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
if input_resolution % output_resolution != 0:
|
| 53 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 54 |
+
f"by output_resolution ({output_resolution}).")
|
| 55 |
+
|
| 56 |
+
factor = input_resolution // output_resolution
|
| 57 |
+
|
| 58 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 59 |
+
|
| 60 |
+
downsampled_voxels[:, 1:] //= factor
|
| 61 |
+
|
| 62 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 63 |
+
return unique_downsampled_voxels
|
| 64 |
+
|
| 65 |
+
class Trainer:
|
| 66 |
+
def __init__(self, config_path, rank, world_size, local_rank):
|
| 67 |
+
self.rank = rank
|
| 68 |
+
self.world_size = world_size
|
| 69 |
+
self.local_rank = local_rank
|
| 70 |
+
self.is_master = self.rank == 0
|
| 71 |
+
|
| 72 |
+
self.load_config(config_path)
|
| 73 |
+
self.accum_steps = max(1, 8 // self.cfg['training']['batch_size'])
|
| 74 |
+
|
| 75 |
+
self.config_hash = self.save_config_with_hash()
|
| 76 |
+
self.init_device()
|
| 77 |
+
self.init_dirs()
|
| 78 |
+
self.init_components()
|
| 79 |
+
self.init_training()
|
| 80 |
+
|
| 81 |
+
self.train_loss_history = []
|
| 82 |
+
self.eval_loss_history = []
|
| 83 |
+
self.best_eval_loss = float('inf')
|
| 84 |
+
|
| 85 |
+
def save_config_with_hash(self):
|
| 86 |
+
import hashlib
|
| 87 |
+
|
| 88 |
+
# Serialize config to hash
|
| 89 |
+
config_str = yaml.dump(self.cfg)
|
| 90 |
+
config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
|
| 91 |
+
|
| 92 |
+
# Prepare all flags as string for formatting
|
| 93 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 94 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 95 |
+
|
| 96 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 97 |
+
|
| 98 |
+
# Format save_dir with all placeholders
|
| 99 |
+
self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
|
| 100 |
+
dataset_name=dataset_name,
|
| 101 |
+
config_hash=config_hash,
|
| 102 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 103 |
+
multires=self.cfg['model']['multires'],
|
| 104 |
+
add_block_embed=add_block_embed_flag,
|
| 105 |
+
using_attn=using_attn_flag,
|
| 106 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if self.is_master:
|
| 110 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 111 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 112 |
+
with open(config_path, 'w') as f:
|
| 113 |
+
yaml.dump(self.cfg, f)
|
| 114 |
+
|
| 115 |
+
dist.barrier()
|
| 116 |
+
return config_hash
|
| 117 |
+
|
| 118 |
+
def save_checkpoint(self, epoch, avg_loss, batch_idx):
|
| 119 |
+
if not self.is_master:
|
| 120 |
+
return
|
| 121 |
+
checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
|
| 122 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 123 |
+
|
| 124 |
+
torch.save({
|
| 125 |
+
'voxel_encoder': self.voxel_encoder.module.state_dict(),
|
| 126 |
+
'vae': self.vae.module.state_dict(),
|
| 127 |
+
'connection_head': self.connection_head.module.state_dict(),
|
| 128 |
+
'epoch': epoch,
|
| 129 |
+
'loss': avg_loss,
|
| 130 |
+
'config': self.cfg
|
| 131 |
+
}, checkpoint_path)
|
| 132 |
+
|
| 133 |
+
def quoted_presenter(dumper, data):
|
| 134 |
+
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
|
| 135 |
+
|
| 136 |
+
yaml.add_representer(str, quoted_presenter)
|
| 137 |
+
|
| 138 |
+
with open(config_path, 'w') as f:
|
| 139 |
+
yaml.dump(self.cfg, f)
|
| 140 |
+
|
| 141 |
+
def load_config(self, config_path):
|
| 142 |
+
with open(config_path) as f:
|
| 143 |
+
self.cfg = yaml.safe_load(f)
|
| 144 |
+
|
| 145 |
+
# Extract and convert flags for formatting
|
| 146 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 147 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 148 |
+
|
| 149 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 150 |
+
|
| 151 |
+
self.save_dir = self.cfg['experiment']['save_dir'].format(
|
| 152 |
+
dataset_name=dataset_name,
|
| 153 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 154 |
+
multires=self.cfg['model']['multires'],
|
| 155 |
+
add_block_embed=add_block_embed_flag,
|
| 156 |
+
using_attn=using_attn_flag,
|
| 157 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if self.is_master:
|
| 161 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 162 |
+
dist.barrier()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def init_device(self):
|
| 166 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 167 |
+
|
| 168 |
+
def init_dirs(self):
|
| 169 |
+
self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
|
| 170 |
+
if self.is_master:
|
| 171 |
+
with open(self.log_file, "a") as f:
|
| 172 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 173 |
+
f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
|
| 174 |
+
|
| 175 |
+
def init_components(self):
|
| 176 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 177 |
+
root_dir=self.cfg['dataset']['path'],
|
| 178 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 179 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 180 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 181 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 182 |
+
|
| 183 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 184 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 185 |
+
|
| 186 |
+
active_voxel_res=128,
|
| 187 |
+
pc_sample_number=819200,
|
| 188 |
+
|
| 189 |
+
sample_type=self.cfg['dataset']['sample_type'],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.sampler = DistributedSampler(
|
| 193 |
+
self.dataset,
|
| 194 |
+
num_replicas=self.world_size,
|
| 195 |
+
rank=self.rank,
|
| 196 |
+
shuffle=True,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
self.dataloader = DataLoader(
|
| 200 |
+
self.dataset,
|
| 201 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 202 |
+
shuffle=False,
|
| 203 |
+
collate_fn=partial(collate_fn_pointnet,),
|
| 204 |
+
num_workers=self.cfg['training']['num_workers'],
|
| 205 |
+
pin_memory=True,
|
| 206 |
+
sampler=self.sampler,
|
| 207 |
+
prefetch_factor=4,
|
| 208 |
+
persistent_workers=True,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 212 |
+
in_channels=15,
|
| 213 |
+
hidden_dim=256,
|
| 214 |
+
out_channels=1024,
|
| 215 |
+
scatter_type='mean',
|
| 216 |
+
n_blocks=5,
|
| 217 |
+
resolution=128,
|
| 218 |
+
add_label=False,
|
| 219 |
+
).to(self.device)
|
| 220 |
+
|
| 221 |
+
self.connection_head = ConnectionHead(
|
| 222 |
+
channels=32 * 2,
|
| 223 |
+
out_channels=1,
|
| 224 |
+
mlp_ratio=16,
|
| 225 |
+
).to(self.device)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# ablation 3: voxelvae_1volume, have tested
|
| 229 |
+
self.vae = VoxelVAE(
|
| 230 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 231 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 232 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 233 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 234 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 235 |
+
num_heads=8,
|
| 236 |
+
num_head_channels=64,
|
| 237 |
+
mlp_ratio=4.0,
|
| 238 |
+
attn_mode="swin",
|
| 239 |
+
window_size=8,
|
| 240 |
+
pe_mode="ape",
|
| 241 |
+
use_fp16=False,
|
| 242 |
+
use_checkpoint=True,
|
| 243 |
+
qk_rms_norm=False,
|
| 244 |
+
using_subdivide=True,
|
| 245 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 246 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 247 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 248 |
+
).to(self.device)
|
| 249 |
+
|
| 250 |
+
if self.cfg['training']['from_pretrained']:
|
| 251 |
+
load_pretrained_woself(
|
| 252 |
+
checkpoint_path=self.cfg['training']['checkpoint_path'],
|
| 253 |
+
voxel_encoder=self.voxel_encoder,
|
| 254 |
+
vae=self.vae,
|
| 255 |
+
connection_head=self.connection_head,
|
| 256 |
+
optimizer=None,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 260 |
+
self.connection_head = DDP(self.connection_head, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 261 |
+
self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=True)
|
| 262 |
+
|
| 263 |
+
def init_training(self):
|
| 264 |
+
self.optimizer = AdamW(
|
| 265 |
+
list(self.vae.module.parameters()) +
|
| 266 |
+
list(self.voxel_encoder.module.parameters()) +
|
| 267 |
+
list(self.connection_head.module.parameters()),
|
| 268 |
+
lr=self.cfg['training']['lr'],
|
| 269 |
+
weight_decay=0.01,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
|
| 273 |
+
max_epochs = self.cfg['training']['max_epochs']
|
| 274 |
+
num_training_steps = max_epochs * num_update_steps_per_epoch
|
| 275 |
+
|
| 276 |
+
num_warmup_steps = 400
|
| 277 |
+
|
| 278 |
+
self.scheduler = get_cosine_schedule_with_warmup(
|
| 279 |
+
self.optimizer,
|
| 280 |
+
num_warmup_steps=num_warmup_steps,
|
| 281 |
+
num_training_steps=num_training_steps
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=10.0).to(self.device)
|
| 285 |
+
self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
|
| 286 |
+
self.asyloss = AsymmetricFocalLoss(
|
| 287 |
+
gamma_pos=0.0,
|
| 288 |
+
gamma_neg=4.0,
|
| 289 |
+
clip=0.05,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
self.bce_loss = torch.nn.BCEWithLogitsLoss()
|
| 293 |
+
|
| 294 |
+
self.dice_loss = DiceLoss()
|
| 295 |
+
self.scaler = GradScaler()
|
| 296 |
+
|
| 297 |
+
def train_step(self, batch, b_idx):
|
| 298 |
+
"""Modified training step that handles vertex and edge voxels separately after initial prediction."""
|
| 299 |
+
# 1. Retrieve data from batch
|
| 300 |
+
combined_voxels_1024 = batch['combined_voxels_1024'].to(self.device)
|
| 301 |
+
combined_voxel_labels_1024 = batch['combined_voxel_labels_1024'].to(self.device)
|
| 302 |
+
gt_vertex_voxels_1024 = batch['gt_vertex_voxels_1024'].to(self.device)
|
| 303 |
+
|
| 304 |
+
# gt_edge_voxels_1024 = batch['gt_edge_voxels_1024'].to(self.device)
|
| 305 |
+
# gt_combined_endpoints_1024 = batch['gt_combined_endpoints_1024'].to(self.device)
|
| 306 |
+
# gt_combined_errors_1024 = batch['gt_combined_errors_1024'].to(self.device)
|
| 307 |
+
|
| 308 |
+
# gt_edges = batch['gt_vertex_edge_indices_256'].to(self.device)
|
| 309 |
+
gt_edges = batch['gt_vertex_edge_indices_1024'].to(self.device)
|
| 310 |
+
|
| 311 |
+
edge_mask = (combined_voxel_labels_1024 == 1)
|
| 312 |
+
|
| 313 |
+
# gt_edge_endpoints_1024 = gt_combined_endpoints_1024[edge_mask]
|
| 314 |
+
# gt_edge_errors_1024 = gt_combined_errors_1024[edge_mask]
|
| 315 |
+
|
| 316 |
+
# p1 = gt_edge_endpoints_1024[:, 1:4].float()
|
| 317 |
+
# p2 = gt_edge_endpoints_1024[:, 4:7].float()
|
| 318 |
+
|
| 319 |
+
# mask = ( (p1[:,0] < p2[:,0]) |
|
| 320 |
+
# ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
|
| 321 |
+
# ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
|
| 322 |
+
|
| 323 |
+
# pA = torch.where(mask[:, None], p1, p2) # smaller one
|
| 324 |
+
# pB = torch.where(mask[:, None], p2, p1) # larger one
|
| 325 |
+
|
| 326 |
+
# d = pB - pA
|
| 327 |
+
# dir_gt = F.normalize(d, dim=-1, eps=1e-6)
|
| 328 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 329 |
+
vtx_256 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=256)
|
| 330 |
+
vtx_512 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=512)
|
| 331 |
+
vtx_1024 = gt_vertex_voxels_1024
|
| 332 |
+
|
| 333 |
+
edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 334 |
+
edge_256 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=256)
|
| 335 |
+
edge_512 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=512)
|
| 336 |
+
edge_1024 = combined_voxels_1024
|
| 337 |
+
|
| 338 |
+
active_coords = batch['active_voxels_128'].to(self.device)
|
| 339 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 340 |
+
|
| 341 |
+
with autocast(dtype=torch.bfloat16):
|
| 342 |
+
active_voxel_feats = self.voxel_encoder(
|
| 343 |
+
p=point_cloud,
|
| 344 |
+
sparse_coords=active_coords,
|
| 345 |
+
res=128,
|
| 346 |
+
bbox_size=(-0.5, 0.5),
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
sparse_input = SparseTensor(
|
| 350 |
+
feats=active_voxel_feats,
|
| 351 |
+
coords=active_coords.int()
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
gt_edge_voxels_list = [
|
| 355 |
+
edge_128,
|
| 356 |
+
edge_256,
|
| 357 |
+
edge_512,
|
| 358 |
+
edge_1024,
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
gt_vertex_voxels_list = [
|
| 362 |
+
vtx_128,
|
| 363 |
+
vtx_256,
|
| 364 |
+
vtx_512,
|
| 365 |
+
vtx_1024,
|
| 366 |
+
]
|
| 367 |
+
|
| 368 |
+
results, posterior, latent_128 = self.vae(
|
| 369 |
+
sparse_input,
|
| 370 |
+
gt_vertex_voxels_list=gt_vertex_voxels_list,
|
| 371 |
+
gt_edge_voxels_list=gt_edge_voxels_list,
|
| 372 |
+
training=True,
|
| 373 |
+
sample_ratio=0.,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
total_loss = 0.
|
| 377 |
+
prune_loss_total = 0.
|
| 378 |
+
vertex_loss_total = 0.
|
| 379 |
+
edge_loss_total=0.
|
| 380 |
+
|
| 381 |
+
initial_result = results[0]
|
| 382 |
+
vertex_mask = initial_result['vertex_mask']
|
| 383 |
+
vtx_logits = initial_result['vtx_feats']
|
| 384 |
+
vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
|
| 385 |
+
|
| 386 |
+
edge_mask = initial_result['edge_mask']
|
| 387 |
+
edge_logits = initial_result['edge_feats']
|
| 388 |
+
edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
|
| 389 |
+
# edge_loss = self.bce_loss(edge_logits.squeeze(-1), edge_mask.float())
|
| 390 |
+
|
| 391 |
+
vertex_loss_total += vertex_loss
|
| 392 |
+
edge_loss_total += edge_loss
|
| 393 |
+
|
| 394 |
+
total_loss += vertex_loss
|
| 395 |
+
total_loss += edge_loss
|
| 396 |
+
|
| 397 |
+
# Process each level's results
|
| 398 |
+
for idx, res_dict in enumerate(results[1:], start=1):
|
| 399 |
+
# Vertex branch losses
|
| 400 |
+
vertex_pred_coords = res_dict['vertex']['occ_coords']
|
| 401 |
+
vertex_occ_probs = res_dict['vertex']['occ_probs']
|
| 402 |
+
vertex_gt_coords = res_dict['vertex']['coords']
|
| 403 |
+
|
| 404 |
+
vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=1024).float()
|
| 405 |
+
# print('vertex_labels.sum()', vertex_labels.sum(), idx)
|
| 406 |
+
vertex_logits = vertex_occ_probs.squeeze()
|
| 407 |
+
|
| 408 |
+
# if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
|
| 409 |
+
vertex_prune_loss = self.focal_loss(vertex_logits, vertex_labels)
|
| 410 |
+
|
| 411 |
+
prune_loss_total += vertex_prune_loss
|
| 412 |
+
total_loss += vertex_prune_loss
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# Edge branch losses
|
| 416 |
+
edge_pred_coords = res_dict['edge']['occ_coords']
|
| 417 |
+
edge_occ_probs = res_dict['edge']['occ_probs']
|
| 418 |
+
edge_gt_coords = res_dict['edge']['coords']
|
| 419 |
+
|
| 420 |
+
edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=1024).float()
|
| 421 |
+
edge_logits = edge_occ_probs.squeeze()
|
| 422 |
+
edge_prune_loss = self.focal_loss(edge_logits, edge_labels)
|
| 423 |
+
|
| 424 |
+
prune_loss_total += edge_prune_loss
|
| 425 |
+
total_loss += edge_prune_loss
|
| 426 |
+
|
| 427 |
+
if idx == 3:
|
| 428 |
+
mse_loss_feats = torch.tensor(0., device=self.device)
|
| 429 |
+
mse_loss_dirs = torch.tensor(0., device=self.device)
|
| 430 |
+
# connection_loss = torch.tensor(0., device=self.device)
|
| 431 |
+
|
| 432 |
+
# --- Vertex Branch (Connection Loss 核心) ---
|
| 433 |
+
vtx_pred_coords = res_dict['vertex']['coords_4d'] # [N, 4]
|
| 434 |
+
vtx_pred_feats = res_dict['vertex']['feats'] # [N, C]
|
| 435 |
+
|
| 436 |
+
# 1.1 排序 (既用于匹配 GT,也用于快速寻找空间邻居)
|
| 437 |
+
vtx_pred_keys = flatten_coords_4d(vtx_pred_coords)
|
| 438 |
+
vtx_pred_keys_sorted, vtx_pred_order = torch.sort(vtx_pred_keys)
|
| 439 |
+
|
| 440 |
+
# 1.2 匹配 GT
|
| 441 |
+
vtx_gt_keys = flatten_coords_4d(gt_vertex_voxels_1024.to(self.device))
|
| 442 |
+
vtx_pos = torch.searchsorted(vtx_pred_keys_sorted, vtx_gt_keys)
|
| 443 |
+
vtx_pos = vtx_pos.clamp(max=len(vtx_pred_keys_sorted) - 1)
|
| 444 |
+
vtx_match_mask = (vtx_pred_keys_sorted[vtx_pos] == vtx_gt_keys)
|
| 445 |
+
|
| 446 |
+
gt_to_pred_mapping = torch.full((len(vtx_gt_keys),), -1, device=self.device, dtype=torch.long)
|
| 447 |
+
matched_pred_indices = vtx_pred_order[vtx_pos[vtx_match_mask]]
|
| 448 |
+
gt_to_pred_mapping[vtx_match_mask] = matched_pred_indices
|
| 449 |
+
|
| 450 |
+
# ====================================================
|
| 451 |
+
# 2. 构建核心数据:正样本 Hash 集合
|
| 452 |
+
# ====================================================
|
| 453 |
+
# 这里的 pos_u/pos_v 仅用于构建 "什么是真连接" 的查询表
|
| 454 |
+
u_gt, v_gt = gt_edges[:, 0], gt_edges[:, 1]
|
| 455 |
+
u_pred = gt_to_pred_mapping[u_gt]
|
| 456 |
+
v_pred = gt_to_pred_mapping[v_gt]
|
| 457 |
+
|
| 458 |
+
valid_edge_mask = (u_pred != -1) & (v_pred != -1)
|
| 459 |
+
real_pos_u = u_pred[valid_edge_mask]
|
| 460 |
+
real_pos_v = v_pred[valid_edge_mask]
|
| 461 |
+
|
| 462 |
+
num_real_pos = real_pos_u.shape[0]
|
| 463 |
+
num_total_nodes = vtx_pred_coords.shape[0]
|
| 464 |
+
|
| 465 |
+
if num_real_pos > 0:
|
| 466 |
+
# 2. 构建候选样本 (Candidate Generation)
|
| 467 |
+
# ====================================================
|
| 468 |
+
cand_u_list = []
|
| 469 |
+
cand_v_list = []
|
| 470 |
+
|
| 471 |
+
batch_ids = vtx_pred_coords[:, 0]
|
| 472 |
+
unique_batches = torch.unique(batch_ids)
|
| 473 |
+
|
| 474 |
+
RADIUS = 64
|
| 475 |
+
MAX_PTS_FOR_DIST = 12000
|
| 476 |
+
K_RANDOM = 32
|
| 477 |
+
|
| 478 |
+
for b_id in unique_batches:
|
| 479 |
+
mask_b = (batch_ids == b_id)
|
| 480 |
+
indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
|
| 481 |
+
coords_b = vtx_pred_coords[mask_b, 1:4].float() # (x,y,z)
|
| 482 |
+
num_b = coords_b.shape[0]
|
| 483 |
+
|
| 484 |
+
if num_b < 2: continue
|
| 485 |
+
|
| 486 |
+
# --- A. Radius Graph (Hard Negatives) ---
|
| 487 |
+
if num_b <= MAX_PTS_FOR_DIST:
|
| 488 |
+
# 计算距离矩阵 [M, M]
|
| 489 |
+
# 注意:autocast 下 float16 的 cdist 可能精度不够,建议转 float32
|
| 490 |
+
dist_mat = torch.cdist(coords_b.float(), coords_b.float())
|
| 491 |
+
|
| 492 |
+
# 找到距离小于 Radius 的点对 (排除自环)
|
| 493 |
+
adj_mat = (dist_mat < RADIUS) & (dist_mat > 1e-6)
|
| 494 |
+
|
| 495 |
+
# 提取索引 (local indices in batch)
|
| 496 |
+
src_local, dst_local = torch.nonzero(adj_mat, as_tuple=True)
|
| 497 |
+
|
| 498 |
+
# 映射回全局索引
|
| 499 |
+
cand_u_list.append(indices_b[src_local])
|
| 500 |
+
cand_v_list.append(indices_b[dst_local])
|
| 501 |
+
else:
|
| 502 |
+
print('num_b is big!')
|
| 503 |
+
pass
|
| 504 |
+
|
| 505 |
+
# --- B. Random Sampling (Easy Negatives) ---
|
| 506 |
+
# 随机生成 num_b * K 对
|
| 507 |
+
n_rand = num_b * K_RANDOM
|
| 508 |
+
rand_src_local = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 509 |
+
rand_dst_local = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 510 |
+
|
| 511 |
+
# 映射回全局索引
|
| 512 |
+
cand_u_list.append(indices_b[rand_src_local])
|
| 513 |
+
cand_v_list.append(indices_b[rand_dst_local])
|
| 514 |
+
|
| 515 |
+
# 合并所有来源 (GT + Radius + Random)
|
| 516 |
+
# 注意:我们把 real_pos 也加进来,确保正样本一定在列表里
|
| 517 |
+
all_u = torch.cat([real_pos_u] + cand_u_list)
|
| 518 |
+
all_v = torch.cat([real_pos_v] + cand_v_list)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# 3. 去重与 Labeling (Deduplication & Labeling)
|
| 523 |
+
# ====================================================
|
| 524 |
+
# 构造无向边 Hash: min * N + max
|
| 525 |
+
# 确保 MAX_NODES 足够大,比如 1000000 或 num_total_nodes
|
| 526 |
+
HASH_BASE = num_total_nodes + 100
|
| 527 |
+
|
| 528 |
+
p_min = torch.min(all_u, all_v)
|
| 529 |
+
p_max = torch.max(all_u, all_v)
|
| 530 |
+
|
| 531 |
+
# 过滤掉自环 (u==v)
|
| 532 |
+
valid_pair = (p_min != p_max)
|
| 533 |
+
p_min = p_min[valid_pair]
|
| 534 |
+
p_max = p_max[valid_pair]
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
all_hashes = p_min.long() * HASH_BASE + p_max.long()
|
| 538 |
+
|
| 539 |
+
# --- 核心:去重 ---
|
| 540 |
+
unique_hashes = torch.unique(all_hashes)
|
| 541 |
+
|
| 542 |
+
# 解码回 u, v
|
| 543 |
+
final_u = unique_hashes // HASH_BASE
|
| 544 |
+
final_v = unique_hashes % HASH_BASE
|
| 545 |
+
|
| 546 |
+
# --- Labeling ---
|
| 547 |
+
# 构建 GT 的 Hash 表用于查询
|
| 548 |
+
gt_min = torch.min(real_pos_u, real_pos_v)
|
| 549 |
+
gt_max = torch.max(real_pos_u, real_pos_v)
|
| 550 |
+
gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
|
| 551 |
+
gt_hashes = torch.unique(gt_hashes) # GT 也去重一下保险
|
| 552 |
+
gt_hashes_sorted, _ = torch.sort(gt_hashes)
|
| 553 |
+
|
| 554 |
+
# 查询 unique_hashes 是否在 gt_hashes 中
|
| 555 |
+
# 使用 searchsorted
|
| 556 |
+
idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
|
| 557 |
+
idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
|
| 558 |
+
is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
|
| 559 |
+
|
| 560 |
+
targets = is_connected.float().unsqueeze(-1) # [N_pairs, 1]
|
| 561 |
+
|
| 562 |
+
# ============================================================
|
| 563 |
+
# 4. 前向传播与 Loss
|
| 564 |
+
# ====================================================
|
| 565 |
+
feat_u = vtx_pred_feats[final_u]
|
| 566 |
+
feat_v = vtx_pred_feats[final_v]
|
| 567 |
+
|
| 568 |
+
# 对称特征融合
|
| 569 |
+
feat_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 570 |
+
feat_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 571 |
+
|
| 572 |
+
logits_uv = self.connection_head(feat_uv)
|
| 573 |
+
logits_vu = self.connection_head(feat_vu)
|
| 574 |
+
logits = (logits_uv + logits_vu) / 2.
|
| 575 |
+
|
| 576 |
+
connection_loss = self.asyloss(logits, targets)
|
| 577 |
+
total_loss += connection_loss
|
| 578 |
+
else:
|
| 579 |
+
connection_loss = torch.tensor(0., device=self.device)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
# KL loss
|
| 583 |
+
kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
|
| 584 |
+
total_loss += kl_loss
|
| 585 |
+
|
| 586 |
+
# Backpropagation
|
| 587 |
+
scaled_total_loss = total_loss / self.accum_steps
|
| 588 |
+
# self.scaler.scale(scaled_total_loss).backward()
|
| 589 |
+
scaled_total_loss.backward()
|
| 590 |
+
|
| 591 |
+
return {
|
| 592 |
+
'total_loss': total_loss.item(),
|
| 593 |
+
'kl_loss': kl_loss.item(),
|
| 594 |
+
'prune_loss': prune_loss_total.item(),
|
| 595 |
+
'vertex_loss': vertex_loss_total.item(),
|
| 596 |
+
'edge_loss': edge_loss_total.item(),
|
| 597 |
+
'offset_loss': mse_loss_feats.item(),
|
| 598 |
+
'direction_loss': mse_loss_dirs.item(),
|
| 599 |
+
'connection_loss': connection_loss.item(),
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def train(self):
|
| 604 |
+
accum_steps = self.accum_steps
|
| 605 |
+
for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
|
| 606 |
+
self.dataloader.sampler.set_epoch(epoch)
|
| 607 |
+
# Initialize metrics
|
| 608 |
+
metrics = {
|
| 609 |
+
'total_loss': 0.0,
|
| 610 |
+
'kl_loss': 0.0,
|
| 611 |
+
'prune_loss': 0.0,
|
| 612 |
+
'vertex_loss': 0.0,
|
| 613 |
+
'edge_loss': 0.0,
|
| 614 |
+
'offset_loss': 0.0,
|
| 615 |
+
'direction_loss': 0.0,
|
| 616 |
+
'connection_loss': 0.0,
|
| 617 |
+
}
|
| 618 |
+
num_batches = 0
|
| 619 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 620 |
+
|
| 621 |
+
for i, batch in enumerate(self.dataloader):
|
| 622 |
+
# Get all losses from train_step
|
| 623 |
+
if batch is None:
|
| 624 |
+
continue
|
| 625 |
+
step_losses = self.train_step(batch, i)
|
| 626 |
+
|
| 627 |
+
# Accumulate losses
|
| 628 |
+
for key in metrics:
|
| 629 |
+
metrics[key] += step_losses[key]
|
| 630 |
+
|
| 631 |
+
num_batches += 1
|
| 632 |
+
|
| 633 |
+
if (i + 1) % accum_steps == 0:
|
| 634 |
+
# self.scaler.unscale_(self.optimizer)
|
| 635 |
+
# torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
|
| 636 |
+
# torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
|
| 637 |
+
# torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
|
| 638 |
+
|
| 639 |
+
# self.scaler.step(self.optimizer)
|
| 640 |
+
# self.scaler.update()
|
| 641 |
+
# self.optimizer.zero_grad(set_to_none=True)
|
| 642 |
+
|
| 643 |
+
# self.scheduler.step()
|
| 644 |
+
|
| 645 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
|
| 646 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
|
| 647 |
+
torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
|
| 648 |
+
self.optimizer.step()
|
| 649 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 650 |
+
self.scheduler.step()
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
# Print batch-level metrics
|
| 654 |
+
if self.is_master:
|
| 655 |
+
avg_metric = {key: value / num_batches for key, value in metrics.items()}
|
| 656 |
+
print(
|
| 657 |
+
f"[Epoch {epoch}] Batch:{num_batches} "
|
| 658 |
+
f"AvgL:{avg_metric['total_loss']:.4f} "
|
| 659 |
+
f"Loss: {step_losses['total_loss']:.4f}, "
|
| 660 |
+
f"KLL: {step_losses['kl_loss']:.4f}, "
|
| 661 |
+
f"PruneL: {step_losses['prune_loss']:.4f}, "
|
| 662 |
+
f"VertexL: {step_losses['vertex_loss']:.4f}, "
|
| 663 |
+
f"EdgeL: {step_losses['edge_loss']:.4f}, "
|
| 664 |
+
f"OffsetL: {step_losses['offset_loss']:.4f}, "
|
| 665 |
+
f"DireL: {step_losses['direction_loss']:.4f}, "
|
| 666 |
+
f"ConL: {step_losses['connection_loss']:.4f}, "
|
| 667 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
if i % 2000 == 0 and i != 0:
|
| 671 |
+
self.save_checkpoint(epoch, avg_metric['total_loss'], i)
|
| 672 |
+
with open(self.log_file, "a") as f:
|
| 673 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 674 |
+
log_line = (
|
| 675 |
+
f"Epoch {epoch:05d} | "
|
| 676 |
+
f"Batch {i:05d} | "
|
| 677 |
+
f"Loss: {avg_metric['total_loss']:.6f} "
|
| 678 |
+
f"Avg KLL: {avg_metric['kl_loss']:.4f} "
|
| 679 |
+
f"Avg PruneL: {avg_metric['prune_loss']:.4f} "
|
| 680 |
+
f"Avg VertexL: {avg_metric['vertex_loss']:.4f} "
|
| 681 |
+
f"Avg EdgeL: {avg_metric['edge_loss']:.4f} "
|
| 682 |
+
f"Avg OffsetL: {avg_metric['offset_loss']:.4f} "
|
| 683 |
+
f"Avg DireL: {avg_metric['direction_loss']:.4f} "
|
| 684 |
+
f"Avg ConL: {avg_metric['connection_loss']:.4f} "
|
| 685 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 686 |
+
f"[{current_time}]\n"
|
| 687 |
+
)
|
| 688 |
+
f.write(log_line)
|
| 689 |
+
|
| 690 |
+
if num_batches % accum_steps != 0:
|
| 691 |
+
# self.scaler.unscale_(self.optimizer)
|
| 692 |
+
# torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
|
| 693 |
+
# torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
|
| 694 |
+
# torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
|
| 695 |
+
|
| 696 |
+
# self.scaler.step(self.optimizer)
|
| 697 |
+
# self.scaler.update()
|
| 698 |
+
# self.optimizer.zero_grad(set_to_none=True)
|
| 699 |
+
|
| 700 |
+
# self.scheduler.step()
|
| 701 |
+
|
| 702 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
|
| 703 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
|
| 704 |
+
torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
|
| 705 |
+
|
| 706 |
+
self.optimizer.step()
|
| 707 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 708 |
+
self.scheduler.step()
|
| 709 |
+
|
| 710 |
+
# Calculate epoch averages
|
| 711 |
+
avg_metrics = {key: value / num_batches for key, value in metrics.items()}
|
| 712 |
+
self.train_loss_history.append(avg_metrics['total_loss'])
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# Log to file
|
| 716 |
+
if self.is_master:
|
| 717 |
+
with open(self.log_file, "a") as f:
|
| 718 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 719 |
+
log_line = (
|
| 720 |
+
f"Epoch {epoch:05d} | "
|
| 721 |
+
f"Loss: {avg_metrics['total_loss']:.6f} "
|
| 722 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 723 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 724 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 725 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 726 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 727 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 728 |
+
f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
|
| 729 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 730 |
+
f"[{current_time}]\n"
|
| 731 |
+
)
|
| 732 |
+
f.write(log_line)
|
| 733 |
+
|
| 734 |
+
# Print epoch summary
|
| 735 |
+
print(
|
| 736 |
+
f"[Epoch {epoch}] "
|
| 737 |
+
f"Avg Loss: {avg_metrics['total_loss']:.4f} "
|
| 738 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 739 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 740 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 741 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 742 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 743 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 744 |
+
f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
|
| 745 |
+
f"[{current_time}]\n"
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Save checkpoint
|
| 749 |
+
if epoch % self.cfg['training']['save_every'] == 0:
|
| 750 |
+
self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
|
| 751 |
+
|
| 752 |
+
# Update learning rate
|
| 753 |
+
if self.is_master:
|
| 754 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 755 |
+
print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
|
| 756 |
+
|
| 757 |
+
dist.barrier()
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def main():
|
| 761 |
+
# Initialize the process group
|
| 762 |
+
dist.init_process_group(backend='nccl')
|
| 763 |
+
|
| 764 |
+
# Get rank and world size from environment variables set by the launcher
|
| 765 |
+
rank = int(os.environ['RANK'])
|
| 766 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 767 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 768 |
+
|
| 769 |
+
# Set the device for the current process. This is crucial.
|
| 770 |
+
torch.cuda.set_device(local_rank)
|
| 771 |
+
torch.manual_seed(42+rank)
|
| 772 |
+
|
| 773 |
+
# with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 774 |
+
# Pass the distributed info to the Trainer
|
| 775 |
+
trainer = Trainer(
|
| 776 |
+
config_path="/gemini/user/private/zhaotianhao/Triposf/config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_addhead.yaml",
|
| 777 |
+
rank=rank,
|
| 778 |
+
world_size=world_size,
|
| 779 |
+
local_rank=local_rank
|
| 780 |
+
)
|
| 781 |
+
trainer.train()
|
| 782 |
+
|
| 783 |
+
# Clean up the process group
|
| 784 |
+
dist.destroy_process_group()
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
if __name__ == '__main__':
|
| 788 |
+
main()
|
train_slat_vae_512_128to1024_pointnet_head.py
ADDED
|
@@ -0,0 +1,930 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import os
|
| 3 |
+
# os.environ['ATTN_BACKEND'] = 'xformers'
|
| 4 |
+
import yaml
|
| 5 |
+
import torch
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from functools import partial
|
| 10 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 14 |
+
|
| 15 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 16 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 17 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 18 |
+
|
| 19 |
+
from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss, FocalLoss
|
| 20 |
+
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 23 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 24 |
+
|
| 25 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 26 |
+
import math
|
| 27 |
+
|
| 28 |
+
from torchvision.ops import sigmoid_focal_loss
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import open3d as o3d
|
| 32 |
+
|
| 33 |
+
def export_sampled_edges(coords, u, v, labels, step_idx, save_dir="debug_viz", batch_idx_to_viz=0):
|
| 34 |
+
"""
|
| 35 |
+
导出采样边为 PLY 文件。
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
coords: [N, 4] Tensor (batch_idx, x, y, z)
|
| 39 |
+
u: [E] Tensor, 起点索引 (global index)
|
| 40 |
+
v: [E] Tensor, 终点索引 (global index)
|
| 41 |
+
labels: [E, 1] Tensor, 1.0 为正样本, 0.0 为负样本
|
| 42 |
+
step_idx: 当前步数或 epoch,用于文件名
|
| 43 |
+
save_dir: 保存目录
|
| 44 |
+
batch_idx_to_viz: 只可视化哪个 batch 的数据 (防止多个 batch 叠加在一起看不清)
|
| 45 |
+
"""
|
| 46 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
# 1. 转为 CPU numpy
|
| 49 |
+
coords_np = coords.detach().cpu().numpy()
|
| 50 |
+
u_np = u.detach().cpu().numpy()
|
| 51 |
+
v_np = v.detach().cpu().numpy()
|
| 52 |
+
labels_np = labels.detach().cpu().numpy().reshape(-1)
|
| 53 |
+
|
| 54 |
+
# 2. 筛选特定 Batch (通常只看 Batch 0 比较清晰)
|
| 55 |
+
# coords 的第0列是 batch index
|
| 56 |
+
batch_mask = (coords_np[:, 0] == batch_idx_to_viz)
|
| 57 |
+
|
| 58 |
+
# 获取属于该 batch 的全局索引范围
|
| 59 |
+
# 注意:u 和 v 是针对所有 coords 的全局索引。
|
| 60 |
+
# 我们需要判断一条边的两个端点是否都在这个 batch 内。
|
| 61 |
+
|
| 62 |
+
# 快速检查端点是否在当前 batch
|
| 63 |
+
valid_u_in_batch = batch_mask[u_np]
|
| 64 |
+
valid_v_in_batch = batch_mask[v_np]
|
| 65 |
+
edge_batch_mask = valid_u_in_batch & valid_v_in_batch
|
| 66 |
+
|
| 67 |
+
if edge_batch_mask.sum() == 0:
|
| 68 |
+
print(f"Warning: No edges found for batch {batch_idx_to_viz}")
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
# 应用 Batch 筛选
|
| 72 |
+
u_b = u_np[edge_batch_mask]
|
| 73 |
+
v_b = v_np[edge_batch_mask]
|
| 74 |
+
labels_b = labels_np[edge_batch_mask]
|
| 75 |
+
|
| 76 |
+
# 3. 提取该 Batch 的顶点坐标 (去掉 batch_idx 维度)
|
| 77 |
+
# 此时我们需要重新映射 u, v 的索引,因为我们要只保存该 batch 的点
|
| 78 |
+
batch_indices_global = np.where(batch_mask)[0]
|
| 79 |
+
# 创建全局索引到局部索引的映射表
|
| 80 |
+
global_to_local = {gid: lid for lid, gid in enumerate(batch_indices_global)}
|
| 81 |
+
|
| 82 |
+
points_xyz = coords_np[batch_indices_global, 1:4] # [M, 3]
|
| 83 |
+
|
| 84 |
+
# 转换 u, v 为局部索引
|
| 85 |
+
try:
|
| 86 |
+
u_local = np.array([global_to_local[idx] for idx in u_b])
|
| 87 |
+
v_local = np.array([global_to_local[idx] for idx in v_b])
|
| 88 |
+
except KeyError:
|
| 89 |
+
print("Error in index mapping. Edge endpoints might cross batches.")
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
# 4. 分离正负样本
|
| 93 |
+
pos_mask = labels_b > 0.5
|
| 94 |
+
neg_mask = ~pos_mask
|
| 95 |
+
|
| 96 |
+
# 内部函数:写 PLY
|
| 97 |
+
def write_ply(filename, points, edges_u, edges_v, color_rgb):
|
| 98 |
+
num_verts = len(points)
|
| 99 |
+
num_edges = len(edges_u)
|
| 100 |
+
|
| 101 |
+
with open(filename, 'w') as f:
|
| 102 |
+
f.write("ply\n")
|
| 103 |
+
f.write("format ascii 1.0\n")
|
| 104 |
+
f.write(f"element vertex {num_verts}\n")
|
| 105 |
+
f.write("property float x\n")
|
| 106 |
+
f.write("property float y\n")
|
| 107 |
+
f.write("property float z\n")
|
| 108 |
+
f.write("property uchar red\n")
|
| 109 |
+
f.write("property uchar green\n")
|
| 110 |
+
f.write("property uchar blue\n")
|
| 111 |
+
f.write(f"element edge {num_edges}\n")
|
| 112 |
+
f.write("property int vertex1\n")
|
| 113 |
+
f.write("property int vertex2\n")
|
| 114 |
+
f.write("end_header\n")
|
| 115 |
+
|
| 116 |
+
# Write Vertices with Color
|
| 117 |
+
# 为了让可视化更清楚,我们将所有点染成指定颜色
|
| 118 |
+
for i in range(num_verts):
|
| 119 |
+
x, y, z = points[i]
|
| 120 |
+
f.write(f"{x:.4f} {y:.4f} {z:.4f} {color_rgb[0]} {color_rgb[1]} {color_rgb[2]}\n")
|
| 121 |
+
|
| 122 |
+
# Write Edges
|
| 123 |
+
for i in range(num_edges):
|
| 124 |
+
f.write(f"{edges_u[i]} {edges_v[i]}\n")
|
| 125 |
+
|
| 126 |
+
print(f"Saved: {filename} (Edges: {num_edges})")
|
| 127 |
+
|
| 128 |
+
# 5. 保存正样本 (绿色)
|
| 129 |
+
if pos_mask.sum() > 0:
|
| 130 |
+
write_ply(
|
| 131 |
+
os.path.join(save_dir, f"step_{step_idx}_pos_edges.ply"),
|
| 132 |
+
points_xyz, # 使用所有点,或者优化为只使用涉及的点(这里为了坐标统一简单起见使用所有点)
|
| 133 |
+
u_local[pos_mask],
|
| 134 |
+
v_local[pos_mask],
|
| 135 |
+
color_rgb=(0, 255, 0) # Green
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# 6. 保存负样本 (红色)
|
| 139 |
+
if neg_mask.sum() > 0:
|
| 140 |
+
# 为了避免文件太大或太乱,如果负样本特别多,可以考虑随机采样一部分保存
|
| 141 |
+
# 这里默认全部保存
|
| 142 |
+
write_ply(
|
| 143 |
+
os.path.join(save_dir, f"step_{step_idx}_neg_edges.ply"),
|
| 144 |
+
points_xyz,
|
| 145 |
+
u_local[neg_mask],
|
| 146 |
+
v_local[neg_mask],
|
| 147 |
+
color_rgb=(255, 0, 0) # Red
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def flatten_coords_4d(coords_4d: torch.Tensor):
|
| 152 |
+
coords_4d_long = coords_4d.long()
|
| 153 |
+
|
| 154 |
+
base_x = 1024
|
| 155 |
+
base_y = 1024 * 1024
|
| 156 |
+
base_z = 1024 * 1024 * 1024
|
| 157 |
+
|
| 158 |
+
flat_coords = coords_4d_long[:, 0] * base_z + \
|
| 159 |
+
coords_4d_long[:, 1] * base_y + \
|
| 160 |
+
coords_4d_long[:, 2] * base_x + \
|
| 161 |
+
coords_4d_long[:, 3]
|
| 162 |
+
return flat_coords
|
| 163 |
+
|
| 164 |
+
def downsample_voxels(
|
| 165 |
+
voxels: torch.Tensor,
|
| 166 |
+
input_resolution: int,
|
| 167 |
+
output_resolution: int
|
| 168 |
+
) -> torch.Tensor:
|
| 169 |
+
if input_resolution % output_resolution != 0:
|
| 170 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 171 |
+
f"by output_resolution ({output_resolution}).")
|
| 172 |
+
|
| 173 |
+
factor = input_resolution // output_resolution
|
| 174 |
+
|
| 175 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 176 |
+
|
| 177 |
+
downsampled_voxels[:, 1:] //= factor
|
| 178 |
+
|
| 179 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 180 |
+
return unique_downsampled_voxels
|
| 181 |
+
|
| 182 |
+
class Trainer:
|
| 183 |
+
def __init__(self, config_path, rank, world_size, local_rank):
|
| 184 |
+
self.rank = rank
|
| 185 |
+
self.world_size = world_size
|
| 186 |
+
self.local_rank = local_rank
|
| 187 |
+
self.is_master = self.rank == 0
|
| 188 |
+
|
| 189 |
+
self.load_config(config_path)
|
| 190 |
+
self.accum_steps = max(1, 4 // self.cfg['training']['batch_size'])
|
| 191 |
+
|
| 192 |
+
self.config_hash = self.save_config_with_hash()
|
| 193 |
+
self.init_device()
|
| 194 |
+
self.init_dirs()
|
| 195 |
+
self.init_components()
|
| 196 |
+
self.init_training()
|
| 197 |
+
|
| 198 |
+
self.train_loss_history = []
|
| 199 |
+
self.eval_loss_history = []
|
| 200 |
+
self.best_eval_loss = float('inf')
|
| 201 |
+
|
| 202 |
+
def save_config_with_hash(self):
|
| 203 |
+
import hashlib
|
| 204 |
+
|
| 205 |
+
# Serialize config to hash
|
| 206 |
+
config_str = yaml.dump(self.cfg)
|
| 207 |
+
config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
|
| 208 |
+
|
| 209 |
+
# Prepare all flags as string for formatting
|
| 210 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 211 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 212 |
+
|
| 213 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 214 |
+
|
| 215 |
+
# Format save_dir with all placeholders
|
| 216 |
+
self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
|
| 217 |
+
dataset_name=dataset_name,
|
| 218 |
+
config_hash=config_hash,
|
| 219 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 220 |
+
multires=self.cfg['model']['multires'],
|
| 221 |
+
add_block_embed=add_block_embed_flag,
|
| 222 |
+
using_attn=using_attn_flag,
|
| 223 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if self.is_master:
|
| 227 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 228 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 229 |
+
with open(config_path, 'w') as f:
|
| 230 |
+
yaml.dump(self.cfg, f)
|
| 231 |
+
|
| 232 |
+
dist.barrier()
|
| 233 |
+
return config_hash
|
| 234 |
+
|
| 235 |
+
def save_checkpoint(self, epoch, avg_loss, batch_idx):
|
| 236 |
+
if not self.is_master:
|
| 237 |
+
return
|
| 238 |
+
checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
|
| 239 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 240 |
+
|
| 241 |
+
torch.save({
|
| 242 |
+
'voxel_encoder': self.voxel_encoder.module.state_dict(),
|
| 243 |
+
'vae': self.vae.module.state_dict(),
|
| 244 |
+
'connection_head': self.connection_head.module.state_dict(),
|
| 245 |
+
'epoch': epoch,
|
| 246 |
+
'loss': avg_loss,
|
| 247 |
+
'config': self.cfg
|
| 248 |
+
}, checkpoint_path)
|
| 249 |
+
|
| 250 |
+
def quoted_presenter(dumper, data):
|
| 251 |
+
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
|
| 252 |
+
|
| 253 |
+
yaml.add_representer(str, quoted_presenter)
|
| 254 |
+
|
| 255 |
+
with open(config_path, 'w') as f:
|
| 256 |
+
yaml.dump(self.cfg, f)
|
| 257 |
+
|
| 258 |
+
def load_config(self, config_path):
|
| 259 |
+
with open(config_path) as f:
|
| 260 |
+
self.cfg = yaml.safe_load(f)
|
| 261 |
+
|
| 262 |
+
# Extract and convert flags for formatting
|
| 263 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 264 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 265 |
+
|
| 266 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 267 |
+
|
| 268 |
+
self.save_dir = self.cfg['experiment']['save_dir'].format(
|
| 269 |
+
dataset_name=dataset_name,
|
| 270 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 271 |
+
multires=self.cfg['model']['multires'],
|
| 272 |
+
add_block_embed=add_block_embed_flag,
|
| 273 |
+
using_attn=using_attn_flag,
|
| 274 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if self.is_master:
|
| 278 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 279 |
+
dist.barrier()
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def init_device(self):
|
| 283 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 284 |
+
|
| 285 |
+
def init_dirs(self):
|
| 286 |
+
self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
|
| 287 |
+
if self.is_master:
|
| 288 |
+
with open(self.log_file, "a") as f:
|
| 289 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 290 |
+
f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
|
| 291 |
+
|
| 292 |
+
def init_components(self):
|
| 293 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 294 |
+
root_dir=self.cfg['dataset']['path'],
|
| 295 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 296 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 297 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 298 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 299 |
+
|
| 300 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 301 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 302 |
+
|
| 303 |
+
active_voxel_res=128,
|
| 304 |
+
pc_sample_number=819200,
|
| 305 |
+
|
| 306 |
+
sample_type=self.cfg['dataset']['sample_type'],
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
self.sampler = DistributedSampler(
|
| 310 |
+
self.dataset,
|
| 311 |
+
num_replicas=self.world_size,
|
| 312 |
+
rank=self.rank,
|
| 313 |
+
shuffle=True,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
self.dataloader = DataLoader(
|
| 317 |
+
self.dataset,
|
| 318 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 319 |
+
shuffle=False,
|
| 320 |
+
collate_fn=partial(collate_fn_pointnet,),
|
| 321 |
+
num_workers=self.cfg['training']['num_workers'],
|
| 322 |
+
pin_memory=True,
|
| 323 |
+
sampler=self.sampler,
|
| 324 |
+
prefetch_factor=4,
|
| 325 |
+
persistent_workers=True,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 329 |
+
in_channels=15,
|
| 330 |
+
hidden_dim=256,
|
| 331 |
+
out_channels=1024,
|
| 332 |
+
scatter_type='mean',
|
| 333 |
+
n_blocks=5,
|
| 334 |
+
resolution=128,
|
| 335 |
+
add_label=False,
|
| 336 |
+
).to(self.device)
|
| 337 |
+
|
| 338 |
+
self.connection_head = ConnectionHead(
|
| 339 |
+
channels=32 * 2,
|
| 340 |
+
out_channels=1,
|
| 341 |
+
mlp_ratio=16,
|
| 342 |
+
).to(self.device)
|
| 343 |
+
|
| 344 |
+
# self.connection_head = ConnectionHead(
|
| 345 |
+
# channels=64 * 2,
|
| 346 |
+
# out_channels=1,
|
| 347 |
+
# mlp_ratio=8,
|
| 348 |
+
# ).to(self.device)
|
| 349 |
+
|
| 350 |
+
# ablation 3: voxelvae_1volume, have tested
|
| 351 |
+
self.vae = VoxelVAE(
|
| 352 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 353 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 354 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 355 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 356 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 357 |
+
num_heads=8,
|
| 358 |
+
num_head_channels=64,
|
| 359 |
+
mlp_ratio=4.0,
|
| 360 |
+
attn_mode="swin",
|
| 361 |
+
window_size=8,
|
| 362 |
+
pe_mode="ape",
|
| 363 |
+
use_fp16=False,
|
| 364 |
+
use_checkpoint=True,
|
| 365 |
+
qk_rms_norm=False,
|
| 366 |
+
using_subdivide=True,
|
| 367 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 368 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 369 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 370 |
+
).to(self.device)
|
| 371 |
+
|
| 372 |
+
if self.cfg['training']['from_pretrained']:
|
| 373 |
+
load_pretrained_woself(
|
| 374 |
+
checkpoint_path=self.cfg['training']['checkpoint_path'],
|
| 375 |
+
voxel_encoder=self.voxel_encoder,
|
| 376 |
+
vae=self.vae,
|
| 377 |
+
connection_head=self.connection_head,
|
| 378 |
+
optimizer=None,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 382 |
+
self.connection_head = DDP(self.connection_head, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 383 |
+
self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 384 |
+
|
| 385 |
+
def init_training(self):
|
| 386 |
+
self.optimizer = AdamW(
|
| 387 |
+
list(self.vae.module.parameters()) +
|
| 388 |
+
list(self.voxel_encoder.module.parameters()) +
|
| 389 |
+
list(self.connection_head.module.parameters()),
|
| 390 |
+
lr=self.cfg['training']['lr'],
|
| 391 |
+
weight_decay=0.01,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
|
| 395 |
+
max_epochs = self.cfg['training']['max_epochs']
|
| 396 |
+
num_training_steps = max_epochs * num_update_steps_per_epoch
|
| 397 |
+
|
| 398 |
+
num_warmup_steps = 200
|
| 399 |
+
|
| 400 |
+
self.scheduler = get_cosine_schedule_with_warmup(
|
| 401 |
+
self.optimizer,
|
| 402 |
+
num_warmup_steps=num_warmup_steps,
|
| 403 |
+
num_training_steps=num_training_steps
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=10.0).to(self.device)
|
| 407 |
+
# self.focal_loss = FocalLoss(gamma=2, alpha=0.6)
|
| 408 |
+
self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
|
| 409 |
+
self.asyloss = AsymmetricFocalLoss(
|
| 410 |
+
gamma_pos=0.0,
|
| 411 |
+
gamma_neg=4.0,
|
| 412 |
+
clip=0.05,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
self.bce_loss = torch.nn.BCEWithLogitsLoss()
|
| 416 |
+
|
| 417 |
+
self.dice_loss = DiceLoss()
|
| 418 |
+
self.scaler = GradScaler()
|
| 419 |
+
|
| 420 |
+
def train_step(self, batch, b_idx):
|
| 421 |
+
"""Modified training step that handles vertex and edge voxels separately after initial prediction."""
|
| 422 |
+
# 1. Retrieve data from batch
|
| 423 |
+
combined_voxels_1024 = batch['combined_voxels_1024'].to(self.device)
|
| 424 |
+
combined_voxel_labels_1024 = batch['combined_voxel_labels_1024'].to(self.device)
|
| 425 |
+
gt_vertex_voxels_1024 = batch['gt_vertex_voxels_1024'].to(self.device)
|
| 426 |
+
|
| 427 |
+
# gt_edge_voxels_1024 = batch['gt_edge_voxels_1024'].to(self.device)
|
| 428 |
+
# gt_combined_endpoints_1024 = batch['gt_combined_endpoints_1024'].to(self.device)
|
| 429 |
+
# gt_combined_errors_1024 = batch['gt_combined_errors_1024'].to(self.device)
|
| 430 |
+
|
| 431 |
+
# gt_edges = batch['gt_vertex_edge_indices_256'].to(self.device)
|
| 432 |
+
gt_edges = batch['gt_vertex_edge_indices_1024'].to(self.device)
|
| 433 |
+
|
| 434 |
+
edge_mask = (combined_voxel_labels_1024 == 1)
|
| 435 |
+
|
| 436 |
+
# gt_edge_endpoints_1024 = gt_combined_endpoints_1024[edge_mask]
|
| 437 |
+
# gt_edge_errors_1024 = gt_combined_errors_1024[edge_mask]
|
| 438 |
+
|
| 439 |
+
# p1 = gt_edge_endpoints_1024[:, 1:4].float()
|
| 440 |
+
# p2 = gt_edge_endpoints_1024[:, 4:7].float()
|
| 441 |
+
|
| 442 |
+
# mask = ( (p1[:,0] < p2[:,0]) |
|
| 443 |
+
# ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
|
| 444 |
+
# ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
|
| 445 |
+
|
| 446 |
+
# pA = torch.where(mask[:, None], p1, p2) # smaller one
|
| 447 |
+
# pB = torch.where(mask[:, None], p2, p1) # larger one
|
| 448 |
+
|
| 449 |
+
# d = pB - pA
|
| 450 |
+
# dir_gt = F.normalize(d, dim=-1, eps=1e-6)
|
| 451 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 452 |
+
vtx_256 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=256)
|
| 453 |
+
vtx_512 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=512)
|
| 454 |
+
vtx_1024 = gt_vertex_voxels_1024
|
| 455 |
+
|
| 456 |
+
edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
|
| 457 |
+
edge_256 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=256)
|
| 458 |
+
edge_512 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=512)
|
| 459 |
+
edge_1024 = combined_voxels_1024
|
| 460 |
+
|
| 461 |
+
active_coords = batch['active_voxels_128'].to(self.device)
|
| 462 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 463 |
+
|
| 464 |
+
with autocast(dtype=torch.bfloat16):
|
| 465 |
+
active_voxel_feats = self.voxel_encoder(
|
| 466 |
+
p=point_cloud,
|
| 467 |
+
sparse_coords=active_coords,
|
| 468 |
+
res=128,
|
| 469 |
+
bbox_size=(-0.5, 0.5),
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
sparse_input = SparseTensor(
|
| 473 |
+
feats=active_voxel_feats,
|
| 474 |
+
coords=active_coords.int()
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
gt_edge_voxels_list = [
|
| 478 |
+
edge_128,
|
| 479 |
+
edge_256,
|
| 480 |
+
edge_512,
|
| 481 |
+
edge_1024,
|
| 482 |
+
]
|
| 483 |
+
|
| 484 |
+
gt_vertex_voxels_list = [
|
| 485 |
+
vtx_128,
|
| 486 |
+
vtx_256,
|
| 487 |
+
vtx_512,
|
| 488 |
+
vtx_1024,
|
| 489 |
+
]
|
| 490 |
+
|
| 491 |
+
results, posterior, latent_128 = self.vae(
|
| 492 |
+
sparse_input,
|
| 493 |
+
gt_vertex_voxels_list=gt_vertex_voxels_list,
|
| 494 |
+
gt_edge_voxels_list=gt_edge_voxels_list,
|
| 495 |
+
training=True,
|
| 496 |
+
sample_ratio=0.,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# print("results[-1]['edge']['coords_4d'][1827:1830]", results[-1]['edge']['coords_4d'][1827:1830])
|
| 500 |
+
total_loss = 0.
|
| 501 |
+
prune_loss_total = 0.
|
| 502 |
+
vertex_loss_total = 0.
|
| 503 |
+
edge_loss_total=0.
|
| 504 |
+
|
| 505 |
+
initial_result = results[0]
|
| 506 |
+
vertex_mask = initial_result['vertex_mask']
|
| 507 |
+
vtx_logits = initial_result['vtx_feats']
|
| 508 |
+
vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
|
| 509 |
+
|
| 510 |
+
edge_mask = initial_result['edge_mask']
|
| 511 |
+
edge_logits = initial_result['edge_feats']
|
| 512 |
+
edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
|
| 513 |
+
# edge_loss = self.bce_loss(edge_logits.squeeze(-1), edge_mask.float())
|
| 514 |
+
|
| 515 |
+
vertex_loss_total += vertex_loss
|
| 516 |
+
edge_loss_total += edge_loss
|
| 517 |
+
|
| 518 |
+
total_loss += vertex_loss
|
| 519 |
+
total_loss += edge_loss
|
| 520 |
+
|
| 521 |
+
# Process each level's results
|
| 522 |
+
for idx, res_dict in enumerate(results[1:], start=1):
|
| 523 |
+
# Vertex branch losses
|
| 524 |
+
vertex_pred_coords = res_dict['vertex']['occ_coords']
|
| 525 |
+
vertex_occ_probs = res_dict['vertex']['occ_probs']
|
| 526 |
+
vertex_gt_coords = res_dict['vertex']['coords']
|
| 527 |
+
|
| 528 |
+
vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=1024).float()
|
| 529 |
+
# print('vertex_labels.sum()', vertex_labels.sum(), idx)
|
| 530 |
+
vertex_logits = vertex_occ_probs.squeeze()
|
| 531 |
+
|
| 532 |
+
# if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
|
| 533 |
+
vertex_prune_loss = self.asyloss(vertex_logits, vertex_labels)
|
| 534 |
+
|
| 535 |
+
prune_loss_total += vertex_prune_loss
|
| 536 |
+
total_loss += vertex_prune_loss
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# Edge branch losses
|
| 540 |
+
edge_pred_coords = res_dict['edge']['occ_coords']
|
| 541 |
+
edge_occ_probs = res_dict['edge']['occ_probs']
|
| 542 |
+
edge_gt_coords = res_dict['edge']['coords']
|
| 543 |
+
|
| 544 |
+
edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=1024).float()
|
| 545 |
+
edge_logits = edge_occ_probs.squeeze()
|
| 546 |
+
edge_prune_loss = self.asyloss(edge_logits, edge_labels)
|
| 547 |
+
|
| 548 |
+
prune_loss_total += edge_prune_loss
|
| 549 |
+
total_loss += edge_prune_loss
|
| 550 |
+
|
| 551 |
+
if idx == 3:
|
| 552 |
+
mse_loss_feats = torch.tensor(0., device=self.device)
|
| 553 |
+
mse_loss_dirs = torch.tensor(0., device=self.device)
|
| 554 |
+
# connection_loss = torch.tensor(0., device=self.device)
|
| 555 |
+
|
| 556 |
+
# --- Vertex Branch (Connection Loss 核心) ---
|
| 557 |
+
vtx_pred_coords = res_dict['vertex']['coords_4d'] # [N, 4]
|
| 558 |
+
vtx_pred_feats = res_dict['vertex']['feats'] # [N, C]
|
| 559 |
+
|
| 560 |
+
# 1.1 排序 (既用于匹配 GT,也用于快速寻找空间邻居)
|
| 561 |
+
vtx_pred_keys = flatten_coords_4d(vtx_pred_coords)
|
| 562 |
+
vtx_pred_keys_sorted, vtx_pred_order = torch.sort(vtx_pred_keys)
|
| 563 |
+
|
| 564 |
+
# 1.2 匹配 GT
|
| 565 |
+
vtx_gt_keys = flatten_coords_4d(gt_vertex_voxels_1024.to(self.device))
|
| 566 |
+
vtx_pos = torch.searchsorted(vtx_pred_keys_sorted, vtx_gt_keys)
|
| 567 |
+
vtx_pos = vtx_pos.clamp(max=len(vtx_pred_keys_sorted) - 1)
|
| 568 |
+
vtx_match_mask = (vtx_pred_keys_sorted[vtx_pos] == vtx_gt_keys)
|
| 569 |
+
|
| 570 |
+
gt_to_pred_mapping = torch.full((len(vtx_gt_keys),), -1, device=self.device, dtype=torch.long)
|
| 571 |
+
matched_pred_indices = vtx_pred_order[vtx_pos[vtx_match_mask]]
|
| 572 |
+
gt_to_pred_mapping[vtx_match_mask] = matched_pred_indices
|
| 573 |
+
|
| 574 |
+
# ====================================================
|
| 575 |
+
# 2. 构建核心数据:正样本 Hash 集合
|
| 576 |
+
# ====================================================
|
| 577 |
+
# 这里的 pos_u/pos_v 仅用于构建 "什么是真连接" 的查询表
|
| 578 |
+
u_gt, v_gt = gt_edges[:, 0], gt_edges[:, 1]
|
| 579 |
+
u_pred = gt_to_pred_mapping[u_gt]
|
| 580 |
+
v_pred = gt_to_pred_mapping[v_gt]
|
| 581 |
+
|
| 582 |
+
valid_edge_mask = (u_pred != -1) & (v_pred != -1)
|
| 583 |
+
real_pos_u = u_pred[valid_edge_mask]
|
| 584 |
+
real_pos_v = v_pred[valid_edge_mask]
|
| 585 |
+
|
| 586 |
+
num_real_pos = real_pos_u.shape[0]
|
| 587 |
+
num_total_nodes = vtx_pred_coords.shape[0]
|
| 588 |
+
|
| 589 |
+
if num_real_pos > 0:
|
| 590 |
+
# 2. 构建候选样本 (Candidate Generation)
|
| 591 |
+
# ====================================================
|
| 592 |
+
cand_u_list = []
|
| 593 |
+
cand_v_list = []
|
| 594 |
+
|
| 595 |
+
batch_ids = vtx_pred_coords[:, 0]
|
| 596 |
+
unique_batches = torch.unique(batch_ids)
|
| 597 |
+
|
| 598 |
+
RADIUS = 64
|
| 599 |
+
MAX_PTS_FOR_DIST = 12000
|
| 600 |
+
K_RANDOM = 32
|
| 601 |
+
|
| 602 |
+
for b_id in unique_batches:
|
| 603 |
+
mask_b = (batch_ids == b_id)
|
| 604 |
+
indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
|
| 605 |
+
coords_b = vtx_pred_coords[mask_b, 1:4].float() # (x,y,z)
|
| 606 |
+
num_b = coords_b.shape[0]
|
| 607 |
+
|
| 608 |
+
if num_b < 2: continue
|
| 609 |
+
|
| 610 |
+
# --- A. Radius Graph (Hard Negatives) ---
|
| 611 |
+
if num_b <= MAX_PTS_FOR_DIST:
|
| 612 |
+
# 计算距离矩阵 [M, M]
|
| 613 |
+
# 注意:autocast 下 float16 ��� cdist 可能精度不够,建议转 float32
|
| 614 |
+
dist_mat = torch.cdist(coords_b.float(), coords_b.float())
|
| 615 |
+
|
| 616 |
+
# 找到距离小于 Radius 的点对 (排除自环)
|
| 617 |
+
adj_mat = (dist_mat < RADIUS) & (dist_mat > 1e-6)
|
| 618 |
+
|
| 619 |
+
# 提取索引 (local indices in batch)
|
| 620 |
+
src_local, dst_local = torch.nonzero(adj_mat, as_tuple=True)
|
| 621 |
+
|
| 622 |
+
# 映射回全局索引
|
| 623 |
+
cand_u_list.append(indices_b[src_local])
|
| 624 |
+
cand_v_list.append(indices_b[dst_local])
|
| 625 |
+
else:
|
| 626 |
+
print('num_b is big!')
|
| 627 |
+
pass
|
| 628 |
+
|
| 629 |
+
# --- B. Random Sampling (Easy Negatives) ---
|
| 630 |
+
# 随机生成 num_b * K 对
|
| 631 |
+
n_rand = num_b * K_RANDOM
|
| 632 |
+
rand_src_local = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 633 |
+
rand_dst_local = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 634 |
+
|
| 635 |
+
# 映射回全局索引
|
| 636 |
+
cand_u_list.append(indices_b[rand_src_local])
|
| 637 |
+
cand_v_list.append(indices_b[rand_dst_local])
|
| 638 |
+
|
| 639 |
+
# 合并所有来源 (GT + Radius + Random)
|
| 640 |
+
# 注意:我们把 real_pos 也加进来,确保正样本一定在列表里
|
| 641 |
+
all_u = torch.cat([real_pos_u] + cand_u_list)
|
| 642 |
+
all_v = torch.cat([real_pos_v] + cand_v_list)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
# 3. 去重与 Labeling (Deduplication & Labeling)
|
| 647 |
+
# ====================================================
|
| 648 |
+
# 构造无向边 Hash: min * N + max
|
| 649 |
+
# 确保 MAX_NODES 足够大,比如 1000000 或 num_total_nodes
|
| 650 |
+
HASH_BASE = num_total_nodes + 100
|
| 651 |
+
|
| 652 |
+
p_min = torch.min(all_u, all_v)
|
| 653 |
+
p_max = torch.max(all_u, all_v)
|
| 654 |
+
|
| 655 |
+
# 过滤掉自环 (u==v)
|
| 656 |
+
valid_pair = (p_min != p_max)
|
| 657 |
+
p_min = p_min[valid_pair]
|
| 658 |
+
p_max = p_max[valid_pair]
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
all_hashes = p_min.long() * HASH_BASE + p_max.long()
|
| 662 |
+
|
| 663 |
+
# --- 核心:去重 ---
|
| 664 |
+
unique_hashes = torch.unique(all_hashes)
|
| 665 |
+
|
| 666 |
+
# 解码回 u, v
|
| 667 |
+
final_u = unique_hashes // HASH_BASE
|
| 668 |
+
final_v = unique_hashes % HASH_BASE
|
| 669 |
+
|
| 670 |
+
# --- Labeling ---
|
| 671 |
+
# 构建 GT 的 Hash 表用于查询
|
| 672 |
+
gt_min = torch.min(real_pos_u, real_pos_v)
|
| 673 |
+
gt_max = torch.max(real_pos_u, real_pos_v)
|
| 674 |
+
gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
|
| 675 |
+
gt_hashes = torch.unique(gt_hashes) # GT 也去重一下保险
|
| 676 |
+
gt_hashes_sorted, _ = torch.sort(gt_hashes)
|
| 677 |
+
|
| 678 |
+
# 查询 unique_hashes 是否在 gt_hashes 中
|
| 679 |
+
# 使用 searchsorted
|
| 680 |
+
idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
|
| 681 |
+
idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
|
| 682 |
+
is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
|
| 683 |
+
|
| 684 |
+
targets = is_connected.float().unsqueeze(-1) # [N_pairs, 1]
|
| 685 |
+
|
| 686 |
+
# ============================================================
|
| 687 |
+
# 4. 前向传播与 Loss
|
| 688 |
+
# ====================================================
|
| 689 |
+
feat_u = vtx_pred_feats[final_u]
|
| 690 |
+
feat_v = vtx_pred_feats[final_v]
|
| 691 |
+
|
| 692 |
+
# 对称特征融合
|
| 693 |
+
feat_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 694 |
+
feat_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 695 |
+
|
| 696 |
+
logits_uv = self.connection_head(feat_uv)
|
| 697 |
+
logits_vu = self.connection_head(feat_vu)
|
| 698 |
+
logits = (logits_uv + logits_vu) / 2.
|
| 699 |
+
|
| 700 |
+
# export_sampled_edges(
|
| 701 |
+
# coords=vtx_pred_coords, # [N, 4]
|
| 702 |
+
# u=final_u, # [E]
|
| 703 |
+
# v=final_v, # [E]
|
| 704 |
+
# labels=targets, # [E, 1]
|
| 705 |
+
# step_idx=b_idx,
|
| 706 |
+
# )
|
| 707 |
+
# export_sampled_edges(
|
| 708 |
+
# coords=vtx_pred_coords, # [N, 4]
|
| 709 |
+
# u=final_u, # [E]
|
| 710 |
+
# v=final_v, # [E]
|
| 711 |
+
# labels=targets, # [E, 1]
|
| 712 |
+
# step_idx=b_idx,
|
| 713 |
+
# batch_idx_to_viz=1,
|
| 714 |
+
# save_dir="debug_viz2"
|
| 715 |
+
# )
|
| 716 |
+
# exit()
|
| 717 |
+
|
| 718 |
+
connection_loss = self.asyloss(logits, targets)
|
| 719 |
+
total_loss += connection_loss
|
| 720 |
+
else:
|
| 721 |
+
connection_loss = torch.tensor(0., device=self.device)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
# KL loss
|
| 725 |
+
kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
|
| 726 |
+
total_loss += kl_loss
|
| 727 |
+
|
| 728 |
+
# Backpropagation
|
| 729 |
+
scaled_total_loss = total_loss / self.accum_steps
|
| 730 |
+
# self.scaler.scale(scaled_total_loss).backward()
|
| 731 |
+
scaled_total_loss.backward()
|
| 732 |
+
|
| 733 |
+
return {
|
| 734 |
+
'total_loss': total_loss.item(),
|
| 735 |
+
'kl_loss': kl_loss.item(),
|
| 736 |
+
'prune_loss': prune_loss_total.item(),
|
| 737 |
+
'vertex_loss': vertex_loss_total.item(),
|
| 738 |
+
'edge_loss': edge_loss_total.item(),
|
| 739 |
+
'offset_loss': mse_loss_feats.item(),
|
| 740 |
+
'direction_loss': mse_loss_dirs.item(),
|
| 741 |
+
'connection_loss': connection_loss.item(),
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def train(self):
|
| 746 |
+
accum_steps = self.accum_steps
|
| 747 |
+
for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
|
| 748 |
+
self.dataloader.sampler.set_epoch(epoch)
|
| 749 |
+
# Initialize metrics
|
| 750 |
+
metrics = {
|
| 751 |
+
'total_loss': 0.0,
|
| 752 |
+
'kl_loss': 0.0,
|
| 753 |
+
'prune_loss': 0.0,
|
| 754 |
+
'vertex_loss': 0.0,
|
| 755 |
+
'edge_loss': 0.0,
|
| 756 |
+
'offset_loss': 0.0,
|
| 757 |
+
'direction_loss': 0.0,
|
| 758 |
+
'connection_loss': 0.0,
|
| 759 |
+
}
|
| 760 |
+
num_batches = 0
|
| 761 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 762 |
+
|
| 763 |
+
for i, batch in enumerate(self.dataloader):
|
| 764 |
+
# Get all losses from train_step
|
| 765 |
+
if batch is None:
|
| 766 |
+
continue
|
| 767 |
+
step_losses = self.train_step(batch, i)
|
| 768 |
+
|
| 769 |
+
# Accumulate losses
|
| 770 |
+
for key in metrics:
|
| 771 |
+
metrics[key] += step_losses[key]
|
| 772 |
+
|
| 773 |
+
num_batches += 1
|
| 774 |
+
|
| 775 |
+
if (i + 1) % accum_steps == 0:
|
| 776 |
+
# self.scaler.unscale_(self.optimizer)
|
| 777 |
+
# torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
|
| 778 |
+
# torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
|
| 779 |
+
# torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
|
| 780 |
+
|
| 781 |
+
# self.scaler.step(self.optimizer)
|
| 782 |
+
# self.scaler.update()
|
| 783 |
+
# self.optimizer.zero_grad(set_to_none=True)
|
| 784 |
+
|
| 785 |
+
# self.scheduler.step()
|
| 786 |
+
|
| 787 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
|
| 788 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
|
| 789 |
+
torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
|
| 790 |
+
self.optimizer.step()
|
| 791 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 792 |
+
self.scheduler.step()
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
# Print batch-level metrics
|
| 796 |
+
if self.is_master:
|
| 797 |
+
avg_metric = {key: value / num_batches for key, value in metrics.items()}
|
| 798 |
+
print(
|
| 799 |
+
f"[Epoch {epoch}] Batch:{num_batches} "
|
| 800 |
+
f"AvgL:{avg_metric['total_loss']:.4f} "
|
| 801 |
+
f"Loss: {step_losses['total_loss']:.4f}, "
|
| 802 |
+
f"KLL: {step_losses['kl_loss']:.4f}, "
|
| 803 |
+
f"PruneL: {step_losses['prune_loss']:.4f}, "
|
| 804 |
+
f"VertexL: {step_losses['vertex_loss']:.4f}, "
|
| 805 |
+
f"EdgeL: {step_losses['edge_loss']:.4f}, "
|
| 806 |
+
f"OffsetL: {step_losses['offset_loss']:.4f}, "
|
| 807 |
+
f"DireL: {step_losses['direction_loss']:.4f}, "
|
| 808 |
+
f"ConL: {step_losses['connection_loss']:.4f}, "
|
| 809 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
if i % 2000 == 0 and i != 0:
|
| 813 |
+
self.save_checkpoint(epoch, avg_metric['total_loss'], i)
|
| 814 |
+
with open(self.log_file, "a") as f:
|
| 815 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 816 |
+
log_line = (
|
| 817 |
+
f"Epoch {epoch:05d} | "
|
| 818 |
+
f"Batch {i:05d} | "
|
| 819 |
+
f"Loss: {avg_metric['total_loss']:.6f} "
|
| 820 |
+
f"Avg KLL: {avg_metric['kl_loss']:.4f} "
|
| 821 |
+
f"Avg PruneL: {avg_metric['prune_loss']:.4f} "
|
| 822 |
+
f"Avg VertexL: {avg_metric['vertex_loss']:.4f} "
|
| 823 |
+
f"Avg EdgeL: {avg_metric['edge_loss']:.4f} "
|
| 824 |
+
f"Avg OffsetL: {avg_metric['offset_loss']:.4f} "
|
| 825 |
+
f"Avg DireL: {avg_metric['direction_loss']:.4f} "
|
| 826 |
+
f"Avg ConL: {avg_metric['connection_loss']:.4f} "
|
| 827 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 828 |
+
f"[{current_time}]\n"
|
| 829 |
+
)
|
| 830 |
+
f.write(log_line)
|
| 831 |
+
|
| 832 |
+
if num_batches % accum_steps != 0:
|
| 833 |
+
# self.scaler.unscale_(self.optimizer)
|
| 834 |
+
# torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
|
| 835 |
+
# torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
|
| 836 |
+
# torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
|
| 837 |
+
|
| 838 |
+
# self.scaler.step(self.optimizer)
|
| 839 |
+
# self.scaler.update()
|
| 840 |
+
# self.optimizer.zero_grad(set_to_none=True)
|
| 841 |
+
|
| 842 |
+
# self.scheduler.step()
|
| 843 |
+
|
| 844 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
|
| 845 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
|
| 846 |
+
torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
|
| 847 |
+
|
| 848 |
+
self.optimizer.step()
|
| 849 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 850 |
+
self.scheduler.step()
|
| 851 |
+
|
| 852 |
+
# Calculate epoch averages
|
| 853 |
+
avg_metrics = {key: value / num_batches for key, value in metrics.items()}
|
| 854 |
+
self.train_loss_history.append(avg_metrics['total_loss'])
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
# Log to file
|
| 858 |
+
if self.is_master:
|
| 859 |
+
with open(self.log_file, "a") as f:
|
| 860 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 861 |
+
log_line = (
|
| 862 |
+
f"Epoch {epoch:05d} | "
|
| 863 |
+
f"Loss: {avg_metrics['total_loss']:.6f} "
|
| 864 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 865 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 866 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 867 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 868 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 869 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 870 |
+
f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
|
| 871 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 872 |
+
f"[{current_time}]\n"
|
| 873 |
+
)
|
| 874 |
+
f.write(log_line)
|
| 875 |
+
|
| 876 |
+
# Print epoch summary
|
| 877 |
+
print(
|
| 878 |
+
f"[Epoch {epoch}] "
|
| 879 |
+
f"Avg Loss: {avg_metrics['total_loss']:.4f} "
|
| 880 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 881 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 882 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 883 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 884 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 885 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 886 |
+
f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
|
| 887 |
+
f"[{current_time}]\n"
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
# Save checkpoint
|
| 891 |
+
if epoch % self.cfg['training']['save_every'] == 0:
|
| 892 |
+
self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
|
| 893 |
+
|
| 894 |
+
# Update learning rate
|
| 895 |
+
if self.is_master:
|
| 896 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 897 |
+
print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
|
| 898 |
+
|
| 899 |
+
dist.barrier()
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
def main():
|
| 903 |
+
# Initialize the process group
|
| 904 |
+
dist.init_process_group(backend='nccl')
|
| 905 |
+
|
| 906 |
+
# Get rank and world size from environment variables set by the launcher
|
| 907 |
+
rank = int(os.environ['RANK'])
|
| 908 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 909 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 910 |
+
|
| 911 |
+
# Set the device for the current process. This is crucial.
|
| 912 |
+
torch.cuda.set_device(local_rank)
|
| 913 |
+
torch.manual_seed(42)
|
| 914 |
+
|
| 915 |
+
# with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 916 |
+
# Pass the distributed info to the Trainer
|
| 917 |
+
trainer = Trainer(
|
| 918 |
+
config_path="/root/Trisf/config_edge_1024_error_8enc_8dec_woself_finetune_128to1024.yaml",
|
| 919 |
+
rank=rank,
|
| 920 |
+
world_size=world_size,
|
| 921 |
+
local_rank=local_rank
|
| 922 |
+
)
|
| 923 |
+
trainer.train()
|
| 924 |
+
|
| 925 |
+
# Clean up the process group
|
| 926 |
+
dist.destroy_process_group()
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
if __name__ == '__main__':
|
| 930 |
+
main()
|
train_slat_vae_512_128to256_pointnet_head.py
ADDED
|
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import os
|
| 3 |
+
# os.environ['ATTN_BACKEND'] = 'xformers'
|
| 4 |
+
import yaml
|
| 5 |
+
import torch
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from functools import partial
|
| 10 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 14 |
+
|
| 15 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 16 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 17 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 18 |
+
|
| 19 |
+
from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss
|
| 20 |
+
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 23 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 24 |
+
|
| 25 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 26 |
+
import math
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import open3d as o3d
|
| 30 |
+
|
| 31 |
+
def export_sampled_edges(coords, u, v, labels, step_idx, save_dir="debug_viz", batch_idx_to_viz=0):
|
| 32 |
+
"""
|
| 33 |
+
导出采样边为 PLY 文件。
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
coords: [N, 4] Tensor (batch_idx, x, y, z)
|
| 37 |
+
u: [E] Tensor, 起点索引 (global index)
|
| 38 |
+
v: [E] Tensor, 终点索引 (global index)
|
| 39 |
+
labels: [E, 1] Tensor, 1.0 为正样本, 0.0 为负样本
|
| 40 |
+
step_idx: 当前步数或 epoch,用于文件名
|
| 41 |
+
save_dir: 保存目录
|
| 42 |
+
batch_idx_to_viz: 只可视化哪个 batch 的数据 (防止多个 batch 叠加在一起看不清)
|
| 43 |
+
"""
|
| 44 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
# 1. 转为 CPU numpy
|
| 47 |
+
coords_np = coords.detach().cpu().numpy()
|
| 48 |
+
u_np = u.detach().cpu().numpy()
|
| 49 |
+
v_np = v.detach().cpu().numpy()
|
| 50 |
+
labels_np = labels.detach().cpu().numpy().reshape(-1)
|
| 51 |
+
|
| 52 |
+
# 2. 筛选特定 Batch (通常只看 Batch 0 比较清晰)
|
| 53 |
+
# coords 的第0列是 batch index
|
| 54 |
+
batch_mask = (coords_np[:, 0] == batch_idx_to_viz)
|
| 55 |
+
|
| 56 |
+
# 获取属于该 batch 的全局索引范围
|
| 57 |
+
# 注意:u 和 v 是针对所有 coords 的全局索引。
|
| 58 |
+
# 我们需要判断一条边的两个端点是否都在这个 batch 内。
|
| 59 |
+
|
| 60 |
+
# 快速检查端点是否在当前 batch
|
| 61 |
+
valid_u_in_batch = batch_mask[u_np]
|
| 62 |
+
valid_v_in_batch = batch_mask[v_np]
|
| 63 |
+
edge_batch_mask = valid_u_in_batch & valid_v_in_batch
|
| 64 |
+
|
| 65 |
+
if edge_batch_mask.sum() == 0:
|
| 66 |
+
print(f"Warning: No edges found for batch {batch_idx_to_viz}")
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
# 应用 Batch 筛选
|
| 70 |
+
u_b = u_np[edge_batch_mask]
|
| 71 |
+
v_b = v_np[edge_batch_mask]
|
| 72 |
+
labels_b = labels_np[edge_batch_mask]
|
| 73 |
+
|
| 74 |
+
# 3. 提取该 Batch 的顶点坐标 (去掉 batch_idx 维度)
|
| 75 |
+
# 此时我们需要重新映射 u, v 的索引,因为我们要只保存该 batch 的点
|
| 76 |
+
batch_indices_global = np.where(batch_mask)[0]
|
| 77 |
+
# 创建全局索引到局部索引的映射表
|
| 78 |
+
global_to_local = {gid: lid for lid, gid in enumerate(batch_indices_global)}
|
| 79 |
+
|
| 80 |
+
points_xyz = coords_np[batch_indices_global, 1:4] # [M, 3]
|
| 81 |
+
|
| 82 |
+
# 转换 u, v 为局部索引
|
| 83 |
+
try:
|
| 84 |
+
u_local = np.array([global_to_local[idx] for idx in u_b])
|
| 85 |
+
v_local = np.array([global_to_local[idx] for idx in v_b])
|
| 86 |
+
except KeyError:
|
| 87 |
+
print("Error in index mapping. Edge endpoints might cross batches.")
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
# 4. 分离正负样本
|
| 91 |
+
pos_mask = labels_b > 0.5
|
| 92 |
+
neg_mask = ~pos_mask
|
| 93 |
+
|
| 94 |
+
# 内部函数:写 PLY
|
| 95 |
+
def write_ply(filename, points, edges_u, edges_v, color_rgb):
|
| 96 |
+
num_verts = len(points)
|
| 97 |
+
num_edges = len(edges_u)
|
| 98 |
+
|
| 99 |
+
with open(filename, 'w') as f:
|
| 100 |
+
f.write("ply\n")
|
| 101 |
+
f.write("format ascii 1.0\n")
|
| 102 |
+
f.write(f"element vertex {num_verts}\n")
|
| 103 |
+
f.write("property float x\n")
|
| 104 |
+
f.write("property float y\n")
|
| 105 |
+
f.write("property float z\n")
|
| 106 |
+
f.write("property uchar red\n")
|
| 107 |
+
f.write("property uchar green\n")
|
| 108 |
+
f.write("property uchar blue\n")
|
| 109 |
+
f.write(f"element edge {num_edges}\n")
|
| 110 |
+
f.write("property int vertex1\n")
|
| 111 |
+
f.write("property int vertex2\n")
|
| 112 |
+
f.write("end_header\n")
|
| 113 |
+
|
| 114 |
+
# Write Vertices with Color
|
| 115 |
+
# 为了让可视化更清楚,我们将所有点染成指定颜色
|
| 116 |
+
for i in range(num_verts):
|
| 117 |
+
x, y, z = points[i]
|
| 118 |
+
f.write(f"{x:.4f} {y:.4f} {z:.4f} {color_rgb[0]} {color_rgb[1]} {color_rgb[2]}\n")
|
| 119 |
+
|
| 120 |
+
# Write Edges
|
| 121 |
+
for i in range(num_edges):
|
| 122 |
+
f.write(f"{edges_u[i]} {edges_v[i]}\n")
|
| 123 |
+
|
| 124 |
+
print(f"Saved: {filename} (Edges: {num_edges})")
|
| 125 |
+
|
| 126 |
+
# 5. 保存正样本 (绿色)
|
| 127 |
+
if pos_mask.sum() > 0:
|
| 128 |
+
write_ply(
|
| 129 |
+
os.path.join(save_dir, f"step_{step_idx}_pos_edges.ply"),
|
| 130 |
+
points_xyz, # 使用所有点,或者优化为只使用涉及的点(这里为了坐标统一���单起见使用所有点)
|
| 131 |
+
u_local[pos_mask],
|
| 132 |
+
v_local[pos_mask],
|
| 133 |
+
color_rgb=(0, 255, 0) # Green
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# 6. 保存负样本 (红色)
|
| 137 |
+
if neg_mask.sum() > 0:
|
| 138 |
+
# 为了避免文件太大或太乱,如果负样本特别多,可以考虑随机采样一部分保存
|
| 139 |
+
# 这里默认全部保存
|
| 140 |
+
write_ply(
|
| 141 |
+
os.path.join(save_dir, f"step_{step_idx}_neg_edges.ply"),
|
| 142 |
+
points_xyz,
|
| 143 |
+
u_local[neg_mask],
|
| 144 |
+
v_local[neg_mask],
|
| 145 |
+
color_rgb=(255, 0, 0) # Red
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def flatten_coords_4d(coords_4d: torch.Tensor):
|
| 149 |
+
coords_4d_long = coords_4d.long()
|
| 150 |
+
|
| 151 |
+
base_x = 256
|
| 152 |
+
base_y = 256 * 256
|
| 153 |
+
base_z = 256 * 256 * 256
|
| 154 |
+
|
| 155 |
+
flat_coords = coords_4d_long[:, 0] * base_z + \
|
| 156 |
+
coords_4d_long[:, 1] * base_y + \
|
| 157 |
+
coords_4d_long[:, 2] * base_x + \
|
| 158 |
+
coords_4d_long[:, 3]
|
| 159 |
+
return flat_coords
|
| 160 |
+
|
| 161 |
+
def downsample_voxels(
|
| 162 |
+
voxels: torch.Tensor,
|
| 163 |
+
input_resolution: int,
|
| 164 |
+
output_resolution: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
if input_resolution % output_resolution != 0:
|
| 167 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 168 |
+
f"by output_resolution ({output_resolution}).")
|
| 169 |
+
|
| 170 |
+
factor = input_resolution // output_resolution
|
| 171 |
+
|
| 172 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 173 |
+
|
| 174 |
+
downsampled_voxels[:, 1:] //= factor
|
| 175 |
+
|
| 176 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 177 |
+
return unique_downsampled_voxels
|
| 178 |
+
|
| 179 |
+
class Trainer:
|
| 180 |
+
def __init__(self, config_path, rank, world_size, local_rank):
|
| 181 |
+
self.rank = rank
|
| 182 |
+
self.world_size = world_size
|
| 183 |
+
self.local_rank = local_rank
|
| 184 |
+
self.is_master = self.rank == 0
|
| 185 |
+
|
| 186 |
+
self.load_config(config_path)
|
| 187 |
+
self.accum_steps = max(1, 4 // self.cfg['training']['batch_size'])
|
| 188 |
+
|
| 189 |
+
self.config_hash = self.save_config_with_hash()
|
| 190 |
+
self.init_device()
|
| 191 |
+
self.init_dirs()
|
| 192 |
+
self.init_components()
|
| 193 |
+
self.init_training()
|
| 194 |
+
|
| 195 |
+
self.train_loss_history = []
|
| 196 |
+
self.eval_loss_history = []
|
| 197 |
+
self.best_eval_loss = float('inf')
|
| 198 |
+
|
| 199 |
+
def save_config_with_hash(self):
|
| 200 |
+
import hashlib
|
| 201 |
+
|
| 202 |
+
# Serialize config to hash
|
| 203 |
+
config_str = yaml.dump(self.cfg)
|
| 204 |
+
config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
|
| 205 |
+
|
| 206 |
+
# Prepare all flags as string for formatting
|
| 207 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 208 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 209 |
+
|
| 210 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 211 |
+
|
| 212 |
+
# Format save_dir with all placeholders
|
| 213 |
+
self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
|
| 214 |
+
dataset_name=dataset_name,
|
| 215 |
+
config_hash=config_hash,
|
| 216 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 217 |
+
multires=self.cfg['model']['multires'],
|
| 218 |
+
add_block_embed=add_block_embed_flag,
|
| 219 |
+
using_attn=using_attn_flag,
|
| 220 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if self.is_master:
|
| 224 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 225 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 226 |
+
with open(config_path, 'w') as f:
|
| 227 |
+
yaml.dump(self.cfg, f)
|
| 228 |
+
|
| 229 |
+
dist.barrier()
|
| 230 |
+
return config_hash
|
| 231 |
+
|
| 232 |
+
def save_checkpoint(self, epoch, avg_loss, batch_idx):
|
| 233 |
+
if not self.is_master:
|
| 234 |
+
return
|
| 235 |
+
checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
|
| 236 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 237 |
+
|
| 238 |
+
torch.save({
|
| 239 |
+
'voxel_encoder': self.voxel_encoder.module.state_dict(),
|
| 240 |
+
'vae': self.vae.module.state_dict(),
|
| 241 |
+
'connection_head': self.connection_head.module.state_dict(),
|
| 242 |
+
'epoch': epoch,
|
| 243 |
+
'loss': avg_loss,
|
| 244 |
+
'config': self.cfg
|
| 245 |
+
}, checkpoint_path)
|
| 246 |
+
|
| 247 |
+
def quoted_presenter(dumper, data):
|
| 248 |
+
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
|
| 249 |
+
|
| 250 |
+
yaml.add_representer(str, quoted_presenter)
|
| 251 |
+
|
| 252 |
+
with open(config_path, 'w') as f:
|
| 253 |
+
yaml.dump(self.cfg, f)
|
| 254 |
+
|
| 255 |
+
def load_config(self, config_path):
|
| 256 |
+
with open(config_path) as f:
|
| 257 |
+
self.cfg = yaml.safe_load(f)
|
| 258 |
+
|
| 259 |
+
# Extract and convert flags for formatting
|
| 260 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 261 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 262 |
+
|
| 263 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 264 |
+
|
| 265 |
+
self.save_dir = self.cfg['experiment']['save_dir'].format(
|
| 266 |
+
dataset_name=dataset_name,
|
| 267 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 268 |
+
multires=self.cfg['model']['multires'],
|
| 269 |
+
add_block_embed=add_block_embed_flag,
|
| 270 |
+
using_attn=using_attn_flag,
|
| 271 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
if self.is_master:
|
| 275 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 276 |
+
dist.barrier()
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def init_device(self):
|
| 280 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 281 |
+
|
| 282 |
+
def init_dirs(self):
|
| 283 |
+
self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
|
| 284 |
+
if self.is_master:
|
| 285 |
+
with open(self.log_file, "a") as f:
|
| 286 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 287 |
+
f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
|
| 288 |
+
|
| 289 |
+
def init_components(self):
|
| 290 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 291 |
+
root_dir=self.cfg['dataset']['path'],
|
| 292 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 293 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 294 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 295 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 296 |
+
|
| 297 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 298 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 299 |
+
|
| 300 |
+
active_voxel_res=128,
|
| 301 |
+
pc_sample_number=819200,
|
| 302 |
+
|
| 303 |
+
sample_type=self.cfg['dataset']['sample_type'],
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
self.sampler = DistributedSampler(
|
| 307 |
+
self.dataset,
|
| 308 |
+
num_replicas=self.world_size,
|
| 309 |
+
rank=self.rank,
|
| 310 |
+
shuffle=True,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
self.dataloader = DataLoader(
|
| 314 |
+
self.dataset,
|
| 315 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 316 |
+
shuffle=False,
|
| 317 |
+
collate_fn=partial(collate_fn_pointnet,),
|
| 318 |
+
num_workers=self.cfg['training']['num_workers'],
|
| 319 |
+
pin_memory=True,
|
| 320 |
+
sampler=self.sampler,
|
| 321 |
+
# prefetch_factor=4,
|
| 322 |
+
persistent_workers=True,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 326 |
+
in_channels=15,
|
| 327 |
+
hidden_dim=256,
|
| 328 |
+
out_channels=1024,
|
| 329 |
+
scatter_type='mean',
|
| 330 |
+
n_blocks=5,
|
| 331 |
+
resolution=128,
|
| 332 |
+
add_label=False,
|
| 333 |
+
).to(self.device)
|
| 334 |
+
|
| 335 |
+
self.connection_head = ConnectionHead(
|
| 336 |
+
channels=128 * 2,
|
| 337 |
+
out_channels=1,
|
| 338 |
+
mlp_ratio=4,
|
| 339 |
+
).to(self.device)
|
| 340 |
+
|
| 341 |
+
# ablation 3: voxelvae_1volume, have tested
|
| 342 |
+
self.vae = VoxelVAE(
|
| 343 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 344 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 345 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 346 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 347 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 348 |
+
num_heads=8,
|
| 349 |
+
num_head_channels=64,
|
| 350 |
+
mlp_ratio=4.0,
|
| 351 |
+
attn_mode="swin",
|
| 352 |
+
window_size=8,
|
| 353 |
+
pe_mode="ape",
|
| 354 |
+
use_fp16=False,
|
| 355 |
+
use_checkpoint=True,
|
| 356 |
+
qk_rms_norm=False,
|
| 357 |
+
using_subdivide=True,
|
| 358 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 359 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 360 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 361 |
+
).to(self.device)
|
| 362 |
+
|
| 363 |
+
if self.cfg['training']['from_pretrained']:
|
| 364 |
+
load_pretrained_woself(
|
| 365 |
+
checkpoint_path=self.cfg['training']['checkpoint_path'],
|
| 366 |
+
voxel_encoder=self.voxel_encoder,
|
| 367 |
+
vae=self.vae,
|
| 368 |
+
connection_head=self.connection_head,
|
| 369 |
+
optimizer=None,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 373 |
+
self.connection_head = DDP(self.connection_head, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 374 |
+
self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 375 |
+
|
| 376 |
+
def init_training(self):
|
| 377 |
+
self.optimizer = AdamW(
|
| 378 |
+
list(self.vae.module.parameters()) +
|
| 379 |
+
list(self.voxel_encoder.module.parameters()) +
|
| 380 |
+
list(self.connection_head.module.parameters()),
|
| 381 |
+
lr=self.cfg['training']['lr'],
|
| 382 |
+
weight_decay=0.01,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
|
| 386 |
+
max_epochs = self.cfg['training']['max_epochs']
|
| 387 |
+
num_training_steps = max_epochs * num_update_steps_per_epoch
|
| 388 |
+
|
| 389 |
+
num_warmup_steps = 100
|
| 390 |
+
|
| 391 |
+
self.scheduler = get_cosine_schedule_with_warmup(
|
| 392 |
+
self.optimizer,
|
| 393 |
+
num_warmup_steps=num_warmup_steps,
|
| 394 |
+
num_training_steps=num_training_steps
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=100.0).to(self.device)
|
| 398 |
+
self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
|
| 399 |
+
self.asyloss = AsymmetricFocalLoss(
|
| 400 |
+
gamma_pos=0.0,
|
| 401 |
+
gamma_neg=4.0,
|
| 402 |
+
clip=0.05,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
self.bce_loss = torch.nn.BCEWithLogitsLoss()
|
| 406 |
+
|
| 407 |
+
self.dice_loss = DiceLoss()
|
| 408 |
+
self.scaler = GradScaler()
|
| 409 |
+
|
| 410 |
+
def train_step(self, batch):
|
| 411 |
+
"""Modified training step that handles vertex and edge voxels separately after initial prediction."""
|
| 412 |
+
# 1. Retrieve data from batch
|
| 413 |
+
combined_voxels_256 = batch['combined_voxels_256'].to(self.device)
|
| 414 |
+
combined_voxel_labels_256 = batch['combined_voxel_labels_256'].to(self.device)
|
| 415 |
+
gt_vertex_voxels_256 = batch['gt_vertex_voxels_256'].to(self.device)
|
| 416 |
+
|
| 417 |
+
gt_edge_voxels_256 = batch['gt_edge_voxels_256'].to(self.device)
|
| 418 |
+
gt_combined_endpoints_256 = batch['gt_combined_endpoints_256'].to(self.device)
|
| 419 |
+
gt_combined_errors_256 = batch['gt_combined_errors_256'].to(self.device)
|
| 420 |
+
|
| 421 |
+
gt_edges = batch['gt_vertex_edge_indices_256'].to(self.device)
|
| 422 |
+
|
| 423 |
+
edge_mask = (combined_voxel_labels_256 == 1)
|
| 424 |
+
|
| 425 |
+
gt_edge_endpoints_256 = gt_combined_endpoints_256[edge_mask]
|
| 426 |
+
gt_edge_errors_256 = gt_combined_errors_256[edge_mask]
|
| 427 |
+
|
| 428 |
+
p1 = gt_edge_endpoints_256[:, 1:4].float()
|
| 429 |
+
p2 = gt_edge_endpoints_256[:, 4:7].float()
|
| 430 |
+
|
| 431 |
+
mask = ( (p1[:,0] < p2[:,0]) |
|
| 432 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
|
| 433 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
|
| 434 |
+
|
| 435 |
+
pA = torch.where(mask[:, None], p1, p2) # smaller one
|
| 436 |
+
pB = torch.where(mask[:, None], p2, p1) # larger one
|
| 437 |
+
|
| 438 |
+
d = pB - pA
|
| 439 |
+
dir_gt = F.normalize(d, dim=-1, eps=1e-6)
|
| 440 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_256, input_resolution=256, output_resolution=128)
|
| 441 |
+
vtx_256 = gt_vertex_voxels_256
|
| 442 |
+
|
| 443 |
+
edge_128 = downsample_voxels(combined_voxels_256, input_resolution=256, output_resolution=128)
|
| 444 |
+
edge_256 = combined_voxels_256
|
| 445 |
+
|
| 446 |
+
active_coords = batch['active_voxels_128'].to(self.device)
|
| 447 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 448 |
+
|
| 449 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 450 |
+
active_voxel_feats = self.voxel_encoder(
|
| 451 |
+
p=point_cloud,
|
| 452 |
+
sparse_coords=active_coords,
|
| 453 |
+
res=128,
|
| 454 |
+
bbox_size=(-0.5, 0.5),
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
sparse_input = SparseTensor(
|
| 458 |
+
feats=active_voxel_feats,
|
| 459 |
+
coords=active_coords.int()
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
gt_edge_voxels_list = [
|
| 463 |
+
edge_128,
|
| 464 |
+
edge_256,
|
| 465 |
+
]
|
| 466 |
+
|
| 467 |
+
gt_vertex_voxels_list = [
|
| 468 |
+
vtx_128,
|
| 469 |
+
vtx_256,
|
| 470 |
+
]
|
| 471 |
+
|
| 472 |
+
results, posterior, latent_128 = self.vae(
|
| 473 |
+
sparse_input,
|
| 474 |
+
gt_vertex_voxels_list=gt_vertex_voxels_list,
|
| 475 |
+
gt_edge_voxels_list=gt_edge_voxels_list,
|
| 476 |
+
training=True,
|
| 477 |
+
sample_ratio=0.,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# print("results[-1]['edge']['coords_4d'][1827:1830]", results[-1]['edge']['coords_4d'][1827:1830])
|
| 481 |
+
total_loss = 0.
|
| 482 |
+
prune_loss_total = 0.
|
| 483 |
+
vertex_loss_total = 0.
|
| 484 |
+
edge_loss_total=0.
|
| 485 |
+
|
| 486 |
+
with autocast(dtype=torch.bfloat16):
|
| 487 |
+
initial_result = results[0]
|
| 488 |
+
vertex_mask = initial_result['vertex_mask']
|
| 489 |
+
vtx_logits = initial_result['vtx_feats']
|
| 490 |
+
vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
|
| 491 |
+
|
| 492 |
+
edge_mask = initial_result['edge_mask']
|
| 493 |
+
edge_logits = initial_result['edge_feats']
|
| 494 |
+
edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
|
| 495 |
+
|
| 496 |
+
vertex_loss_total += vertex_loss
|
| 497 |
+
edge_loss_total += edge_loss
|
| 498 |
+
|
| 499 |
+
total_loss += vertex_loss
|
| 500 |
+
total_loss += edge_loss
|
| 501 |
+
|
| 502 |
+
# Process each level's results
|
| 503 |
+
for idx, res_dict in enumerate(results[1:], start=1):
|
| 504 |
+
# Vertex branch losses
|
| 505 |
+
vertex_pred_coords = res_dict['vertex']['occ_coords']
|
| 506 |
+
vertex_occ_probs = res_dict['vertex']['occ_probs']
|
| 507 |
+
vertex_gt_coords = res_dict['vertex']['coords']
|
| 508 |
+
|
| 509 |
+
vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=256).float()
|
| 510 |
+
# print('vertex_labels.sum()', vertex_labels.sum(), idx)
|
| 511 |
+
vertex_logits = vertex_occ_probs.squeeze()
|
| 512 |
+
|
| 513 |
+
# if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
|
| 514 |
+
vertex_prune_loss = self.focal_loss(vertex_logits, vertex_labels)
|
| 515 |
+
# vertex_prune_loss = self.dice_loss(vertex_logits, vertex_labels)
|
| 516 |
+
|
| 517 |
+
# dilation 1: bce loss
|
| 518 |
+
# vertex_prune_loss = self.bce_loss(vertex_logits, vertex_labels,)
|
| 519 |
+
|
| 520 |
+
prune_loss_total += vertex_prune_loss
|
| 521 |
+
total_loss += vertex_prune_loss
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
# Edge branch losses
|
| 525 |
+
edge_pred_coords = res_dict['edge']['occ_coords']
|
| 526 |
+
edge_occ_probs = res_dict['edge']['occ_probs']
|
| 527 |
+
edge_gt_coords = res_dict['edge']['coords']
|
| 528 |
+
|
| 529 |
+
edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=256).float()
|
| 530 |
+
# print('edge_labels.sum()', edge_labels.sum(), idx)
|
| 531 |
+
edge_logits = edge_occ_probs.squeeze()
|
| 532 |
+
# if edge_labels.sum() > 0 and edge_labels.sum() < len(edge_labels):
|
| 533 |
+
edge_prune_loss = self.focal_loss(edge_logits, edge_labels)
|
| 534 |
+
|
| 535 |
+
# dilation 1: bce loss
|
| 536 |
+
# edge_prune_loss = self.bce_loss(edge_logits, edge_labels,)
|
| 537 |
+
|
| 538 |
+
prune_loss_total += edge_prune_loss
|
| 539 |
+
total_loss += edge_prune_loss
|
| 540 |
+
|
| 541 |
+
if idx == 1:
|
| 542 |
+
pred_coords = res_dict['edge']['coords_4d'] # [N,4] (b,x,y,z)
|
| 543 |
+
pred_feats = res_dict['edge']['predicted_offset_feats'] # [N,C]
|
| 544 |
+
|
| 545 |
+
gt_coords = gt_edge_voxels_256.to(pred_coords.device) # [M,4]
|
| 546 |
+
gt_feats = gt_edge_errors_256[:, 1:].to(pred_coords.device) # [M,C]
|
| 547 |
+
|
| 548 |
+
pred_keys = flatten_coords_4d(pred_coords)
|
| 549 |
+
gt_keys = flatten_coords_4d(gt_coords)
|
| 550 |
+
|
| 551 |
+
sorted_pred_keys, pred_order = torch.sort(pred_keys)
|
| 552 |
+
pred_coords_sorted = pred_coords[pred_order]
|
| 553 |
+
pred_feats_sorted = pred_feats[pred_order]
|
| 554 |
+
|
| 555 |
+
sorted_gt_keys, gt_order = torch.sort(gt_keys)
|
| 556 |
+
gt_coords_sorted = gt_coords[gt_order]
|
| 557 |
+
gt_feats_sorted = gt_feats[gt_order]
|
| 558 |
+
|
| 559 |
+
pos = torch.searchsorted(sorted_gt_keys, sorted_pred_keys)
|
| 560 |
+
valid_mask = (pos < len(sorted_gt_keys)) & (sorted_gt_keys[pos] == sorted_pred_keys)
|
| 561 |
+
|
| 562 |
+
if valid_mask.any():
|
| 563 |
+
# print('valid_mask.sum()', valid_mask.sum())
|
| 564 |
+
matched_pred_feats = pred_feats_sorted[valid_mask]
|
| 565 |
+
matched_gt_feats = gt_feats_sorted[pos[valid_mask]]
|
| 566 |
+
mse_loss_feats = self.mse_loss(matched_pred_feats, matched_gt_feats * 2)
|
| 567 |
+
total_loss += mse_loss_feats * 0.
|
| 568 |
+
|
| 569 |
+
if self.cfg['model'].get('pred_direction', False):
|
| 570 |
+
pred_dirs = res_dict['edge']['predicted_direction_feats']
|
| 571 |
+
dir_gt_device = dir_gt.to(pred_coords.device)
|
| 572 |
+
|
| 573 |
+
pred_dirs_sorted = pred_dirs[pred_order]
|
| 574 |
+
dir_gt_sorted = dir_gt_device[gt_order]
|
| 575 |
+
|
| 576 |
+
matched_pred_dirs = pred_dirs_sorted[valid_mask]
|
| 577 |
+
matched_gt_dirs = dir_gt_sorted[pos[valid_mask]]
|
| 578 |
+
|
| 579 |
+
mse_loss_dirs = self.mse_loss(matched_pred_dirs, matched_gt_dirs)
|
| 580 |
+
total_loss += mse_loss_dirs * 0.
|
| 581 |
+
else:
|
| 582 |
+
mse_loss_feats = torch.tensor(0., device=pred_coords.device)
|
| 583 |
+
if self.cfg['model'].get('pred_direction', False):
|
| 584 |
+
mse_loss_dirs = torch.tensor(0., device=pred_coords.device)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# --- Vertex Branch (Connection Loss 核心) ---
|
| 588 |
+
vtx_pred_coords = res_dict['vertex']['coords_4d'] # [N, 4]
|
| 589 |
+
vtx_pred_feats = res_dict['vertex']['feats'] # [N, C]
|
| 590 |
+
|
| 591 |
+
# 1.1 排序 (既用于匹配 GT,也用于快速寻找空间邻居)
|
| 592 |
+
vtx_pred_keys = flatten_coords_4d(vtx_pred_coords)
|
| 593 |
+
vtx_pred_keys_sorted, vtx_pred_order = torch.sort(vtx_pred_keys)
|
| 594 |
+
|
| 595 |
+
# 1.2 匹配 GT
|
| 596 |
+
vtx_gt_keys = flatten_coords_4d(gt_vertex_voxels_256.to(self.device))
|
| 597 |
+
vtx_pos = torch.searchsorted(vtx_pred_keys_sorted, vtx_gt_keys)
|
| 598 |
+
vtx_pos = vtx_pos.clamp(max=len(vtx_pred_keys_sorted) - 1)
|
| 599 |
+
vtx_match_mask = (vtx_pred_keys_sorted[vtx_pos] == vtx_gt_keys)
|
| 600 |
+
|
| 601 |
+
gt_to_pred_mapping = torch.full((len(vtx_gt_keys),), -1, device=self.device, dtype=torch.long)
|
| 602 |
+
matched_pred_indices = vtx_pred_order[vtx_pos[vtx_match_mask]]
|
| 603 |
+
gt_to_pred_mapping[vtx_match_mask] = matched_pred_indices
|
| 604 |
+
|
| 605 |
+
# ====================================================
|
| 606 |
+
# 2. 构建核心数据:正样本 Hash 集合
|
| 607 |
+
# ====================================================
|
| 608 |
+
# 这里的 pos_u/pos_v 仅用于构建 "什么是真连接" 的查询表
|
| 609 |
+
u_gt, v_gt = gt_edges[:, 0], gt_edges[:, 1]
|
| 610 |
+
u_pred = gt_to_pred_mapping[u_gt]
|
| 611 |
+
v_pred = gt_to_pred_mapping[v_gt]
|
| 612 |
+
|
| 613 |
+
valid_edge_mask = (u_pred != -1) & (v_pred != -1)
|
| 614 |
+
real_pos_u = u_pred[valid_edge_mask]
|
| 615 |
+
real_pos_v = v_pred[valid_edge_mask]
|
| 616 |
+
|
| 617 |
+
num_real_pos = real_pos_u.shape[0]
|
| 618 |
+
num_total_nodes = vtx_pred_coords.shape[0]
|
| 619 |
+
|
| 620 |
+
if num_real_pos > 0:
|
| 621 |
+
# 2. 构建候选样本 (Candidate Generation)
|
| 622 |
+
# ====================================================
|
| 623 |
+
cand_u_list = []
|
| 624 |
+
cand_v_list = []
|
| 625 |
+
|
| 626 |
+
batch_ids = vtx_pred_coords[:, 0]
|
| 627 |
+
unique_batches = torch.unique(batch_ids)
|
| 628 |
+
|
| 629 |
+
RADIUS = 16
|
| 630 |
+
MAX_PTS_FOR_DIST = 12000
|
| 631 |
+
K_RANDOM = 32
|
| 632 |
+
|
| 633 |
+
for b_id in unique_batches:
|
| 634 |
+
mask_b = (batch_ids == b_id)
|
| 635 |
+
indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
|
| 636 |
+
coords_b = vtx_pred_coords[mask_b, 1:4].float() # (x,y,z)
|
| 637 |
+
num_b = coords_b.shape[0]
|
| 638 |
+
|
| 639 |
+
if num_b < 2: continue
|
| 640 |
+
|
| 641 |
+
# --- A. Radius Graph (Hard Negatives) ---
|
| 642 |
+
if num_b <= MAX_PTS_FOR_DIST:
|
| 643 |
+
# 计算距离矩阵 [M, M]
|
| 644 |
+
# 注意:autocast 下 float16 的 cdist 可能精度不够,建议转 float32
|
| 645 |
+
dist_mat = torch.cdist(coords_b.float(), coords_b.float())
|
| 646 |
+
|
| 647 |
+
# 找到距离小于 Radius 的点对 (排除自环)
|
| 648 |
+
adj_mat = (dist_mat < RADIUS) & (dist_mat > 1e-6)
|
| 649 |
+
|
| 650 |
+
# 提取索引 (local indices in batch)
|
| 651 |
+
src_local, dst_local = torch.nonzero(adj_mat, as_tuple=True)
|
| 652 |
+
|
| 653 |
+
# 映射回全局索引
|
| 654 |
+
cand_u_list.append(indices_b[src_local])
|
| 655 |
+
cand_v_list.append(indices_b[dst_local])
|
| 656 |
+
else:
|
| 657 |
+
# 如果点太多,显存不够,退化为随机局部采样或跳过
|
| 658 |
+
# 这里简单处理:跳过 Radius Graph,依赖 Random
|
| 659 |
+
pass
|
| 660 |
+
|
| 661 |
+
# --- B. Random Sampling (Easy Negatives) ---
|
| 662 |
+
# 随机生成 num_b * K 对
|
| 663 |
+
n_rand = num_b * K_RANDOM
|
| 664 |
+
rand_src_local = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 665 |
+
rand_dst_local = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 666 |
+
|
| 667 |
+
# 映射回全局索引
|
| 668 |
+
cand_u_list.append(indices_b[rand_src_local])
|
| 669 |
+
cand_v_list.append(indices_b[rand_dst_local])
|
| 670 |
+
|
| 671 |
+
# 合并所有来源 (GT + Radius + Random)
|
| 672 |
+
# 注意:我们把 real_pos 也加进来,确保正样本一定在列表里
|
| 673 |
+
all_u = torch.cat([real_pos_u] + cand_u_list)
|
| 674 |
+
all_v = torch.cat([real_pos_v] + cand_v_list)
|
| 675 |
+
|
| 676 |
+
# 3. 去重与 Labeling (Deduplication & Labeling)
|
| 677 |
+
# ====================================================
|
| 678 |
+
# 构造无向边 Hash: min * N + max
|
| 679 |
+
# 确保 MAX_NODES 足够大,比如 1000000 或 num_total_nodes
|
| 680 |
+
HASH_BASE = num_total_nodes + 100
|
| 681 |
+
|
| 682 |
+
p_min = torch.min(all_u, all_v)
|
| 683 |
+
p_max = torch.max(all_u, all_v)
|
| 684 |
+
|
| 685 |
+
# 过滤掉自环 (u==v)
|
| 686 |
+
valid_pair = (p_min != p_max)
|
| 687 |
+
p_min = p_min[valid_pair]
|
| 688 |
+
p_max = p_max[valid_pair]
|
| 689 |
+
|
| 690 |
+
all_hashes = p_min.long() * HASH_BASE + p_max.long()
|
| 691 |
+
|
| 692 |
+
# --- 核心:去重 ---
|
| 693 |
+
unique_hashes = torch.unique(all_hashes)
|
| 694 |
+
|
| 695 |
+
# 解码回 u, v
|
| 696 |
+
final_u = unique_hashes // HASH_BASE
|
| 697 |
+
final_v = unique_hashes % HASH_BASE
|
| 698 |
+
|
| 699 |
+
# --- Labeling ---
|
| 700 |
+
# 构建 GT 的 Hash 表用于查询
|
| 701 |
+
gt_min = torch.min(real_pos_u, real_pos_v)
|
| 702 |
+
gt_max = torch.max(real_pos_u, real_pos_v)
|
| 703 |
+
gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
|
| 704 |
+
gt_hashes = torch.unique(gt_hashes) # GT 也去重一下保险
|
| 705 |
+
gt_hashes_sorted, _ = torch.sort(gt_hashes)
|
| 706 |
+
|
| 707 |
+
# 查询 unique_hashes 是否在 gt_hashes 中
|
| 708 |
+
# 使用 searchsorted
|
| 709 |
+
idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
|
| 710 |
+
idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
|
| 711 |
+
is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
|
| 712 |
+
|
| 713 |
+
targets = is_connected.float().unsqueeze(-1) # [N_pairs, 1]
|
| 714 |
+
|
| 715 |
+
# 4. 前向传播与 Loss
|
| 716 |
+
# ====================================================
|
| 717 |
+
feat_u = vtx_pred_feats[final_u]
|
| 718 |
+
feat_v = vtx_pred_feats[final_v]
|
| 719 |
+
|
| 720 |
+
# 对称特征融合
|
| 721 |
+
feat_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 722 |
+
feat_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 723 |
+
|
| 724 |
+
logits_uv = self.connection_head(feat_uv)
|
| 725 |
+
logits_vu = self.connection_head(feat_vu)
|
| 726 |
+
logits = logits_uv + logits_vu
|
| 727 |
+
|
| 728 |
+
# print('targets.sum()', targets.sum())
|
| 729 |
+
# print('targets.shape', targets.shape)
|
| 730 |
+
|
| 731 |
+
# export_sampled_edges(
|
| 732 |
+
# coords=vtx_pred_coords, # [N, 4]
|
| 733 |
+
# u=final_u, # [E]
|
| 734 |
+
# v=final_v, # [E]
|
| 735 |
+
# labels=targets, # [E, 1]
|
| 736 |
+
# step_idx=0,
|
| 737 |
+
# )
|
| 738 |
+
# exit()
|
| 739 |
+
|
| 740 |
+
# Focal Loss
|
| 741 |
+
connection_loss = self.asyloss(logits, targets)
|
| 742 |
+
total_loss += connection_loss
|
| 743 |
+
|
| 744 |
+
else:
|
| 745 |
+
connection_loss = torch.tensor(0., device=self.device)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
# KL loss
|
| 750 |
+
kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
|
| 751 |
+
total_loss += kl_loss
|
| 752 |
+
|
| 753 |
+
# Backpropagation
|
| 754 |
+
scaled_total_loss = total_loss / self.accum_steps
|
| 755 |
+
self.scaler.scale(scaled_total_loss).backward()
|
| 756 |
+
|
| 757 |
+
return {
|
| 758 |
+
'total_loss': total_loss.item(),
|
| 759 |
+
'kl_loss': kl_loss.item(),
|
| 760 |
+
'prune_loss': prune_loss_total.item(),
|
| 761 |
+
'vertex_loss': vertex_loss_total.item(),
|
| 762 |
+
'edge_loss': edge_loss_total.item(),
|
| 763 |
+
'offset_loss': mse_loss_feats.item(),
|
| 764 |
+
'direction_loss': mse_loss_dirs.item(),
|
| 765 |
+
'connection_loss': connection_loss.item(),
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
def train(self):
|
| 770 |
+
accum_steps = self.accum_steps
|
| 771 |
+
for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
|
| 772 |
+
self.dataloader.sampler.set_epoch(epoch)
|
| 773 |
+
# Initialize metrics
|
| 774 |
+
metrics = {
|
| 775 |
+
'total_loss': 0.0,
|
| 776 |
+
'kl_loss': 0.0,
|
| 777 |
+
'prune_loss': 0.0,
|
| 778 |
+
'vertex_loss': 0.0,
|
| 779 |
+
'edge_loss': 0.0,
|
| 780 |
+
'offset_loss': 0.0,
|
| 781 |
+
'direction_loss': 0.0,
|
| 782 |
+
'connection_loss': 0.0,
|
| 783 |
+
}
|
| 784 |
+
num_batches = 0
|
| 785 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 786 |
+
|
| 787 |
+
for i, batch in enumerate(self.dataloader):
|
| 788 |
+
# Get all losses from train_step
|
| 789 |
+
if batch is None:
|
| 790 |
+
continue
|
| 791 |
+
step_losses = self.train_step(batch)
|
| 792 |
+
|
| 793 |
+
# Accumulate losses
|
| 794 |
+
for key in metrics:
|
| 795 |
+
metrics[key] += step_losses[key]
|
| 796 |
+
num_batches += 1
|
| 797 |
+
|
| 798 |
+
if (i + 1) % accum_steps == 0:
|
| 799 |
+
self.scaler.unscale_(self.optimizer)
|
| 800 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
|
| 801 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
|
| 802 |
+
torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=1.0)
|
| 803 |
+
|
| 804 |
+
self.scaler.step(self.optimizer)
|
| 805 |
+
self.scaler.update()
|
| 806 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 807 |
+
|
| 808 |
+
self.scheduler.step()
|
| 809 |
+
|
| 810 |
+
# Print batch-level metrics
|
| 811 |
+
if self.is_master:
|
| 812 |
+
print(
|
| 813 |
+
f"[Epoch {epoch}] Batch:{num_batches} "
|
| 814 |
+
f"Loss: {step_losses['total_loss']:.4f}, "
|
| 815 |
+
f"KLL: {step_losses['kl_loss']:.4f}, "
|
| 816 |
+
f"PruneL: {step_losses['prune_loss']:.4f}, "
|
| 817 |
+
f"VertexL: {step_losses['vertex_loss']:.4f}, "
|
| 818 |
+
f"EdgeL: {step_losses['edge_loss']:.4f}, "
|
| 819 |
+
f"OffsetL: {step_losses['offset_loss']:.4f}, "
|
| 820 |
+
f"DireL: {step_losses['direction_loss']:.4f}, "
|
| 821 |
+
f"ConL: {step_losses['connection_loss']:.4f}, "
|
| 822 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
# if i % 2000 == 0 and i != 0:
|
| 826 |
+
# self.save_checkpoint(epoch, step_losses['total_loss'], i)
|
| 827 |
+
|
| 828 |
+
if num_batches % accum_steps != 0:
|
| 829 |
+
self.scaler.unscale_(self.optimizer)
|
| 830 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
|
| 831 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
|
| 832 |
+
torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=1.0)
|
| 833 |
+
|
| 834 |
+
self.scaler.step(self.optimizer)
|
| 835 |
+
self.scaler.update()
|
| 836 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 837 |
+
|
| 838 |
+
self.scheduler.step()
|
| 839 |
+
|
| 840 |
+
# Calculate epoch averages
|
| 841 |
+
avg_metrics = {key: value / num_batches for key, value in metrics.items()}
|
| 842 |
+
self.train_loss_history.append(avg_metrics['total_loss'])
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
# Log to file
|
| 846 |
+
if self.is_master:
|
| 847 |
+
with open(self.log_file, "a") as f:
|
| 848 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 849 |
+
log_line = (
|
| 850 |
+
f"Epoch {epoch:05d} | "
|
| 851 |
+
f"Loss: {avg_metrics['total_loss']:.6f} "
|
| 852 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 853 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 854 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 855 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 856 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 857 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 858 |
+
f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
|
| 859 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 860 |
+
f"[{current_time}]\n"
|
| 861 |
+
)
|
| 862 |
+
f.write(log_line)
|
| 863 |
+
|
| 864 |
+
# Print epoch summary
|
| 865 |
+
print(
|
| 866 |
+
f"[Epoch {epoch}] "
|
| 867 |
+
f"Avg Loss: {avg_metrics['total_loss']:.4f} "
|
| 868 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 869 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 870 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 871 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 872 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 873 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 874 |
+
f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
|
| 875 |
+
f"[{current_time}]\n"
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
# Save checkpoint
|
| 879 |
+
if epoch % self.cfg['training']['save_every'] == 0:
|
| 880 |
+
self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
|
| 881 |
+
|
| 882 |
+
# Update learning rate
|
| 883 |
+
if self.is_master:
|
| 884 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 885 |
+
print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
|
| 886 |
+
|
| 887 |
+
dist.barrier()
|
| 888 |
+
|
| 889 |
+
def main():
|
| 890 |
+
# Initialize the process group
|
| 891 |
+
dist.init_process_group(backend='nccl')
|
| 892 |
+
|
| 893 |
+
# Get rank and world size from environment variables set by the launcher
|
| 894 |
+
rank = int(os.environ['RANK'])
|
| 895 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 896 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 897 |
+
|
| 898 |
+
# Set the device for the current process. This is crucial.
|
| 899 |
+
torch.cuda.set_device(local_rank)
|
| 900 |
+
torch.manual_seed(42)
|
| 901 |
+
|
| 902 |
+
# with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 903 |
+
# Pass the distributed info to the Trainer
|
| 904 |
+
trainer = Trainer(
|
| 905 |
+
config_path="/root/Trisf/config_edge_1024_error_8enc_8dec_woself_finetune_128to256.yaml",
|
| 906 |
+
rank=rank,
|
| 907 |
+
world_size=world_size,
|
| 908 |
+
local_rank=local_rank
|
| 909 |
+
)
|
| 910 |
+
trainer.train()
|
| 911 |
+
|
| 912 |
+
# Clean up the process group
|
| 913 |
+
dist.destroy_process_group()
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
if __name__ == '__main__':
|
| 917 |
+
main()
|
train_slat_vae_512_128to512_pointnet_head.py
ADDED
|
@@ -0,0 +1,1090 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import os
|
| 3 |
+
# os.environ['ATTN_BACKEND'] = 'xformers'
|
| 4 |
+
import yaml
|
| 5 |
+
import torch
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from functools import partial
|
| 10 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 14 |
+
|
| 15 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 16 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 17 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 18 |
+
|
| 19 |
+
from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss
|
| 20 |
+
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 23 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 24 |
+
|
| 25 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 26 |
+
import math
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import open3d as o3d
|
| 30 |
+
|
| 31 |
+
def export_sampled_edges(coords, u, v, labels, edge_voxels=None, step_idx=0, save_dir="debug_viz"):
|
| 32 |
+
"""
|
| 33 |
+
导出顶点、采样的边以及背景边缘体素用于可视化 (PLY格式)
|
| 34 |
+
"""
|
| 35 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# 转为 CPU numpy
|
| 38 |
+
coords_np = coords.detach().cpu().numpy() # [N, 4] (b, x, y, z)
|
| 39 |
+
u_np = u.detach().cpu().numpy()
|
| 40 |
+
v_np = v.detach().cpu().numpy()
|
| 41 |
+
labels_np = labels.detach().cpu().numpy().flatten()
|
| 42 |
+
|
| 43 |
+
edge_voxels_np = None
|
| 44 |
+
if edge_voxels is not None:
|
| 45 |
+
edge_voxels_np = edge_voxels.detach().cpu().numpy() # [M, 4]
|
| 46 |
+
|
| 47 |
+
# 按 Batch 处理 (因为 visualization 最好单个物体看)
|
| 48 |
+
batch_ids = coords_np[:, 0]
|
| 49 |
+
unique_batches = np.unique(batch_ids)
|
| 50 |
+
|
| 51 |
+
for b_id in unique_batches:
|
| 52 |
+
# 1. 筛选当前 Batch 的顶点
|
| 53 |
+
mask_b = (batch_ids == b_id)
|
| 54 |
+
# 获取局部坐标 (x,y,z)
|
| 55 |
+
curr_verts = coords_np[mask_b, 1:4]
|
| 56 |
+
|
| 57 |
+
# 建立 全局索引 -> 局部索引 的映射
|
| 58 |
+
# u 和 v 是基于整个 batch 的全局索引,需要转换
|
| 59 |
+
global_indices = np.where(mask_b)[0]
|
| 60 |
+
# 创建一个大的映射数组 (假设索引范围不超过 max_len)
|
| 61 |
+
max_idx = np.max(global_indices) + 1
|
| 62 |
+
global_to_local = np.full(max_idx, -1)
|
| 63 |
+
global_to_local[global_indices] = np.arange(len(global_indices))
|
| 64 |
+
|
| 65 |
+
# 2. 筛选当前 Batch 相关的边
|
| 66 |
+
# 边的两个端点都必须属于当前 batch
|
| 67 |
+
# 只要检查 u 是否在当前 batch 范围内即可 (假设跨 batch 不连线)
|
| 68 |
+
batch_edge_mask = np.isin(u_np, global_indices)
|
| 69 |
+
|
| 70 |
+
curr_u_global = u_np[batch_edge_mask]
|
| 71 |
+
curr_v_global = v_np[batch_edge_mask]
|
| 72 |
+
curr_labels = labels_np[batch_edge_mask]
|
| 73 |
+
|
| 74 |
+
# 转为局部索引 (用于构建 mesh)
|
| 75 |
+
curr_u = global_to_local[curr_u_global]
|
| 76 |
+
curr_v = global_to_local[curr_v_global]
|
| 77 |
+
|
| 78 |
+
# 3. 筛选当前 Batch 的 Edge Voxels
|
| 79 |
+
curr_edge_voxels = None
|
| 80 |
+
if edge_voxels_np is not None:
|
| 81 |
+
mask_ev = (edge_voxels_np[:, 0] == b_id)
|
| 82 |
+
curr_edge_voxels = edge_voxels_np[mask_ev, 1:4]
|
| 83 |
+
|
| 84 |
+
# 4. 写入 PLY 文件
|
| 85 |
+
filename = os.path.join(save_dir, f"step_{step_idx}_batch_{int(b_id)}.ply")
|
| 86 |
+
write_ply_with_edges_and_voxels(
|
| 87 |
+
filename,
|
| 88 |
+
curr_verts,
|
| 89 |
+
curr_u,
|
| 90 |
+
curr_v,
|
| 91 |
+
curr_labels,
|
| 92 |
+
curr_edge_voxels
|
| 93 |
+
)
|
| 94 |
+
print(f"Saved visualization to {filename}")
|
| 95 |
+
|
| 96 |
+
def write_ply_with_edges_and_voxels(filename, verts, u, v, labels, edge_voxels=None):
|
| 97 |
+
"""
|
| 98 |
+
手动写入 PLY 文件,包含顶点、边和背景体素点。
|
| 99 |
+
为了区分,我们将它们合并到一个文件中,但使用颜色区分。
|
| 100 |
+
"""
|
| 101 |
+
# ------------------------------------------------
|
| 102 |
+
# 准备数据
|
| 103 |
+
# ------------------------------------------------
|
| 104 |
+
|
| 105 |
+
# 1. 节点 (Vertices) -> 设为 青色
|
| 106 |
+
num_verts = len(verts)
|
| 107 |
+
vertex_colors = np.tile([0, 255, 255], (num_verts, 1))
|
| 108 |
+
|
| 109 |
+
# 2. 边缘体素 (Edge Voxels) -> 设为 灰色 (作为背景)
|
| 110 |
+
if edge_voxels is not None:
|
| 111 |
+
num_ev = len(edge_voxels)
|
| 112 |
+
# 为了放在同一个 vertex buffer,我们把 edge voxels 追加到 verts 后面
|
| 113 |
+
all_points = np.vstack([verts, edge_voxels])
|
| 114 |
+
ev_colors = np.tile([180, 180, 180], (num_ev, 1))
|
| 115 |
+
all_colors = np.vstack([vertex_colors, ev_colors])
|
| 116 |
+
else:
|
| 117 |
+
all_points = verts
|
| 118 |
+
all_colors = vertex_colors
|
| 119 |
+
num_ev = 0
|
| 120 |
+
|
| 121 |
+
num_total_points = len(all_points)
|
| 122 |
+
|
| 123 |
+
# 3. 边 (Edges)
|
| 124 |
+
# PLY 格式的 edge 索引是基于当前点列表的
|
| 125 |
+
# u, v 已经是基于 verts 的局部索引了,不需要偏移 (因为 verts 在最前面)
|
| 126 |
+
|
| 127 |
+
# 筛选:我们通常只想看 Positive 的边 (labels==1),或者用颜色区分
|
| 128 |
+
# 这里全部写入,用颜色区分
|
| 129 |
+
# 红色 = 负样本 (预测连接但实际上没连/负采样)
|
| 130 |
+
# 绿色 = 正样本 (GT连接)
|
| 131 |
+
|
| 132 |
+
edge_list = np.stack([u, v], axis=1)
|
| 133 |
+
num_edges = len(edge_list)
|
| 134 |
+
|
| 135 |
+
edge_colors = np.zeros((num_edges, 3), dtype=int)
|
| 136 |
+
edge_colors[labels == 1] = [0, 255, 0] # Green for Positive
|
| 137 |
+
edge_colors[labels == 0] = [255, 0, 0] # Red for Negative
|
| 138 |
+
|
| 139 |
+
# ------------------------------------------------
|
| 140 |
+
# 写入文件 Header
|
| 141 |
+
# ------------------------------------------------
|
| 142 |
+
with open(filename, 'w') as f:
|
| 143 |
+
f.write("ply\n")
|
| 144 |
+
f.write("format ascii 1.0\n")
|
| 145 |
+
|
| 146 |
+
# Vertex Element (包含 Nodes 和 EdgeVoxels)
|
| 147 |
+
f.write(f"element vertex {num_total_points}\n")
|
| 148 |
+
f.write("property float x\n")
|
| 149 |
+
f.write("property float y\n")
|
| 150 |
+
f.write("property float z\n")
|
| 151 |
+
f.write("property uchar red\n")
|
| 152 |
+
f.write("property uchar green\n")
|
| 153 |
+
f.write("property uchar blue\n")
|
| 154 |
+
|
| 155 |
+
# Edge Element (包含采样的连接)
|
| 156 |
+
f.write(f"element edge {num_edges}\n")
|
| 157 |
+
f.write("property list uchar int vertex_indices\n")
|
| 158 |
+
f.write("property uchar red\n")
|
| 159 |
+
f.write("property uchar green\n")
|
| 160 |
+
f.write("property uchar blue\n")
|
| 161 |
+
|
| 162 |
+
f.write("end_header\n")
|
| 163 |
+
|
| 164 |
+
# ------------------------------------------------
|
| 165 |
+
# 写入数据 Body
|
| 166 |
+
# ------------------------------------------------
|
| 167 |
+
|
| 168 |
+
# 1. Write Points (Vertices + Edge Voxels)
|
| 169 |
+
for i in range(num_total_points):
|
| 170 |
+
p = all_points[i]
|
| 171 |
+
c = all_colors[i]
|
| 172 |
+
f.write(f"{p[0]:.4f} {p[1]:.4f} {p[2]:.4f} {int(c[0])} {int(c[1])} {int(c[2])}\n")
|
| 173 |
+
|
| 174 |
+
# 2. Write Edges
|
| 175 |
+
for i in range(num_edges):
|
| 176 |
+
e = edge_list[i]
|
| 177 |
+
c = edge_colors[i]
|
| 178 |
+
f.write(f"2 {int(e[0])} {int(e[1])} {int(c[0])} {int(c[1])} {int(c[2])}\n")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def flatten_coords_4d(coords_4d: torch.Tensor):
|
| 182 |
+
coords_4d_long = coords_4d.long()
|
| 183 |
+
|
| 184 |
+
base_x = 512
|
| 185 |
+
base_y = 512 * 512
|
| 186 |
+
base_z = 512 * 512 * 512
|
| 187 |
+
|
| 188 |
+
flat_coords = coords_4d_long[:, 0] * base_z + \
|
| 189 |
+
coords_4d_long[:, 1] * base_y + \
|
| 190 |
+
coords_4d_long[:, 2] * base_x + \
|
| 191 |
+
coords_4d_long[:, 3]
|
| 192 |
+
return flat_coords
|
| 193 |
+
|
| 194 |
+
def downsample_voxels(
|
| 195 |
+
voxels: torch.Tensor,
|
| 196 |
+
input_resolution: int,
|
| 197 |
+
output_resolution: int
|
| 198 |
+
) -> torch.Tensor:
|
| 199 |
+
if input_resolution % output_resolution != 0:
|
| 200 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 201 |
+
f"by output_resolution ({output_resolution}).")
|
| 202 |
+
|
| 203 |
+
factor = input_resolution // output_resolution
|
| 204 |
+
|
| 205 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 206 |
+
|
| 207 |
+
downsampled_voxels[:, 1:] //= factor
|
| 208 |
+
|
| 209 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 210 |
+
return unique_downsampled_voxels
|
| 211 |
+
|
| 212 |
+
class Trainer:
|
| 213 |
+
def __init__(self, config_path, rank, world_size, local_rank):
|
| 214 |
+
self.rank = rank
|
| 215 |
+
self.world_size = world_size
|
| 216 |
+
self.local_rank = local_rank
|
| 217 |
+
self.is_master = self.rank == 0
|
| 218 |
+
|
| 219 |
+
self.load_config(config_path)
|
| 220 |
+
self.accum_steps = max(1, 4 // self.cfg['training']['batch_size'])
|
| 221 |
+
|
| 222 |
+
self.config_hash = self.save_config_with_hash()
|
| 223 |
+
self.init_device()
|
| 224 |
+
self.init_dirs()
|
| 225 |
+
self.init_components()
|
| 226 |
+
self.init_training()
|
| 227 |
+
|
| 228 |
+
self.train_loss_history = []
|
| 229 |
+
self.eval_loss_history = []
|
| 230 |
+
self.best_eval_loss = float('inf')
|
| 231 |
+
|
| 232 |
+
def save_config_with_hash(self):
|
| 233 |
+
import hashlib
|
| 234 |
+
|
| 235 |
+
# Serialize config to hash
|
| 236 |
+
config_str = yaml.dump(self.cfg)
|
| 237 |
+
config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
|
| 238 |
+
|
| 239 |
+
# Prepare all flags as string for formatting
|
| 240 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 241 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 242 |
+
|
| 243 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 244 |
+
|
| 245 |
+
# Format save_dir with all placeholders
|
| 246 |
+
self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
|
| 247 |
+
dataset_name=dataset_name,
|
| 248 |
+
config_hash=config_hash,
|
| 249 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 250 |
+
multires=self.cfg['model']['multires'],
|
| 251 |
+
add_block_embed=add_block_embed_flag,
|
| 252 |
+
using_attn=using_attn_flag,
|
| 253 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
if self.is_master:
|
| 257 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 258 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 259 |
+
with open(config_path, 'w') as f:
|
| 260 |
+
yaml.dump(self.cfg, f)
|
| 261 |
+
|
| 262 |
+
dist.barrier()
|
| 263 |
+
return config_hash
|
| 264 |
+
|
| 265 |
+
def save_checkpoint(self, epoch, avg_loss, batch_idx):
|
| 266 |
+
if not self.is_master:
|
| 267 |
+
return
|
| 268 |
+
checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
|
| 269 |
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
| 270 |
+
|
| 271 |
+
torch.save({
|
| 272 |
+
'voxel_encoder': self.voxel_encoder.module.state_dict(),
|
| 273 |
+
'vae': self.vae.module.state_dict(),
|
| 274 |
+
'connection_head': self.connection_head.module.state_dict(),
|
| 275 |
+
'epoch': epoch,
|
| 276 |
+
'loss': avg_loss,
|
| 277 |
+
'config': self.cfg
|
| 278 |
+
}, checkpoint_path)
|
| 279 |
+
|
| 280 |
+
def quoted_presenter(dumper, data):
|
| 281 |
+
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
|
| 282 |
+
|
| 283 |
+
yaml.add_representer(str, quoted_presenter)
|
| 284 |
+
|
| 285 |
+
with open(config_path, 'w') as f:
|
| 286 |
+
yaml.dump(self.cfg, f)
|
| 287 |
+
|
| 288 |
+
def load_config(self, config_path):
|
| 289 |
+
with open(config_path) as f:
|
| 290 |
+
self.cfg = yaml.safe_load(f)
|
| 291 |
+
|
| 292 |
+
# Extract and convert flags for formatting
|
| 293 |
+
add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
|
| 294 |
+
using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
|
| 295 |
+
|
| 296 |
+
dataset_name = os.path.basename(self.cfg['dataset']['path'])
|
| 297 |
+
|
| 298 |
+
self.save_dir = self.cfg['experiment']['save_dir'].format(
|
| 299 |
+
dataset_name=dataset_name,
|
| 300 |
+
n_train_samples=self.cfg['dataset']['n_train_samples'],
|
| 301 |
+
multires=self.cfg['model']['multires'],
|
| 302 |
+
add_block_embed=add_block_embed_flag,
|
| 303 |
+
using_attn=using_attn_flag,
|
| 304 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if self.is_master:
|
| 308 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 309 |
+
dist.barrier()
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def init_device(self):
|
| 313 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 314 |
+
|
| 315 |
+
def init_dirs(self):
|
| 316 |
+
self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
|
| 317 |
+
if self.is_master:
|
| 318 |
+
with open(self.log_file, "a") as f:
|
| 319 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 320 |
+
f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
|
| 321 |
+
|
| 322 |
+
def init_components(self):
|
| 323 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 324 |
+
root_dir=self.cfg['dataset']['path'],
|
| 325 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 326 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 327 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 328 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 329 |
+
|
| 330 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 331 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 332 |
+
|
| 333 |
+
active_voxel_res=128,
|
| 334 |
+
pc_sample_number=819200,
|
| 335 |
+
|
| 336 |
+
sample_type=self.cfg['dataset']['sample_type'],
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
self.sampler = DistributedSampler(
|
| 340 |
+
self.dataset,
|
| 341 |
+
num_replicas=self.world_size,
|
| 342 |
+
rank=self.rank,
|
| 343 |
+
shuffle=True,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
self.dataloader = DataLoader(
|
| 347 |
+
self.dataset,
|
| 348 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 349 |
+
shuffle=False,
|
| 350 |
+
collate_fn=partial(collate_fn_pointnet,),
|
| 351 |
+
num_workers=self.cfg['training']['num_workers'],
|
| 352 |
+
pin_memory=True,
|
| 353 |
+
sampler=self.sampler,
|
| 354 |
+
# prefetch_factor=4,
|
| 355 |
+
persistent_workers=True,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 359 |
+
in_channels=15,
|
| 360 |
+
hidden_dim=256,
|
| 361 |
+
out_channels=1024,
|
| 362 |
+
scatter_type='mean',
|
| 363 |
+
n_blocks=5,
|
| 364 |
+
resolution=128,
|
| 365 |
+
add_label=False,
|
| 366 |
+
).to(self.device)
|
| 367 |
+
|
| 368 |
+
self.connection_head = ConnectionHead(
|
| 369 |
+
channels=128 * 2,
|
| 370 |
+
out_channels=1,
|
| 371 |
+
mlp_ratio=4,
|
| 372 |
+
).to(self.device)
|
| 373 |
+
|
| 374 |
+
# ablation 3: voxelvae_1volume, have tested
|
| 375 |
+
self.vae = VoxelVAE(
|
| 376 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 377 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 378 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 379 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 380 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 381 |
+
num_heads=8,
|
| 382 |
+
num_head_channels=64,
|
| 383 |
+
mlp_ratio=4.0,
|
| 384 |
+
attn_mode="swin",
|
| 385 |
+
window_size=8,
|
| 386 |
+
pe_mode="ape",
|
| 387 |
+
use_fp16=False,
|
| 388 |
+
use_checkpoint=True,
|
| 389 |
+
qk_rms_norm=False,
|
| 390 |
+
using_subdivide=True,
|
| 391 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 392 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 393 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 394 |
+
).to(self.device)
|
| 395 |
+
|
| 396 |
+
if self.cfg['training']['from_pretrained']:
|
| 397 |
+
load_pretrained_woself(
|
| 398 |
+
checkpoint_path=self.cfg['training']['checkpoint_path'],
|
| 399 |
+
voxel_encoder=self.voxel_encoder,
|
| 400 |
+
vae=self.vae,
|
| 401 |
+
connection_head=self.connection_head,
|
| 402 |
+
optimizer=None,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 406 |
+
self.connection_head = DDP(self.connection_head, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 407 |
+
self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=False)
|
| 408 |
+
|
| 409 |
+
def init_training(self):
|
| 410 |
+
self.optimizer = AdamW(
|
| 411 |
+
list(self.vae.module.parameters()) +
|
| 412 |
+
list(self.voxel_encoder.module.parameters()) +
|
| 413 |
+
list(self.connection_head.module.parameters()),
|
| 414 |
+
lr=self.cfg['training']['lr'],
|
| 415 |
+
weight_decay=0.01,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
|
| 419 |
+
max_epochs = self.cfg['training']['max_epochs']
|
| 420 |
+
num_training_steps = max_epochs * num_update_steps_per_epoch
|
| 421 |
+
|
| 422 |
+
num_warmup_steps = 200
|
| 423 |
+
|
| 424 |
+
self.scheduler = get_cosine_schedule_with_warmup(
|
| 425 |
+
self.optimizer,
|
| 426 |
+
num_warmup_steps=num_warmup_steps,
|
| 427 |
+
num_training_steps=num_training_steps
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=10.0).to(self.device)
|
| 431 |
+
self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
|
| 432 |
+
self.asyloss = AsymmetricFocalLoss(
|
| 433 |
+
gamma_pos=0.0,
|
| 434 |
+
gamma_neg=4.0,
|
| 435 |
+
clip=0.05,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
self.bce_loss = torch.nn.BCEWithLogitsLoss()
|
| 439 |
+
|
| 440 |
+
self.dice_loss = DiceLoss()
|
| 441 |
+
self.scaler = GradScaler()
|
| 442 |
+
|
| 443 |
+
def train_step(self, batch):
|
| 444 |
+
"""Modified training step that handles vertex and edge voxels separately after initial prediction."""
|
| 445 |
+
# 1. Retrieve data from batch
|
| 446 |
+
combined_voxels_512 = batch['combined_voxels_512'].to(self.device)
|
| 447 |
+
combined_voxel_labels_512 = batch['combined_voxel_labels_512'].to(self.device)
|
| 448 |
+
gt_vertex_voxels_512 = batch['gt_vertex_voxels_512'].to(self.device)
|
| 449 |
+
|
| 450 |
+
gt_edges = batch['gt_vertex_edge_indices_512'].to(self.device)
|
| 451 |
+
|
| 452 |
+
edge_mask = (combined_voxel_labels_512 == 1)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=128)
|
| 456 |
+
vtx_256 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=256)
|
| 457 |
+
vtx_512 = gt_vertex_voxels_512
|
| 458 |
+
|
| 459 |
+
edge_128 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=128)
|
| 460 |
+
edge_256 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=256)
|
| 461 |
+
edge_512 = combined_voxels_512
|
| 462 |
+
|
| 463 |
+
active_coords = batch['active_voxels_128'].to(self.device)
|
| 464 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 465 |
+
|
| 466 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 467 |
+
active_voxel_feats = self.voxel_encoder(
|
| 468 |
+
p=point_cloud,
|
| 469 |
+
sparse_coords=active_coords,
|
| 470 |
+
res=128,
|
| 471 |
+
bbox_size=(-0.5, 0.5),
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
sparse_input = SparseTensor(
|
| 475 |
+
feats=active_voxel_feats,
|
| 476 |
+
coords=active_coords.int()
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
gt_edge_voxels_list = [
|
| 480 |
+
edge_128,
|
| 481 |
+
edge_256,
|
| 482 |
+
edge_512,
|
| 483 |
+
]
|
| 484 |
+
|
| 485 |
+
gt_vertex_voxels_list = [
|
| 486 |
+
vtx_128,
|
| 487 |
+
vtx_256,
|
| 488 |
+
vtx_512,
|
| 489 |
+
]
|
| 490 |
+
|
| 491 |
+
results, posterior, latent_128 = self.vae(
|
| 492 |
+
sparse_input,
|
| 493 |
+
gt_vertex_voxels_list=gt_vertex_voxels_list,
|
| 494 |
+
gt_edge_voxels_list=gt_edge_voxels_list,
|
| 495 |
+
training=True,
|
| 496 |
+
sample_ratio=0.,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# print("results[-1]['edge']['coords_4d'][1827:1830]", results[-1]['edge']['coords_4d'][1827:1830])
|
| 500 |
+
total_loss = 0.
|
| 501 |
+
prune_loss_total = 0.
|
| 502 |
+
vertex_loss_total = 0.
|
| 503 |
+
edge_loss_total=0.
|
| 504 |
+
|
| 505 |
+
with autocast(dtype=torch.bfloat16):
|
| 506 |
+
initial_result = results[0]
|
| 507 |
+
vertex_mask = initial_result['vertex_mask']
|
| 508 |
+
vtx_logits = initial_result['vtx_feats']
|
| 509 |
+
vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
|
| 510 |
+
|
| 511 |
+
edge_mask = initial_result['edge_mask']
|
| 512 |
+
edge_logits = initial_result['edge_feats']
|
| 513 |
+
edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
|
| 514 |
+
|
| 515 |
+
vertex_loss_total += vertex_loss
|
| 516 |
+
edge_loss_total += edge_loss
|
| 517 |
+
|
| 518 |
+
total_loss += vertex_loss
|
| 519 |
+
total_loss += edge_loss
|
| 520 |
+
|
| 521 |
+
# Process each level's results
|
| 522 |
+
for idx, res_dict in enumerate(results[1:], start=1):
|
| 523 |
+
# Vertex branch losses
|
| 524 |
+
vertex_pred_coords = res_dict['vertex']['occ_coords']
|
| 525 |
+
vertex_occ_probs = res_dict['vertex']['occ_probs']
|
| 526 |
+
vertex_gt_coords = res_dict['vertex']['coords']
|
| 527 |
+
|
| 528 |
+
vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=512).float()
|
| 529 |
+
# print('vertex_labels.sum()', vertex_labels.sum(), idx)
|
| 530 |
+
vertex_logits = vertex_occ_probs.squeeze()
|
| 531 |
+
|
| 532 |
+
# if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
|
| 533 |
+
vertex_prune_loss = self.focal_loss(vertex_logits, vertex_labels)
|
| 534 |
+
# vertex_prune_loss = self.dice_loss(vertex_logits, vertex_labels)
|
| 535 |
+
|
| 536 |
+
# dilation 1: bce loss
|
| 537 |
+
# vertex_prune_loss = self.bce_loss(vertex_logits, vertex_labels,)
|
| 538 |
+
|
| 539 |
+
prune_loss_total += vertex_prune_loss
|
| 540 |
+
total_loss += vertex_prune_loss
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
# Edge branch losses
|
| 544 |
+
edge_pred_coords = res_dict['edge']['occ_coords']
|
| 545 |
+
edge_occ_probs = res_dict['edge']['occ_probs']
|
| 546 |
+
edge_gt_coords = res_dict['edge']['coords']
|
| 547 |
+
|
| 548 |
+
edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=512).float()
|
| 549 |
+
# print('edge_labels.sum()', edge_labels.sum(), idx)
|
| 550 |
+
edge_logits = edge_occ_probs.squeeze()
|
| 551 |
+
# if edge_labels.sum() > 0 and edge_labels.sum() < len(edge_labels):
|
| 552 |
+
edge_prune_loss = self.focal_loss(edge_logits, edge_labels)
|
| 553 |
+
|
| 554 |
+
# dilation 1: bce loss
|
| 555 |
+
# edge_prune_loss = self.bce_loss(edge_logits, edge_labels,)
|
| 556 |
+
|
| 557 |
+
prune_loss_total += edge_prune_loss
|
| 558 |
+
total_loss += edge_prune_loss
|
| 559 |
+
|
| 560 |
+
if idx == 2:
|
| 561 |
+
# pred_coords = res_dict['edge']['coords_4d'] # [N,4] (b,x,y,z)
|
| 562 |
+
# pred_feats = res_dict['edge']['predicted_offset_feats'] # [N,C]
|
| 563 |
+
|
| 564 |
+
# gt_coords = gt_edge_voxels_512.to(pred_coords.device) # [M,4]
|
| 565 |
+
# gt_feats = gt_edge_errors_512[:, 1:].to(pred_coords.device) # [M,C]
|
| 566 |
+
|
| 567 |
+
# pred_keys = flatten_coords_4d(pred_coords)
|
| 568 |
+
# gt_keys = flatten_coords_4d(gt_coords)
|
| 569 |
+
|
| 570 |
+
# sorted_pred_keys, pred_order = torch.sort(pred_keys)
|
| 571 |
+
# pred_coords_sorted = pred_coords[pred_order]
|
| 572 |
+
# pred_feats_sorted = pred_feats[pred_order]
|
| 573 |
+
|
| 574 |
+
# sorted_gt_keys, gt_order = torch.sort(gt_keys)
|
| 575 |
+
# gt_coords_sorted = gt_coords[gt_order]
|
| 576 |
+
# gt_feats_sorted = gt_feats[gt_order]
|
| 577 |
+
|
| 578 |
+
# pos = torch.searchsorted(sorted_gt_keys, sorted_pred_keys)
|
| 579 |
+
# valid_mask = (pos < len(sorted_gt_keys)) & (sorted_gt_keys[pos] == sorted_pred_keys)
|
| 580 |
+
|
| 581 |
+
# if valid_mask.any():
|
| 582 |
+
# # print('valid_mask.sum()', valid_mask.sum())
|
| 583 |
+
# matched_pred_feats = pred_feats_sorted[valid_mask]
|
| 584 |
+
# matched_gt_feats = gt_feats_sorted[pos[valid_mask]]
|
| 585 |
+
# mse_loss_feats = self.mse_loss(matched_pred_feats, matched_gt_feats * 2)
|
| 586 |
+
# total_loss += mse_loss_feats * 0.
|
| 587 |
+
|
| 588 |
+
# if self.cfg['model'].get('pred_direction', False):
|
| 589 |
+
# pred_dirs = res_dict['edge']['predicted_direction_feats']
|
| 590 |
+
# dir_gt_device = dir_gt.to(pred_coords.device)
|
| 591 |
+
|
| 592 |
+
# pred_dirs_sorted = pred_dirs[pred_order]
|
| 593 |
+
# dir_gt_sorted = dir_gt_device[gt_order]
|
| 594 |
+
|
| 595 |
+
# matched_pred_dirs = pred_dirs_sorted[valid_mask]
|
| 596 |
+
# matched_gt_dirs = dir_gt_sorted[pos[valid_mask]]
|
| 597 |
+
|
| 598 |
+
# mse_loss_dirs = self.mse_loss(matched_pred_dirs, matched_gt_dirs)
|
| 599 |
+
# total_loss += mse_loss_dirs * 0.
|
| 600 |
+
# else:
|
| 601 |
+
# mse_loss_feats = torch.tensor(0., device=pred_coords.device)
|
| 602 |
+
# if self.cfg['model'].get('pred_direction', False):
|
| 603 |
+
# mse_loss_dirs = torch.tensor(0., device=pred_coords.device)
|
| 604 |
+
|
| 605 |
+
mse_loss_dirs = torch.tensor(0., device=self.device)
|
| 606 |
+
mse_loss_feats = torch.tensor(0., device=self.device)
|
| 607 |
+
|
| 608 |
+
# --- Vertex Branch (Connection Loss 核心) ---
|
| 609 |
+
vtx_pred_coords = res_dict['vertex']['coords_4d'] # [N, 4]
|
| 610 |
+
vtx_pred_feats = res_dict['vertex']['feats'] # [N, C]
|
| 611 |
+
|
| 612 |
+
# 1.1 排序 (既用于��配 GT,也用于快速寻找空间邻居)
|
| 613 |
+
vtx_pred_keys = flatten_coords_4d(vtx_pred_coords)
|
| 614 |
+
vtx_pred_keys_sorted, vtx_pred_order = torch.sort(vtx_pred_keys)
|
| 615 |
+
|
| 616 |
+
# 1.2 匹配 GT
|
| 617 |
+
vtx_gt_keys = flatten_coords_4d(gt_vertex_voxels_512.to(self.device))
|
| 618 |
+
vtx_pos = torch.searchsorted(vtx_pred_keys_sorted, vtx_gt_keys)
|
| 619 |
+
vtx_pos = vtx_pos.clamp(max=len(vtx_pred_keys_sorted) - 1)
|
| 620 |
+
vtx_match_mask = (vtx_pred_keys_sorted[vtx_pos] == vtx_gt_keys)
|
| 621 |
+
|
| 622 |
+
gt_to_pred_mapping = torch.full((len(vtx_gt_keys),), -1, device=self.device, dtype=torch.long)
|
| 623 |
+
matched_pred_indices = vtx_pred_order[vtx_pos[vtx_match_mask]]
|
| 624 |
+
gt_to_pred_mapping[vtx_match_mask] = matched_pred_indices
|
| 625 |
+
|
| 626 |
+
# ====================================================
|
| 627 |
+
# 2. 构建核心数据:正样本 Hash 集合
|
| 628 |
+
# ====================================================
|
| 629 |
+
# 这里的 pos_u/pos_v 仅用于构建 "什么是真连接" 的查询表
|
| 630 |
+
u_gt, v_gt = gt_edges[:, 0], gt_edges[:, 1]
|
| 631 |
+
u_pred = gt_to_pred_mapping[u_gt]
|
| 632 |
+
v_pred = gt_to_pred_mapping[v_gt]
|
| 633 |
+
|
| 634 |
+
valid_edge_mask = (u_pred != -1) & (v_pred != -1)
|
| 635 |
+
real_pos_u = u_pred[valid_edge_mask]
|
| 636 |
+
real_pos_v = v_pred[valid_edge_mask]
|
| 637 |
+
|
| 638 |
+
num_real_pos = real_pos_u.shape[0]
|
| 639 |
+
num_total_nodes = vtx_pred_coords.shape[0]
|
| 640 |
+
|
| 641 |
+
# if num_real_pos > 0:
|
| 642 |
+
# # 2. 构建候选样本 (Candidate Generation)
|
| 643 |
+
# # ====================================================
|
| 644 |
+
# cand_u_list = []
|
| 645 |
+
# cand_v_list = []
|
| 646 |
+
|
| 647 |
+
# batch_ids = vtx_pred_coords[:, 0]
|
| 648 |
+
# unique_batches = torch.unique(batch_ids)
|
| 649 |
+
|
| 650 |
+
# RADIUS = 32
|
| 651 |
+
# MAX_PTS_FOR_DIST = 12000
|
| 652 |
+
# K_RANDOM = 32
|
| 653 |
+
|
| 654 |
+
# for b_id in unique_batches:
|
| 655 |
+
# mask_b = (batch_ids == b_id)
|
| 656 |
+
# indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
|
| 657 |
+
# coords_b = vtx_pred_coords[mask_b, 1:4].float() # (x,y,z)
|
| 658 |
+
# num_b = coords_b.shape[0]
|
| 659 |
+
|
| 660 |
+
# if num_b < 2: continue
|
| 661 |
+
|
| 662 |
+
# # --- A. Radius Graph (Hard Negatives) ---
|
| 663 |
+
# if num_b <= MAX_PTS_FOR_DIST:
|
| 664 |
+
# # 计算距离矩阵 [M, M]
|
| 665 |
+
# # 注意:autocast 下 float16 的 cdist 可能精度不够,建议转 float32
|
| 666 |
+
# dist_mat = torch.cdist(coords_b.float(), coords_b.float())
|
| 667 |
+
|
| 668 |
+
# # 找到距离小于 Radius 的点对 (排除自环)
|
| 669 |
+
# adj_mat = (dist_mat < RADIUS) & (dist_mat > 1e-6)
|
| 670 |
+
|
| 671 |
+
# # 提取索引 (local indices in batch)
|
| 672 |
+
# src_local, dst_local = torch.nonzero(adj_mat, as_tuple=True)
|
| 673 |
+
|
| 674 |
+
# # 映射回全局索引
|
| 675 |
+
# cand_u_list.append(indices_b[src_local])
|
| 676 |
+
# cand_v_list.append(indices_b[dst_local])
|
| 677 |
+
# else:
|
| 678 |
+
# # 如果点太多,显存不够,退化为随机局部采样或跳过
|
| 679 |
+
# # 这里简单处理:跳过 Radius Graph,依赖 Random
|
| 680 |
+
# pass
|
| 681 |
+
|
| 682 |
+
# # --- B. Random Sampling (Easy Negatives) ---
|
| 683 |
+
# # 随机生成 num_b * K 对
|
| 684 |
+
# n_rand = num_b * K_RANDOM
|
| 685 |
+
# rand_src_local = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 686 |
+
# rand_dst_local = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 687 |
+
|
| 688 |
+
# # 映射回全局索引
|
| 689 |
+
# cand_u_list.append(indices_b[rand_src_local])
|
| 690 |
+
# cand_v_list.append(indices_b[rand_dst_local])
|
| 691 |
+
|
| 692 |
+
# # 合并所有来源 (GT + Radius + Random)
|
| 693 |
+
# # 注意:我们把 real_pos 也加进来,确保正样本一定在列表里
|
| 694 |
+
# all_u = torch.cat([real_pos_u] + cand_u_list)
|
| 695 |
+
# all_v = torch.cat([real_pos_v] + cand_v_list)
|
| 696 |
+
|
| 697 |
+
# # 3. 去重与 Labeling (Deduplication & Labeling)
|
| 698 |
+
# # ====================================================
|
| 699 |
+
# # 构造无向边 Hash: min * N + max
|
| 700 |
+
# # 确保 MAX_NODES 足够大,比如 1000000 或 num_total_nodes
|
| 701 |
+
# HASH_BASE = num_total_nodes + 100
|
| 702 |
+
|
| 703 |
+
# p_min = torch.min(all_u, all_v)
|
| 704 |
+
# p_max = torch.max(all_u, all_v)
|
| 705 |
+
|
| 706 |
+
# # 过滤掉自环 (u==v)
|
| 707 |
+
# valid_pair = (p_min != p_max)
|
| 708 |
+
# p_min = p_min[valid_pair]
|
| 709 |
+
# p_max = p_max[valid_pair]
|
| 710 |
+
|
| 711 |
+
# all_hashes = p_min.long() * HASH_BASE + p_max.long()
|
| 712 |
+
|
| 713 |
+
# # --- 核心:去重 ---
|
| 714 |
+
# unique_hashes = torch.unique(all_hashes)
|
| 715 |
+
|
| 716 |
+
# # 解码回 u, v
|
| 717 |
+
# final_u = unique_hashes // HASH_BASE
|
| 718 |
+
# final_v = unique_hashes % HASH_BASE
|
| 719 |
+
|
| 720 |
+
# # --- Labeling ---
|
| 721 |
+
# # 构建 GT 的 Hash 表用于查询
|
| 722 |
+
# gt_min = torch.min(real_pos_u, real_pos_v)
|
| 723 |
+
# gt_max = torch.max(real_pos_u, real_pos_v)
|
| 724 |
+
# gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
|
| 725 |
+
# gt_hashes = torch.unique(gt_hashes) # GT 也去重一下保险
|
| 726 |
+
# gt_hashes_sorted, _ = torch.sort(gt_hashes)
|
| 727 |
+
|
| 728 |
+
# # 查询 unique_hashes 是否在 gt_hashes 中
|
| 729 |
+
# # 使用 searchsorted
|
| 730 |
+
# idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
|
| 731 |
+
# idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
|
| 732 |
+
# is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
|
| 733 |
+
|
| 734 |
+
# targets = is_connected.float().unsqueeze(-1) # [N_pairs, 1]
|
| 735 |
+
|
| 736 |
+
# # 4. 前向传播与 Loss
|
| 737 |
+
# # ====================================================
|
| 738 |
+
# feat_u = vtx_pred_feats[final_u]
|
| 739 |
+
# feat_v = vtx_pred_feats[final_v]
|
| 740 |
+
|
| 741 |
+
# # 对称特征融合
|
| 742 |
+
# feat_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 743 |
+
# feat_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 744 |
+
|
| 745 |
+
# logits_uv = self.connection_head(feat_uv)
|
| 746 |
+
# logits_vu = self.connection_head(feat_vu)
|
| 747 |
+
# logits = logits_uv + logits_vu
|
| 748 |
+
|
| 749 |
+
# # print('targets.sum()', targets.sum())
|
| 750 |
+
# # print('targets.shape', targets.shape)
|
| 751 |
+
|
| 752 |
+
# # viz_edge_voxels = combined_voxels_512
|
| 753 |
+
|
| 754 |
+
# # export_sampled_edges(
|
| 755 |
+
# # coords=vtx_pred_coords, # [N, 4] (Batch, X, Y, Z) - 顶点
|
| 756 |
+
# # u=final_u, # [E] - 连线起点索引
|
| 757 |
+
# # v=final_v, # [E] - 连线终点索引
|
| 758 |
+
# # labels=targets, # [E, 1] - 连线标签 (1=Pos, 0=Neg)
|
| 759 |
+
# # edge_voxels=viz_edge_voxels, # [M, 4] - 新增:512分辨率边缘体素
|
| 760 |
+
# # step_idx=0, # 或者传入当前的 step/epoch
|
| 761 |
+
# # save_dir="debug_viz" # 建议指定保存路径
|
| 762 |
+
# # )
|
| 763 |
+
# # exit()
|
| 764 |
+
|
| 765 |
+
# # asyloss Loss
|
| 766 |
+
# connection_loss = self.asyloss(logits, targets)
|
| 767 |
+
# total_loss += connection_loss
|
| 768 |
+
|
| 769 |
+
# else:
|
| 770 |
+
# connection_loss = torch.tensor(0., device=self.device)
|
| 771 |
+
|
| 772 |
+
if num_real_pos > 0:
|
| 773 |
+
# ====================================================
|
| 774 |
+
# 2. 构建候选样本 (Candidate Generation) - KNN + Random 版
|
| 775 |
+
# ====================================================
|
| 776 |
+
cand_u_list = []
|
| 777 |
+
cand_v_list = []
|
| 778 |
+
|
| 779 |
+
batch_ids = vtx_pred_coords[:, 0]
|
| 780 |
+
unique_batches = torch.unique(batch_ids)
|
| 781 |
+
|
| 782 |
+
# === 配置 ===
|
| 783 |
+
K_KNN = 64
|
| 784 |
+
K_RANDOM = 32
|
| 785 |
+
# 这样每个点大约产生 32 条边,总显存 = N * 32 * MLP大小,非常稳定
|
| 786 |
+
|
| 787 |
+
# 安全阈值:如果点数太多,cdist矩阵本身会爆,需要限制
|
| 788 |
+
MAX_PTS_FOR_KNN = 12000
|
| 789 |
+
|
| 790 |
+
for b_id in unique_batches:
|
| 791 |
+
mask_b = (batch_ids == b_id)
|
| 792 |
+
indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
|
| 793 |
+
coords_b = vtx_pred_coords[mask_b, 1:4].float()
|
| 794 |
+
num_b = coords_b.shape[0]
|
| 795 |
+
|
| 796 |
+
if num_b < 2: continue
|
| 797 |
+
|
| 798 |
+
# --- A. KNN Graph (距离最近的 K 个) ---
|
| 799 |
+
# 只有当点数在可接受范围内时才算矩阵,否则只用随机
|
| 800 |
+
if num_b <= MAX_PTS_FOR_KNN:
|
| 801 |
+
# 1. 计算距离矩阵 [N, N]
|
| 802 |
+
# 注意: 12000个点产生的矩阵约 576MB,非常安全
|
| 803 |
+
dist_mat = torch.cdist(coords_b, coords_b)
|
| 804 |
+
|
| 805 |
+
# 2. 取最近的 K+1 个 (包含自己)
|
| 806 |
+
# largest=False 表示取最小距离
|
| 807 |
+
k_val = min(K_KNN + 1, num_b)
|
| 808 |
+
_, knn_indices_local = torch.topk(dist_mat, k=k_val, dim=1, largest=False)
|
| 809 |
+
|
| 810 |
+
# 3. 去掉自己 (第0个通常是距离为0的自己)
|
| 811 |
+
knn_indices_local = knn_indices_local[:, 1:]
|
| 812 |
+
|
| 813 |
+
# 4. 构建边索引
|
| 814 |
+
# repeat_interleave 生成源点: [0,0,0, 1,1,1...]
|
| 815 |
+
src_local = torch.arange(num_b, device=self.device).repeat_interleave(knn_indices_local.shape[1])
|
| 816 |
+
dst_local = knn_indices_local.flatten()
|
| 817 |
+
|
| 818 |
+
cand_u_list.append(indices_b[src_local])
|
| 819 |
+
cand_v_list.append(indices_b[dst_local])
|
| 820 |
+
|
| 821 |
+
# --- B. Random Sampling (随机 K 个) ---
|
| 822 |
+
# 不管点多点少,都可以做随机
|
| 823 |
+
n_rand = num_b * K_RANDOM
|
| 824 |
+
if n_rand > 0:
|
| 825 |
+
rand_src = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 826 |
+
rand_dst = torch.randint(0, num_b, (n_rand,), device=self.device)
|
| 827 |
+
|
| 828 |
+
cand_u_list.append(indices_b[rand_src])
|
| 829 |
+
cand_v_list.append(indices_b[rand_dst])
|
| 830 |
+
|
| 831 |
+
if len(cand_u_list) > 0:
|
| 832 |
+
all_u = torch.cat([real_pos_u] + cand_u_list)
|
| 833 |
+
all_v = torch.cat([real_pos_v] + cand_v_list)
|
| 834 |
+
else:
|
| 835 |
+
all_u = real_pos_u
|
| 836 |
+
all_v = real_pos_v
|
| 837 |
+
|
| 838 |
+
# ====================================================
|
| 839 |
+
# 3. 去重与 Labeling (Logic 保持不变)
|
| 840 |
+
# ====================================================
|
| 841 |
+
HASH_BASE = num_total_nodes + 100
|
| 842 |
+
|
| 843 |
+
p_min = torch.min(all_u, all_v)
|
| 844 |
+
p_max = torch.max(all_u, all_v)
|
| 845 |
+
valid_pair = (p_min != p_max)
|
| 846 |
+
p_min = p_min[valid_pair]
|
| 847 |
+
p_max = p_max[valid_pair]
|
| 848 |
+
|
| 849 |
+
all_hashes = p_min.long() * HASH_BASE + p_max.long()
|
| 850 |
+
|
| 851 |
+
# 去重:因为 KNN 和 Random 可能会重复,或者和 GT 重复
|
| 852 |
+
unique_hashes = torch.unique(all_hashes)
|
| 853 |
+
|
| 854 |
+
# 【注意】这里不需要再做 Max Limit 截断了
|
| 855 |
+
# 因为边数严格受控于 (N * K),不会出现数量级爆炸的情况。
|
| 856 |
+
|
| 857 |
+
# 解码回 u, v
|
| 858 |
+
final_u = unique_hashes // HASH_BASE
|
| 859 |
+
final_v = unique_hashes % HASH_BASE
|
| 860 |
+
|
| 861 |
+
# --- Labeling ---
|
| 862 |
+
gt_min = torch.min(real_pos_u, real_pos_v)
|
| 863 |
+
gt_max = torch.max(real_pos_u, real_pos_v)
|
| 864 |
+
gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
|
| 865 |
+
gt_hashes = torch.unique(gt_hashes)
|
| 866 |
+
gt_hashes_sorted, _ = torch.sort(gt_hashes)
|
| 867 |
+
|
| 868 |
+
idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
|
| 869 |
+
idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
|
| 870 |
+
is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
|
| 871 |
+
|
| 872 |
+
targets = is_connected.float().unsqueeze(-1)
|
| 873 |
+
|
| 874 |
+
feat_u = vtx_pred_feats[final_u]
|
| 875 |
+
feat_v = vtx_pred_feats[final_v]
|
| 876 |
+
|
| 877 |
+
feat_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 878 |
+
feat_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 879 |
+
|
| 880 |
+
logits_uv = self.connection_head(feat_uv)
|
| 881 |
+
logits_vu = self.connection_head(feat_vu)
|
| 882 |
+
logits = logits_uv + logits_vu
|
| 883 |
+
|
| 884 |
+
# viz_edge_voxels = combined_voxels_512
|
| 885 |
+
# export_sampled_edges(
|
| 886 |
+
# coords=vtx_pred_coords, # [N, 4] (Batch, X, Y, Z) - 顶点
|
| 887 |
+
# u=final_u, # [E] - 连线起点索引
|
| 888 |
+
# v=final_v, # [E] - 连线终点索引
|
| 889 |
+
# labels=targets, # [E, 1] - 连线标签 (1=Pos, 0=Neg)
|
| 890 |
+
# edge_voxels=viz_edge_voxels, # [M, 4] - 新增:512分辨率边缘体素
|
| 891 |
+
# step_idx=0, # 或者传入当前的 step/epoch
|
| 892 |
+
# save_dir="debug_viz" # 建议指定保存路径
|
| 893 |
+
# )
|
| 894 |
+
# exit()
|
| 895 |
+
|
| 896 |
+
connection_loss = self.asyloss(logits, targets)
|
| 897 |
+
total_loss += connection_loss
|
| 898 |
+
|
| 899 |
+
else:
|
| 900 |
+
connection_loss = torch.tensor(0., device=self.device)
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
# KL loss
|
| 904 |
+
kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
|
| 905 |
+
total_loss += kl_loss
|
| 906 |
+
|
| 907 |
+
# Backpropagation
|
| 908 |
+
scaled_total_loss = total_loss / self.accum_steps
|
| 909 |
+
self.scaler.scale(scaled_total_loss).backward()
|
| 910 |
+
|
| 911 |
+
return {
|
| 912 |
+
'total_loss': total_loss.item(),
|
| 913 |
+
'kl_loss': kl_loss.item(),
|
| 914 |
+
'prune_loss': prune_loss_total.item(),
|
| 915 |
+
'vertex_loss': vertex_loss_total.item(),
|
| 916 |
+
'edge_loss': edge_loss_total.item(),
|
| 917 |
+
'offset_loss': mse_loss_feats.item(),
|
| 918 |
+
'direction_loss': mse_loss_dirs.item(),
|
| 919 |
+
'connection_loss': connection_loss.item(),
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
def train(self):
|
| 924 |
+
accum_steps = self.accum_steps
|
| 925 |
+
for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
|
| 926 |
+
self.dataloader.sampler.set_epoch(epoch)
|
| 927 |
+
# Initialize metrics
|
| 928 |
+
metrics = {
|
| 929 |
+
'total_loss': 0.0,
|
| 930 |
+
'kl_loss': 0.0,
|
| 931 |
+
'prune_loss': 0.0,
|
| 932 |
+
'vertex_loss': 0.0,
|
| 933 |
+
'edge_loss': 0.0,
|
| 934 |
+
'offset_loss': 0.0,
|
| 935 |
+
'direction_loss': 0.0,
|
| 936 |
+
'connection_loss': 0.0,
|
| 937 |
+
}
|
| 938 |
+
num_batches = 0
|
| 939 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 940 |
+
|
| 941 |
+
for i, batch in enumerate(self.dataloader):
|
| 942 |
+
# Get all losses from train_step
|
| 943 |
+
if batch is None:
|
| 944 |
+
continue
|
| 945 |
+
step_losses = self.train_step(batch)
|
| 946 |
+
|
| 947 |
+
# Accumulate losses
|
| 948 |
+
for key in metrics:
|
| 949 |
+
metrics[key] += step_losses[key]
|
| 950 |
+
num_batches += 1
|
| 951 |
+
|
| 952 |
+
if (i + 1) % accum_steps == 0:
|
| 953 |
+
self.scaler.unscale_(self.optimizer)
|
| 954 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
|
| 955 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
|
| 956 |
+
torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=1.0)
|
| 957 |
+
|
| 958 |
+
self.scaler.step(self.optimizer)
|
| 959 |
+
self.scaler.update()
|
| 960 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 961 |
+
|
| 962 |
+
self.scheduler.step()
|
| 963 |
+
|
| 964 |
+
# Print batch-level metrics
|
| 965 |
+
if self.is_master:
|
| 966 |
+
avg_metric = {key: value / num_batches for key, value in metrics.items()}
|
| 967 |
+
print(
|
| 968 |
+
f"[Epoch {epoch}] Batch:{num_batches} "
|
| 969 |
+
f"AvgL:{avg_metric['total_loss']:.4f} "
|
| 970 |
+
f"Loss: {step_losses['total_loss']:.4f}, "
|
| 971 |
+
f"KLL: {step_losses['kl_loss']:.4f}, "
|
| 972 |
+
f"PruneL: {step_losses['prune_loss']:.4f}, "
|
| 973 |
+
f"VertexL: {step_losses['vertex_loss']:.4f}, "
|
| 974 |
+
f"EdgeL: {step_losses['edge_loss']:.4f}, "
|
| 975 |
+
f"OffsetL: {step_losses['offset_loss']:.4f}, "
|
| 976 |
+
f"DireL: {step_losses['direction_loss']:.4f}, "
|
| 977 |
+
f"ConL: {step_losses['connection_loss']:.4f}, "
|
| 978 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
if i % 2000 == 0 and i != 0:
|
| 982 |
+
self.save_checkpoint(epoch, avg_metric['total_loss'], i)
|
| 983 |
+
with open(self.log_file, "a") as f:
|
| 984 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 985 |
+
log_line = (
|
| 986 |
+
f"Epoch {epoch:05d} | "
|
| 987 |
+
f"Batch {i:05d} | "
|
| 988 |
+
f"Loss: {avg_metric['total_loss']:.6f} "
|
| 989 |
+
f"Avg KLL: {avg_metric['kl_loss']:.4f} "
|
| 990 |
+
f"Avg PruneL: {avg_metric['prune_loss']:.4f} "
|
| 991 |
+
f"Avg VertexL: {avg_metric['vertex_loss']:.4f} "
|
| 992 |
+
f"Avg EdgeL: {avg_metric['edge_loss']:.4f} "
|
| 993 |
+
f"Avg OffsetL: {avg_metric['offset_loss']:.4f} "
|
| 994 |
+
f"Avg DireL: {avg_metric['direction_loss']:.4f} "
|
| 995 |
+
f"Avg ConL: {avg_metric['connection_loss']:.4f} "
|
| 996 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 997 |
+
f"[{current_time}]\n"
|
| 998 |
+
)
|
| 999 |
+
f.write(log_line)
|
| 1000 |
+
|
| 1001 |
+
if num_batches % accum_steps != 0:
|
| 1002 |
+
self.scaler.unscale_(self.optimizer)
|
| 1003 |
+
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
|
| 1004 |
+
torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
|
| 1005 |
+
torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=1.0)
|
| 1006 |
+
|
| 1007 |
+
self.scaler.step(self.optimizer)
|
| 1008 |
+
self.scaler.update()
|
| 1009 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 1010 |
+
|
| 1011 |
+
self.scheduler.step()
|
| 1012 |
+
|
| 1013 |
+
# Calculate epoch averages
|
| 1014 |
+
avg_metrics = {key: value / num_batches for key, value in metrics.items()}
|
| 1015 |
+
self.train_loss_history.append(avg_metrics['total_loss'])
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
# Log to file
|
| 1019 |
+
if self.is_master:
|
| 1020 |
+
with open(self.log_file, "a") as f:
|
| 1021 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 1022 |
+
log_line = (
|
| 1023 |
+
f"Epoch {epoch:05d} | "
|
| 1024 |
+
f"Loss: {avg_metrics['total_loss']:.6f} "
|
| 1025 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 1026 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 1027 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 1028 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 1029 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 1030 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 1031 |
+
f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
|
| 1032 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
|
| 1033 |
+
f"[{current_time}]\n"
|
| 1034 |
+
)
|
| 1035 |
+
f.write(log_line)
|
| 1036 |
+
|
| 1037 |
+
# Print epoch summary
|
| 1038 |
+
print(
|
| 1039 |
+
f"[Epoch {epoch}] "
|
| 1040 |
+
f"Avg Loss: {avg_metrics['total_loss']:.4f} "
|
| 1041 |
+
f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
|
| 1042 |
+
f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
|
| 1043 |
+
f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
|
| 1044 |
+
f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
|
| 1045 |
+
f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
|
| 1046 |
+
f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
|
| 1047 |
+
f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
|
| 1048 |
+
f"[{current_time}]\n"
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
# Save checkpoint
|
| 1052 |
+
if epoch % self.cfg['training']['save_every'] == 0:
|
| 1053 |
+
self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
|
| 1054 |
+
|
| 1055 |
+
# Update learning rate
|
| 1056 |
+
if self.is_master:
|
| 1057 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 1058 |
+
print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
|
| 1059 |
+
|
| 1060 |
+
dist.barrier()
|
| 1061 |
+
|
| 1062 |
+
def main():
|
| 1063 |
+
# Initialize the process group
|
| 1064 |
+
dist.init_process_group(backend='nccl')
|
| 1065 |
+
|
| 1066 |
+
# Get rank and world size from environment variables set by the launcher
|
| 1067 |
+
rank = int(os.environ['RANK'])
|
| 1068 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 1069 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 1070 |
+
|
| 1071 |
+
# Set the device for the current process. This is crucial.
|
| 1072 |
+
torch.cuda.set_device(local_rank)
|
| 1073 |
+
torch.manual_seed(42+rank)
|
| 1074 |
+
|
| 1075 |
+
# with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 1076 |
+
# Pass the distributed info to the Trainer
|
| 1077 |
+
trainer = Trainer(
|
| 1078 |
+
config_path="/home/tiger/yy/src/Michelangelo-master/config_edge_1024_error_8enc_8dec_woself_finetune_128to512.yaml",
|
| 1079 |
+
rank=rank,
|
| 1080 |
+
world_size=world_size,
|
| 1081 |
+
local_rank=local_rank
|
| 1082 |
+
)
|
| 1083 |
+
trainer.train()
|
| 1084 |
+
|
| 1085 |
+
# Clean up the process group
|
| 1086 |
+
dist.destroy_process_group()
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
if __name__ == '__main__':
|
| 1090 |
+
main()
|
trellis/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import models
|
| 2 |
+
from . import modules
|
| 3 |
+
from . import pipelines
|
| 4 |
+
from . import renderers
|
| 5 |
+
from . import representations
|
| 6 |
+
from . import utils
|
trellis/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (344 Bytes). View file
|
|
|
trellis/datasets/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
__attributes = {
|
| 4 |
+
'SparseStructure': 'sparse_structure',
|
| 5 |
+
|
| 6 |
+
'SparseFeat2Render': 'sparse_feat2render',
|
| 7 |
+
'SLat2Render':'structured_latent2render',
|
| 8 |
+
'Slat2RenderGeo':'structured_latent2render',
|
| 9 |
+
|
| 10 |
+
'SparseStructureLatent': 'sparse_structure_latent',
|
| 11 |
+
'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
|
| 12 |
+
'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
|
| 13 |
+
|
| 14 |
+
'SLat': 'structured_latent',
|
| 15 |
+
'TextConditionedSLat': 'structured_latent',
|
| 16 |
+
'ImageConditionedSLat': 'structured_latent',
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
__submodules = []
|
| 20 |
+
|
| 21 |
+
__all__ = list(__attributes.keys()) + __submodules
|
| 22 |
+
|
| 23 |
+
def __getattr__(name):
|
| 24 |
+
if name not in globals():
|
| 25 |
+
if name in __attributes:
|
| 26 |
+
module_name = __attributes[name]
|
| 27 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
| 28 |
+
globals()[name] = getattr(module, name)
|
| 29 |
+
elif name in __submodules:
|
| 30 |
+
module = importlib.import_module(f".{name}", __name__)
|
| 31 |
+
globals()[name] = module
|
| 32 |
+
else:
|
| 33 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
| 34 |
+
return globals()[name]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# For Pylance
|
| 38 |
+
if __name__ == '__main__':
|
| 39 |
+
from .sparse_structure import SparseStructure
|
| 40 |
+
|
| 41 |
+
from .sparse_feat2render import SparseFeat2Render
|
| 42 |
+
from .structured_latent2render import (
|
| 43 |
+
SLat2Render,
|
| 44 |
+
Slat2RenderGeo,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
from .sparse_structure_latent import (
|
| 48 |
+
SparseStructureLatent,
|
| 49 |
+
TextConditionedSparseStructureLatent,
|
| 50 |
+
ImageConditionedSparseStructureLatent,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
from .structured_latent import (
|
| 54 |
+
SLat,
|
| 55 |
+
TextConditionedSLat,
|
| 56 |
+
ImageConditionedSLat,
|
| 57 |
+
)
|
| 58 |
+
|
trellis/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
trellis/datasets/__pycache__/components.cpython-310.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
trellis/datasets/__pycache__/sparse_structure_latent.cpython-310.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
trellis/datasets/components.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StandardDatasetBase(Dataset):
|
| 13 |
+
"""
|
| 14 |
+
Base class for standard datasets.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
roots (str): paths to the dataset
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
roots: str,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.roots = roots.split(',')
|
| 25 |
+
self.instances = []
|
| 26 |
+
self.metadata = pd.DataFrame()
|
| 27 |
+
|
| 28 |
+
self._stats = {}
|
| 29 |
+
for root in self.roots:
|
| 30 |
+
key = os.path.basename(root)
|
| 31 |
+
self._stats[key] = {}
|
| 32 |
+
metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
|
| 33 |
+
self._stats[key]['Total'] = len(metadata)
|
| 34 |
+
metadata, stats = self.filter_metadata(metadata)
|
| 35 |
+
self._stats[key].update(stats)
|
| 36 |
+
self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
|
| 37 |
+
metadata.set_index('sha256', inplace=True)
|
| 38 |
+
self.metadata = pd.concat([self.metadata, metadata])
|
| 39 |
+
|
| 40 |
+
@abstractmethod
|
| 41 |
+
def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def get_instance(self, root: str, instance: str) -> Dict[str, Any]:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.instances)
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, index) -> Dict[str, Any]:
|
| 52 |
+
try:
|
| 53 |
+
root, instance = self.instances[index]
|
| 54 |
+
return self.get_instance(root, instance)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(e)
|
| 57 |
+
return self.__getitem__(np.random.randint(0, len(self)))
|
| 58 |
+
|
| 59 |
+
def __str__(self):
|
| 60 |
+
lines = []
|
| 61 |
+
lines.append(self.__class__.__name__)
|
| 62 |
+
lines.append(f' - Total instances: {len(self)}')
|
| 63 |
+
lines.append(f' - Sources:')
|
| 64 |
+
for key, stats in self._stats.items():
|
| 65 |
+
lines.append(f' - {key}:')
|
| 66 |
+
for k, v in stats.items():
|
| 67 |
+
lines.append(f' - {k}: {v}')
|
| 68 |
+
return '\n'.join(lines)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TextConditionedMixin:
|
| 72 |
+
def __init__(self, roots, **kwargs):
|
| 73 |
+
super().__init__(roots, **kwargs)
|
| 74 |
+
self.captions = {}
|
| 75 |
+
for instance in self.instances:
|
| 76 |
+
sha256 = instance[1]
|
| 77 |
+
self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions'])
|
| 78 |
+
|
| 79 |
+
def filter_metadata(self, metadata):
|
| 80 |
+
metadata, stats = super().filter_metadata(metadata)
|
| 81 |
+
metadata = metadata[metadata['captions'].notna()]
|
| 82 |
+
stats['With captions'] = len(metadata)
|
| 83 |
+
return metadata, stats
|
| 84 |
+
|
| 85 |
+
def get_instance(self, root, instance):
|
| 86 |
+
pack = super().get_instance(root, instance)
|
| 87 |
+
text = np.random.choice(self.captions[instance])
|
| 88 |
+
pack['cond'] = text
|
| 89 |
+
return pack
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ImageConditionedMixin:
|
| 93 |
+
def __init__(self, roots, *, image_size=518, **kwargs):
|
| 94 |
+
self.image_size = image_size
|
| 95 |
+
super().__init__(roots, **kwargs)
|
| 96 |
+
|
| 97 |
+
def filter_metadata(self, metadata):
|
| 98 |
+
metadata, stats = super().filter_metadata(metadata)
|
| 99 |
+
metadata = metadata[metadata[f'cond_rendered']]
|
| 100 |
+
stats['Cond rendered'] = len(metadata)
|
| 101 |
+
return metadata, stats
|
| 102 |
+
|
| 103 |
+
def get_instance(self, root, instance):
|
| 104 |
+
pack = super().get_instance(root, instance)
|
| 105 |
+
|
| 106 |
+
image_root = os.path.join(root, 'renders_cond', instance)
|
| 107 |
+
with open(os.path.join(image_root, 'transforms.json')) as f:
|
| 108 |
+
metadata = json.load(f)
|
| 109 |
+
n_views = len(metadata['frames'])
|
| 110 |
+
view = np.random.randint(n_views)
|
| 111 |
+
metadata = metadata['frames'][view]
|
| 112 |
+
|
| 113 |
+
image_path = os.path.join(image_root, metadata['file_path'])
|
| 114 |
+
image = Image.open(image_path)
|
| 115 |
+
|
| 116 |
+
alpha = np.array(image.getchannel(3))
|
| 117 |
+
bbox = np.array(alpha).nonzero()
|
| 118 |
+
bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
|
| 119 |
+
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
|
| 120 |
+
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
|
| 121 |
+
aug_size_ratio = 1.2
|
| 122 |
+
aug_hsize = hsize * aug_size_ratio
|
| 123 |
+
aug_center_offset = [0, 0]
|
| 124 |
+
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
|
| 125 |
+
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
|
| 126 |
+
image = image.crop(aug_bbox)
|
| 127 |
+
|
| 128 |
+
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
| 129 |
+
alpha = image.getchannel(3)
|
| 130 |
+
image = image.convert('RGB')
|
| 131 |
+
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
|
| 132 |
+
alpha = torch.tensor(np.array(alpha)).float() / 255.0
|
| 133 |
+
image = image * alpha.unsqueeze(0)
|
| 134 |
+
pack['cond'] = image
|
| 135 |
+
|
| 136 |
+
return pack
|
| 137 |
+
|
trellis/datasets/sparse_feat2render.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
import utils3d.torch
|
| 8 |
+
from ..modules.sparse.basic import SparseTensor
|
| 9 |
+
from .components import StandardDatasetBase
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SparseFeat2Render(StandardDatasetBase):
|
| 13 |
+
"""
|
| 14 |
+
SparseFeat2Render dataset.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
roots (str): paths to the dataset
|
| 18 |
+
image_size (int): size of the image
|
| 19 |
+
model (str): model name
|
| 20 |
+
resolution (int): resolution of the data
|
| 21 |
+
min_aesthetic_score (float): minimum aesthetic score
|
| 22 |
+
max_num_voxels (int): maximum number of voxels
|
| 23 |
+
"""
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
roots: str,
|
| 27 |
+
image_size: int,
|
| 28 |
+
model: str = 'dinov2_vitl14_reg',
|
| 29 |
+
resolution: int = 64,
|
| 30 |
+
min_aesthetic_score: float = 5.0,
|
| 31 |
+
max_num_voxels: int = 32768,
|
| 32 |
+
):
|
| 33 |
+
self.image_size = image_size
|
| 34 |
+
self.model = model
|
| 35 |
+
self.resolution = resolution
|
| 36 |
+
self.min_aesthetic_score = min_aesthetic_score
|
| 37 |
+
self.max_num_voxels = max_num_voxels
|
| 38 |
+
self.value_range = (0, 1)
|
| 39 |
+
|
| 40 |
+
super().__init__(roots)
|
| 41 |
+
|
| 42 |
+
def filter_metadata(self, metadata):
|
| 43 |
+
stats = {}
|
| 44 |
+
metadata = metadata[metadata[f'feature_{self.model}']]
|
| 45 |
+
stats['With features'] = len(metadata)
|
| 46 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
| 47 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
| 48 |
+
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
|
| 49 |
+
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
|
| 50 |
+
return metadata, stats
|
| 51 |
+
|
| 52 |
+
def _get_image(self, root, instance):
|
| 53 |
+
with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
|
| 54 |
+
metadata = json.load(f)
|
| 55 |
+
n_views = len(metadata['frames'])
|
| 56 |
+
view = np.random.randint(n_views)
|
| 57 |
+
metadata = metadata['frames'][view]
|
| 58 |
+
fov = metadata['camera_angle_x']
|
| 59 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
| 60 |
+
c2w = torch.tensor(metadata['transform_matrix'])
|
| 61 |
+
c2w[:3, 1:3] *= -1
|
| 62 |
+
extrinsics = torch.inverse(c2w)
|
| 63 |
+
|
| 64 |
+
image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
|
| 65 |
+
image = Image.open(image_path)
|
| 66 |
+
alpha = image.getchannel(3)
|
| 67 |
+
image = image.convert('RGB')
|
| 68 |
+
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
| 69 |
+
alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
| 70 |
+
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
|
| 71 |
+
alpha = torch.tensor(np.array(alpha)).float() / 255.0
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
'image': image,
|
| 75 |
+
'alpha': alpha,
|
| 76 |
+
'extrinsics': extrinsics,
|
| 77 |
+
'intrinsics': intrinsics,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def _get_feat(self, root, instance):
|
| 81 |
+
DATA_RESOLUTION = 64
|
| 82 |
+
feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz')
|
| 83 |
+
feats = np.load(feats_path, allow_pickle=True)
|
| 84 |
+
coords = torch.tensor(feats['indices']).int()
|
| 85 |
+
feats = torch.tensor(feats['patchtokens']).float()
|
| 86 |
+
|
| 87 |
+
if self.resolution != DATA_RESOLUTION:
|
| 88 |
+
factor = DATA_RESOLUTION // self.resolution
|
| 89 |
+
coords = coords // factor
|
| 90 |
+
coords, idx = coords.unique(return_inverse=True, dim=0)
|
| 91 |
+
feats = torch.scatter_reduce(
|
| 92 |
+
torch.zeros(coords.shape[0], feats.shape[1], device=feats.device),
|
| 93 |
+
dim=0,
|
| 94 |
+
index=idx.unsqueeze(-1).expand(-1, feats.shape[1]),
|
| 95 |
+
src=feats,
|
| 96 |
+
reduce='mean'
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
'coords': coords,
|
| 101 |
+
'feats': feats,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
@torch.no_grad()
|
| 105 |
+
def visualize_sample(self, sample: dict):
|
| 106 |
+
return sample['image']
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def collate_fn(batch):
|
| 110 |
+
pack = {}
|
| 111 |
+
coords = []
|
| 112 |
+
for i, b in enumerate(batch):
|
| 113 |
+
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
|
| 114 |
+
coords = torch.cat(coords)
|
| 115 |
+
feats = torch.cat([b['feats'] for b in batch])
|
| 116 |
+
pack['feats'] = SparseTensor(
|
| 117 |
+
coords=coords,
|
| 118 |
+
feats=feats,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
pack['image'] = torch.stack([b['image'] for b in batch])
|
| 122 |
+
pack['alpha'] = torch.stack([b['alpha'] for b in batch])
|
| 123 |
+
pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch])
|
| 124 |
+
pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch])
|
| 125 |
+
|
| 126 |
+
return pack
|
| 127 |
+
|
| 128 |
+
def get_instance(self, root, instance):
|
| 129 |
+
image = self._get_image(root, instance)
|
| 130 |
+
feat = self._get_feat(root, instance)
|
| 131 |
+
return {
|
| 132 |
+
**image,
|
| 133 |
+
**feat,
|
| 134 |
+
}
|
trellis/datasets/sparse_structure.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import Union
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
import utils3d
|
| 9 |
+
from .components import StandardDatasetBase
|
| 10 |
+
from ..representations.octree import DfsOctree as Octree
|
| 11 |
+
from ..renderers import OctreeRenderer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SparseStructure(StandardDatasetBase):
|
| 15 |
+
"""
|
| 16 |
+
Sparse structure dataset
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
roots (str): path to the dataset
|
| 20 |
+
resolution (int): resolution of the voxel grid
|
| 21 |
+
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self,
|
| 25 |
+
roots,
|
| 26 |
+
resolution: int = 64,
|
| 27 |
+
min_aesthetic_score: float = 5.0,
|
| 28 |
+
):
|
| 29 |
+
self.resolution = resolution
|
| 30 |
+
self.min_aesthetic_score = min_aesthetic_score
|
| 31 |
+
self.value_range = (0, 1)
|
| 32 |
+
|
| 33 |
+
super().__init__(roots)
|
| 34 |
+
|
| 35 |
+
def filter_metadata(self, metadata):
|
| 36 |
+
stats = {}
|
| 37 |
+
metadata = metadata[metadata[f'voxelized']]
|
| 38 |
+
stats['Voxelized'] = len(metadata)
|
| 39 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
| 40 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
| 41 |
+
return metadata, stats
|
| 42 |
+
|
| 43 |
+
def get_instance(self, root, instance):
|
| 44 |
+
position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
|
| 45 |
+
coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
|
| 46 |
+
ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
|
| 47 |
+
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
| 48 |
+
return {'ss': ss}
|
| 49 |
+
|
| 50 |
+
@torch.no_grad()
|
| 51 |
+
def visualize_sample(self, ss: Union[torch.Tensor, dict]):
|
| 52 |
+
ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
|
| 53 |
+
|
| 54 |
+
renderer = OctreeRenderer()
|
| 55 |
+
renderer.rendering_options.resolution = 512
|
| 56 |
+
renderer.rendering_options.near = 0.8
|
| 57 |
+
renderer.rendering_options.far = 1.6
|
| 58 |
+
renderer.rendering_options.bg_color = (0, 0, 0)
|
| 59 |
+
renderer.rendering_options.ssaa = 4
|
| 60 |
+
renderer.pipe.primitive = 'voxel'
|
| 61 |
+
|
| 62 |
+
# Build camera
|
| 63 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
| 64 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
| 65 |
+
yaws = [y + yaws_offset for y in yaws]
|
| 66 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
| 67 |
+
|
| 68 |
+
exts = []
|
| 69 |
+
ints = []
|
| 70 |
+
for yaw, pitch in zip(yaws, pitch):
|
| 71 |
+
orig = torch.tensor([
|
| 72 |
+
np.sin(yaw) * np.cos(pitch),
|
| 73 |
+
np.cos(yaw) * np.cos(pitch),
|
| 74 |
+
np.sin(pitch),
|
| 75 |
+
]).float().cuda() * 2
|
| 76 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
| 77 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
| 78 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
| 79 |
+
exts.append(extrinsics)
|
| 80 |
+
ints.append(intrinsics)
|
| 81 |
+
|
| 82 |
+
images = []
|
| 83 |
+
|
| 84 |
+
# Build each representation
|
| 85 |
+
ss = ss.cuda()
|
| 86 |
+
for i in range(ss.shape[0]):
|
| 87 |
+
representation = Octree(
|
| 88 |
+
depth=10,
|
| 89 |
+
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
|
| 90 |
+
device='cuda',
|
| 91 |
+
primitive='voxel',
|
| 92 |
+
sh_degree=0,
|
| 93 |
+
primitive_config={'solid': True},
|
| 94 |
+
)
|
| 95 |
+
coords = torch.nonzero(ss[i, 0], as_tuple=False)
|
| 96 |
+
representation.position = coords.float() / self.resolution
|
| 97 |
+
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
|
| 98 |
+
|
| 99 |
+
image = torch.zeros(3, 1024, 1024).cuda()
|
| 100 |
+
tile = [2, 2]
|
| 101 |
+
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
| 102 |
+
res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
|
| 103 |
+
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
| 104 |
+
images.append(image)
|
| 105 |
+
|
| 106 |
+
return torch.stack(images)
|
| 107 |
+
|