udbbdh commited on
Commit
7340df2
·
verified ·
1 Parent(s): 799d5e6

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. 40w_2000-100000edge_2000-75000active.txt +3 -0
  3. 40w_2000-200000edge_2000-100000active.txt +3 -0
  4. MERGED_DATASET_count_200_2000_10000_train_2000min_100000max.txt +0 -0
  5. MERGED_DATASET_filtered_2000-100000edge_2000-32678active.txt +0 -0
  6. __pycache__/dataset_triposf.cpython-310.pyc +0 -0
  7. __pycache__/dataset_triposf_head.cpython-310.pyc +0 -0
  8. __pycache__/query_point.cpython-310.pyc +0 -0
  9. __pycache__/utils.bresenham_3d_array-192.py310.1.nbc +3 -0
  10. __pycache__/utils.bresenham_3d_array-192.py310.nbi +0 -0
  11. __pycache__/utils.cpython-310.pyc +0 -0
  12. __pycache__/vertex_encoder.cpython-310.pyc +0 -0
  13. config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_addhead.yaml +82 -0
  14. config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_head_woca.yaml +97 -0
  15. config_edge_1024_error_8enc_8dec_woself_finetune_128to512.yaml +101 -0
  16. config_slat_flow_128to256_pointnet_test.yaml +124 -0
  17. dataset_triposf.py +924 -0
  18. dataset_triposf_head.py +1000 -0
  19. debug_viz/step_0_batch_0.ply +0 -0
  20. debug_viz/step_0_batch_1.ply +0 -0
  21. filter_active_voxels.py +106 -0
  22. generate_npz.py +118 -0
  23. mesh_augment.py +79 -0
  24. metric.py +300 -0
  25. metric_cd.py +190 -0
  26. query_point.py +259 -0
  27. test_slat_flow_128to1024_pointnet.py +403 -0
  28. test_slat_flow_128to256_pointnet.py +403 -0
  29. test_slat_vae_128to1024_pointnet.py +0 -0
  30. test_slat_vae_128to1024_pointnet_vae.py +0 -0
  31. test_slat_vae_128to1024_pointnet_vae_addhead.py +0 -0
  32. test_slat_vae_128to1024_pointnet_vae_head.py +1339 -0
  33. test_slat_vae_128to1024_pointnet_vae_head_woca.py +0 -0
  34. test_slat_vae_128to256_pointnet_vae_head.py +1349 -0
  35. test_slat_vae_128to512_pointnet_vae_head.py +1636 -0
  36. train_slat_flow_128to1024_pointnet.py +484 -0
  37. train_slat_vae_512_128to1024_pointnet.py +682 -0
  38. train_slat_vae_512_128to1024_pointnet_addhead.py +788 -0
  39. train_slat_vae_512_128to1024_pointnet_head.py +930 -0
  40. train_slat_vae_512_128to256_pointnet_head.py +917 -0
  41. train_slat_vae_512_128to512_pointnet_head.py +1090 -0
  42. trellis/__init__.py +6 -0
  43. trellis/__pycache__/__init__.cpython-310.pyc +0 -0
  44. trellis/datasets/__init__.py +58 -0
  45. trellis/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  46. trellis/datasets/__pycache__/components.cpython-310.pyc +0 -0
  47. trellis/datasets/__pycache__/sparse_structure_latent.cpython-310.pyc +0 -0
  48. trellis/datasets/components.py +137 -0
  49. trellis/datasets/sparse_feat2render.py +134 -0
  50. trellis/datasets/sparse_structure.py +107 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 40w_2000-100000edge_2000-75000active.txt filter=lfs diff=lfs merge=lfs -text
37
+ 40w_2000-200000edge_2000-100000active.txt filter=lfs diff=lfs merge=lfs -text
38
+ __pycache__/utils.bresenham_3d_array-192.py310.1.nbc filter=lfs diff=lfs merge=lfs -text
40w_2000-100000edge_2000-75000active.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e283bf011c720c1300f7c4621c15af8360df05f0bcfbd71faf984e2d61f6e17
3
+ size 38231752
40w_2000-200000edge_2000-100000active.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46aae1f5d9bce421cb402a69b3b112daf586ed8f470c7de4b1e6d73adc82ef56
3
+ size 44315087
MERGED_DATASET_count_200_2000_10000_train_2000min_100000max.txt ADDED
The diff for this file is too large to render. See raw diff
 
MERGED_DATASET_filtered_2000-100000edge_2000-32678active.txt ADDED
File without changes
__pycache__/dataset_triposf.cpython-310.pyc ADDED
Binary file (24.8 kB). View file
 
__pycache__/dataset_triposf_head.cpython-310.pyc ADDED
Binary file (26.1 kB). View file
 
__pycache__/query_point.cpython-310.pyc ADDED
Binary file (8.77 kB). View file
 
__pycache__/utils.bresenham_3d_array-192.py310.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d783005c5692d0ddfb1fc45298776ac767bde1326be014a6e65e76e41aeef898
3
+ size 113601
__pycache__/utils.bresenham_3d_array-192.py310.nbi ADDED
Binary file (1.59 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (24.5 kB). View file
 
__pycache__/vertex_encoder.cpython-310.pyc ADDED
Binary file (19.2 kB). View file
 
config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_addhead.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "dataset":
2
+ "base_resolution": 1024
3
+ "cache_dir": "/gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to1024_819200_head"
4
+ "cache_filter_path": "/gemini/user/private/zhaotianhao/Triposf/MERGED_DATASET_filtered_2000-75000edge_2000-326780active.txt"
5
+ "filter_active_voxels": true
6
+ "min_resolution": 128
7
+ "n_train_samples": 1024
8
+ "path": "/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/train"
9
+ "renders_dir": "None"
10
+ "sample_type": "dora"
11
+ "experiment":
12
+ "save_dir": "/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs1_128to1024_wolabel_dir_sorted_dora_bigger_addhead"
13
+ "model":
14
+ "add_block_embed": true
15
+ "add_direction": false
16
+ "add_edge_glb_feats": true
17
+ "attn_first": false
18
+ "block_size": 16
19
+ "decoder_blocks_edge":
20
+ - "in_channels": 768
21
+ "model_channels": 768
22
+ "num_blocks": 0
23
+ "num_heads": 0
24
+ "out_channels": 128
25
+ "resolution": 128
26
+ - "in_channels": 128
27
+ "model_channels": 128
28
+ "num_blocks": 0
29
+ "num_heads": 0
30
+ "out_channels": 64
31
+ "resolution": 256
32
+ - "in_channels": 64
33
+ "model_channels": 64
34
+ "num_blocks": 0
35
+ "num_heads": 0
36
+ "out_channels": 32
37
+ "resolution": 512
38
+ "decoder_blocks_vtx":
39
+ - "in_channels": 768
40
+ "model_channels": 768
41
+ "num_blocks": 0
42
+ "num_heads": 0
43
+ "out_channels": 128
44
+ "resolution": 128
45
+ - "in_channels": 128
46
+ "model_channels": 128
47
+ "num_blocks": 0
48
+ "num_heads": 0
49
+ "out_channels": 64
50
+ "resolution": 256
51
+ - "in_channels": 64
52
+ "model_channels": 64
53
+ "num_blocks": 0
54
+ "num_heads": 0
55
+ "out_channels": 32
56
+ "resolution": 512
57
+ "embed_dim": 1024
58
+ "encoder_blocks":
59
+ - "in_channels": 1024
60
+ "model_channels": 768
61
+ "num_blocks": 12
62
+ "num_heads": 12
63
+ "out_channels": 768
64
+ "in_channels": 1024
65
+ "latent_dim": 16
66
+ "model_channels": 384
67
+ "multires": 12
68
+ "pos_encoding": "nerf"
69
+ "pred_direction": true
70
+ "relative_embed": true
71
+ "using_attn": false
72
+ "training":
73
+ "batch_size": 1
74
+ "checkpoint_path": "/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs1_128to1024_wolabel_dir_sorted_dora_bigger/checkpoint_epoch14_batch10432_loss0.1438.pt"
75
+ "from_pretrained": true
76
+ "gamma": 0.95
77
+ "lr": 0.0001
78
+ "max_epochs": 100
79
+ "num_workers": 12
80
+ "save_every": 1
81
+ "start_epoch": 1
82
+ "step_size": 1
config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_head_woca.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ path: /gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/train
3
+ cache_dir: /gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to1024_819200_head
4
+
5
+ renders_dir: None
6
+
7
+ filter_active_voxels: true
8
+ cache_filter_path: /gemini/user/private/zhaotianhao/Triposf/MERGED_DATASET_filtered_2000-75000edge_2000-326780active.txt
9
+
10
+ base_resolution: 1024
11
+ min_resolution: 128
12
+
13
+ n_train_samples: 1024
14
+ sample_type: dora
15
+
16
+ model:
17
+ pred_direction: false
18
+ relative_embed: true
19
+ using_attn: false
20
+ add_block_embed: true
21
+ multires: 12
22
+
23
+ embed_dim: 1024 #64
24
+ in_channels: 1024 #64
25
+ model_channels: 384
26
+ latent_dim: 16
27
+
28
+ block_size: 16
29
+ pos_encoding: 'nerf'
30
+ # pos_encoding: 'embedding'
31
+ attn_first: false
32
+
33
+ add_edge_glb_feats: true
34
+ add_direction: false
35
+
36
+ encoder_blocks:
37
+ - in_channels: 1024
38
+ model_channels: 512
39
+ num_blocks: 8
40
+ num_heads: 8
41
+ out_channels: 512
42
+
43
+ decoder_blocks_edge:
44
+ - in_channels: 512
45
+ model_channels: 512
46
+ num_blocks: 0
47
+ num_heads: 0
48
+ out_channels: 256
49
+ resolution: 128
50
+ - in_channels: 256
51
+ model_channels: 256
52
+ num_blocks: 0
53
+ num_heads: 0
54
+ out_channels: 128
55
+ resolution: 256
56
+ - in_channels: 128
57
+ model_channels: 128
58
+ num_blocks: 0
59
+ num_heads: 0
60
+ out_channels: 64
61
+ resolution: 512
62
+
63
+ decoder_blocks_vtx:
64
+ - in_channels: 512
65
+ model_channels: 512
66
+ num_blocks: 0
67
+ num_heads: 0
68
+ out_channels: 256
69
+ resolution: 128
70
+ - in_channels: 256
71
+ model_channels: 256
72
+ num_blocks: 0
73
+ num_heads: 0
74
+ out_channels: 128
75
+ resolution: 256
76
+ - in_channels: 128
77
+ model_channels: 128
78
+ num_blocks: 0
79
+ num_heads: 0
80
+ out_channels: 64
81
+ resolution: 512
82
+
83
+
84
+ training:
85
+ batch_size: 2
86
+ lr: 1.e-4
87
+ step_size: 1
88
+ gamma: 0.95
89
+ save_every: 5
90
+ start_epoch: 0
91
+ max_epochs: 200
92
+ num_workers: 4
93
+ from_pretrained: false
94
+ checkpoint_path: /root/Trisf/experiments_edge/vae/train_9w_200_2000face/9w_128to1024/checkpoint_epoch14_batch5216_loss0.2745.pt
95
+
96
+ experiment:
97
+ save_dir: "/root/Trisf/experiments_edge/vae/{dataset_name}_9w_200_2000face/shapenet_bs{batch_size}_128to1024_wolabel_dir_sorted_dora_small_allasyloss"
config_edge_1024_error_8enc_8dec_woself_finetune_128to512.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ # path: /gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/train
3
+ # cache_dir: /gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to1024_819200_head
4
+
5
+ path: /home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01
6
+ cache_dir: /home/tiger/yy/src/dataset_cache/unique_files_glb_under6000face_2degree_30ratio_0.01
7
+
8
+ renders_dir: None
9
+
10
+ filter_active_voxels: true
11
+ cache_filter_path: /home/tiger/yy/src/Michelangelo-master/40w_2000-200000edge_2000-100000active.txt # 5epoch
12
+ # cache_filter_path: /home/tiger/yy/src/Michelangelo-master/40w_2000-100000edge_2000-75000active.txt # 0-4 epoch
13
+
14
+ base_resolution: 512
15
+ min_resolution: 128
16
+
17
+ n_train_samples: 1024
18
+ sample_type: dora
19
+
20
+ model:
21
+ pred_direction: false
22
+ relative_embed: true
23
+ using_attn: false
24
+ add_block_embed: true
25
+ multires: 12
26
+
27
+ embed_dim: 1024 #64
28
+ in_channels: 1024 #64
29
+ model_channels: 384
30
+ latent_dim: 16
31
+
32
+ block_size: 16
33
+ pos_encoding: 'nerf'
34
+ attn_first: false
35
+
36
+ add_edge_glb_feats: true
37
+ add_direction: false
38
+
39
+ encoder_blocks:
40
+ - in_channels: 1024
41
+ model_channels: 512
42
+ num_blocks: 8
43
+ num_heads: 8
44
+ out_channels: 512
45
+
46
+ decoder_blocks_edge:
47
+ - in_channels: 512
48
+ model_channels: 512
49
+ num_blocks: 0
50
+ num_heads: 0
51
+ out_channels: 256
52
+ resolution: 128
53
+ - in_channels: 256
54
+ model_channels: 256
55
+ num_blocks: 0
56
+ num_heads: 0
57
+ out_channels: 128
58
+ resolution: 256
59
+ # - in_channels: 128
60
+ # model_channels: 128
61
+ # num_blocks: 0
62
+ # num_heads: 0
63
+ # out_channels: 64
64
+ # resolution: 512
65
+
66
+ decoder_blocks_vtx:
67
+ - in_channels: 512
68
+ model_channels: 512
69
+ num_blocks: 0
70
+ num_heads: 0
71
+ out_channels: 256
72
+ resolution: 128
73
+ - in_channels: 256
74
+ model_channels: 256
75
+ num_blocks: 0
76
+ num_heads: 0
77
+ out_channels: 128
78
+ resolution: 256
79
+ # - in_channels: 128
80
+ # model_channels: 128
81
+ # num_blocks: 0
82
+ # num_heads: 0
83
+ # out_channels: 64
84
+ # resolution: 512
85
+
86
+
87
+ training:
88
+ batch_size: 2
89
+ lr: 4.e-5
90
+ step_size: 1
91
+ gamma: 0.95
92
+ save_every: 1
93
+ start_epoch: 0
94
+ max_epochs: 10
95
+ num_workers: 32
96
+ from_pretrained: true
97
+ # checkpoint_path: /home/tiger/yy/src/checkpoint_epoch2_batch30000_loss0.1829.pt
98
+ checkpoint_path: /home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch7_batch10000_loss0.1175.pt
99
+
100
+ experiment:
101
+ save_dir: "/home/tiger/yy/src/checkpoints/vae/{dataset_name}/shapenet_bs{batch_size}_128to512_wolabel_dir_sorted_dora_small_lowlr"
config_slat_flow_128to256_pointnet_test.yaml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pred_direction: true
3
+ relative_embed: true
4
+ using_attn: false
5
+ add_block_embed: true
6
+ multires: 12
7
+
8
+ embed_dim: 1024
9
+ in_channels: 1024
10
+ model_channels: 384
11
+ latent_dim: 16
12
+
13
+ block_size: 16
14
+ pos_encoding: 'nerf'
15
+ attn_first: false
16
+
17
+ add_edge_glb_feats: true
18
+ add_direction: false
19
+
20
+ encoder_blocks:
21
+ - in_channels: 1024
22
+ model_channels: 512
23
+ num_blocks: 8
24
+ num_heads: 8
25
+ out_channels: 512
26
+
27
+ decoder_blocks_edge:
28
+ - in_channels: 512
29
+ model_channels: 512
30
+ num_blocks: 0
31
+ num_heads: 0
32
+ out_channels: 128
33
+ resolution: 128
34
+ # - in_channels: 128
35
+ # model_channels: 128
36
+ # num_blocks: 0
37
+ # num_heads: 0
38
+ # out_channels: 64
39
+ # resolution: 256
40
+ # - in_channels: 64
41
+ # model_channels: 64
42
+ # num_blocks: 0
43
+ # num_heads: 0
44
+ # out_channels: 32
45
+ # resolution: 512
46
+
47
+ decoder_blocks_vtx:
48
+ - in_channels: 512
49
+ model_channels: 512
50
+ num_blocks: 0
51
+ num_heads: 0
52
+ out_channels: 128
53
+ resolution: 128
54
+ # - in_channels: 128
55
+ # model_channels: 128
56
+ # num_blocks: 0
57
+ # num_heads: 0
58
+ # out_channels: 64
59
+ # resolution: 256
60
+ # - in_channels: 64
61
+ # model_channels: 64
62
+ # num_blocks: 0
63
+ # num_heads: 0
64
+ # out_channels: 32
65
+ # resolution: 512
66
+
67
+ "t_schedule":
68
+ "name": "logitNormal"
69
+ "args":
70
+ "mean": 1.0
71
+ "std": 1.0
72
+
73
+ "sigma_min": 1.e-5
74
+
75
+ training:
76
+ batch_size: 1
77
+ lr: 1.e-4
78
+ step_size: 20
79
+ gamma: 0.95
80
+ save_every: 2000
81
+ start_epoch: 0
82
+ max_epochs: 300000
83
+ num_workers: 4
84
+
85
+ output_dir: /gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope
86
+ clip_model_path: "/gemini/user/private/zhaotianhao/clip-vit-large-patch14"
87
+ dinov2_model_path: "/gemini/user/private/zhaotianhao/dinov2-large"
88
+
89
+ vae_path: /gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to256_dir_sorted_dora_head_small/checkpoint_epoch13_batch6000_loss0.1381.pt
90
+ denoiser_checkpoint_path: false
91
+
92
+
93
+ dataset:
94
+ path: /gemini/user/private/zhaotianhao/data/trellis_clean_mesh
95
+ path: /gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/train
96
+
97
+ cache_dir: /gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to256_819200_head
98
+
99
+ renders_dir: None
100
+ filter_active_voxels: false
101
+ cache_filter_path: None
102
+
103
+ base_resolution: 1024
104
+ min_resolution: 128
105
+
106
+ n_train_samples: 1024
107
+ sample_type: dora
108
+
109
+ flow:
110
+ "resolution": 128
111
+ "in_channels": 16
112
+ "out_channels": 16
113
+ "model_channels": 768
114
+ "cond_channels": 1024
115
+ "num_blocks": 8
116
+ "num_heads": 8
117
+ "mlp_ratio": 4
118
+ "patch_size": 2
119
+ "num_io_res_blocks": 2
120
+ "io_block_channels": [128]
121
+ "pe_mode": "rope"
122
+ "qk_rms_norm": true
123
+ "qk_rms_norm_cross": false
124
+ "use_fp16": false
dataset_triposf.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from typing import *
4
+ import trimesh
5
+ import numpy as np
6
+ from torch.utils.data import Dataset
7
+ import torch.nn.functional as F
8
+ # from utils import quantize_vertices
9
+ from utils import get_voxel_line
10
+ import random
11
+ import hashlib
12
+ import json
13
+ from tqdm import tqdm
14
+ from torchvision import transforms
15
+ from PIL import Image
16
+ import rembg
17
+ import open3d as o3d
18
+ from trimesh import grouping
19
+
20
+ def normalize_mesh(mesh_path):
21
+ scene = trimesh.load(mesh_path, process=False, force='scene')
22
+ meshes = []
23
+ for node_name in scene.graph.nodes_geometry:
24
+ geom_name = scene.graph[node_name][1]
25
+ geometry = scene.geometry[geom_name]
26
+ transform = scene.graph[node_name][0]
27
+ if isinstance(geometry, trimesh.Trimesh):
28
+ geometry.apply_transform(transform)
29
+ meshes.append(geometry)
30
+
31
+ mesh = trimesh.util.concatenate(meshes)
32
+
33
+ center = mesh.bounding_box.centroid
34
+ mesh.apply_translation(-center)
35
+ scale = max(mesh.bounding_box.extents)
36
+ mesh.apply_scale(2.0 / scale * 0.5)
37
+
38
+ return mesh
39
+
40
+ def quantize_vertices(vertices: torch.Tensor, res: int):
41
+ """
42
+ Quantize normalized vertices (range approx [-0.5, 0.5]) to integer grid [0, res-1].
43
+ """
44
+ normalized = vertices + 0.5
45
+ scaled = normalized * res
46
+ quantized = torch.floor(scaled).clamp(0, res - 1).int()
47
+ return quantized
48
+
49
+ def sample_edges_dora(tm_mesh, n_samples):
50
+ adj_faces = tm_mesh.face_adjacency
51
+ adj_edges = tm_mesh.face_adjacency_edges
52
+
53
+ internal_data = None
54
+ if len(adj_faces) > 0:
55
+ n0 = tm_mesh.face_normals[adj_faces[:, 0]]
56
+ n1 = tm_mesh.face_normals[adj_faces[:, 1]]
57
+ sum_normals = n0 + n1
58
+ norms = np.linalg.norm(sum_normals, axis=1, keepdims=True)
59
+ norms[norms < 1e-6] = 1.0
60
+ int_normals = sum_normals / norms
61
+
62
+ int_v_start = tm_mesh.vertices[adj_edges[:, 0]]
63
+ int_v_end = tm_mesh.vertices[adj_edges[:, 1]]
64
+
65
+ faces_pair = tm_mesh.faces[adj_faces]
66
+ sum_face_indices = np.sum(faces_pair, axis=2)
67
+ sum_edge_indices = np.sum(adj_edges, axis=1)
68
+
69
+ unique_idx_0 = sum_face_indices[:, 0] - sum_edge_indices
70
+ unique_idx_1 = sum_face_indices[:, 1] - sum_edge_indices
71
+
72
+ v_unique_0 = tm_mesh.vertices[unique_idx_0]
73
+ v_unique_1 = tm_mesh.vertices[unique_idx_1]
74
+ int_v_virtual = (v_unique_0 + v_unique_1) * 0.5
75
+
76
+ internal_data = (int_v_start, int_v_end, int_normals, int_v_virtual)
77
+
78
+ edges_sorted = tm_mesh.edges_sorted
79
+
80
+ if len(edges_sorted) == 0:
81
+ boundary_data = None
82
+ else:
83
+ boundary_group = grouping.group_rows(edges_sorted, require_count=1)
84
+
85
+ boundary_data = None
86
+ if len(boundary_group) > 0:
87
+ boundary_indices = np.concatenate([np.atleast_1d(g) for g in boundary_group])
88
+
89
+ boundary_face_indices = boundary_indices // 3
90
+
91
+ bnd_normals = tm_mesh.face_normals[boundary_face_indices]
92
+
93
+ bnd_edge_v_indices = edges_sorted[boundary_indices]
94
+ bnd_v_start = tm_mesh.vertices[bnd_edge_v_indices[:, 0]]
95
+ bnd_v_end = tm_mesh.vertices[bnd_edge_v_indices[:, 1]]
96
+
97
+ boundary_face_v_indices = tm_mesh.faces[boundary_face_indices]
98
+
99
+ sum_face = np.sum(boundary_face_v_indices, axis=1)
100
+ sum_edge = np.sum(bnd_edge_v_indices, axis=1)
101
+ unique_idx = sum_face - sum_edge
102
+
103
+ bnd_v_virtual = tm_mesh.vertices[unique_idx]
104
+
105
+ boundary_data = (bnd_v_start, bnd_v_end, bnd_normals, bnd_v_virtual)
106
+
107
+ if internal_data is None and boundary_data is None:
108
+ return None, None, None
109
+
110
+ parts_start, parts_end, parts_norm, parts_virt = [], [], [], []
111
+
112
+ if internal_data is not None:
113
+ parts_start.append(internal_data[0])
114
+ parts_end.append(internal_data[1])
115
+ parts_norm.append(internal_data[2])
116
+ parts_virt.append(internal_data[3])
117
+
118
+ if boundary_data is not None:
119
+ parts_start.append(boundary_data[0])
120
+ parts_end.append(boundary_data[1])
121
+ parts_norm.append(boundary_data[2])
122
+ parts_virt.append(boundary_data[3])
123
+
124
+ if not parts_start:
125
+ return None, None, None
126
+
127
+ all_v_start = np.concatenate(parts_start, axis=0)
128
+ all_v_end = np.concatenate(parts_end, axis=0)
129
+ all_normals = np.concatenate(parts_norm, axis=0)
130
+ all_v_virtual = np.concatenate(parts_virt, axis=0)
131
+
132
+ edge_vectors = all_v_end - all_v_start
133
+ edge_lengths = np.linalg.norm(edge_vectors, axis=1)
134
+ total_length = np.sum(edge_lengths)
135
+
136
+ if total_length < 1e-9:
137
+ probs = np.ones(len(edge_lengths)) / len(edge_lengths)
138
+ else:
139
+ probs = edge_lengths / total_length
140
+ probs = probs / probs.sum()
141
+
142
+ chosen_indices = np.random.choice(len(edge_lengths), size=n_samples, p=probs)
143
+
144
+ t = np.random.rand(n_samples, 1)
145
+
146
+ sel_v_start = all_v_start[chosen_indices]
147
+ sel_v_end = all_v_end[chosen_indices]
148
+ sel_normals = all_normals[chosen_indices]
149
+ sel_v_virtual = all_v_virtual[chosen_indices]
150
+
151
+ sampled_points = sel_v_start + (sel_v_end - sel_v_start) * t
152
+
153
+ vertex_triplets = np.stack([sel_v_start, sel_v_end, sel_v_virtual], axis=1).astype(np.float32)
154
+
155
+ return sampled_points.astype(np.float32), sel_normals.astype(np.float32), vertex_triplets
156
+
157
+ def load_quantized_mesh_dora(
158
+ mesh_path,
159
+ mesh_load=None,
160
+ volume_resolution=256,
161
+ use_normals=True,
162
+ pc_sample_number=4096000,
163
+ edge_sample_ratio=0.2
164
+ ):
165
+ cube_dilate = np.array([
166
+ [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, -1], [0, -1, 0], [0, 1, 1], [0, -1, 1], [0, 1, -1], [0, -1, -1],
167
+ [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 0, -1], [1, -1, 0], [1, 1, 1], [1, -1, 1], [1, 1, -1], [1, -1, -1],
168
+ [-1, 0, 0], [-1, 0, 1], [-1, 1, 0], [-1, 0, -1], [-1, -1, 0], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1], [-1, -1, -1]
169
+ ]) / (volume_resolution * 4 - 1)
170
+
171
+ if mesh_load is None:
172
+ mesh_o3d = o3d.io.read_triangle_mesh(mesh_path)
173
+ vertices = np.clip(np.asarray(mesh_o3d.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
174
+ faces = np.asarray(mesh_o3d.triangles)
175
+ mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
176
+ else:
177
+ vertices = np.clip(np.asarray(mesh_load.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
178
+ faces = np.asarray(mesh_load.faces)
179
+ mesh_o3d = o3d.geometry.TriangleMesh()
180
+ mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
181
+ mesh_o3d.triangles = o3d.utility.Vector3iVector(faces)
182
+
183
+ tm_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
184
+
185
+ n_edge_samples = int(pc_sample_number * edge_sample_ratio)
186
+ p_edge, n_edge, triplets_edge = sample_edges_dora(tm_mesh, n_edge_samples)
187
+
188
+ if p_edge is None:
189
+ # print('p_edge is none!')
190
+ n_surface_samples = pc_sample_number
191
+ else:
192
+ # print('p_edge is right!')
193
+ n_surface_samples = pc_sample_number - n_edge_samples
194
+
195
+ p_surf, idx_surf = tm_mesh.sample(n_surface_samples, return_index=True)
196
+ p_surf = p_surf.astype(np.float32)
197
+ n_surf = tm_mesh.face_normals[idx_surf].astype(np.float32)
198
+
199
+ v_indices_surf = faces[idx_surf]
200
+ triplets_surf = vertices[v_indices_surf]
201
+
202
+ if p_edge is None:
203
+ final_points = p_surf
204
+ final_normals = n_surf
205
+ final_triplets = triplets_surf
206
+ else:
207
+ final_points = np.concatenate([p_surf, p_edge.astype(np.float32)], axis=0)
208
+ if use_normals:
209
+ final_normals = np.concatenate([n_surf, n_edge.astype(np.float32)], axis=0)
210
+
211
+ final_triplets = np.concatenate([triplets_surf, triplets_edge], axis=0)
212
+
213
+ voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
214
+ mesh_o3d,
215
+ voxel_size=1. / volume_resolution,
216
+ min_bound=[-0.5, -0.5, -0.5],
217
+ max_bound=[0.5, 0.5, 0.5]
218
+ )
219
+ voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
220
+
221
+ voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
222
+ o3d.geometry.PointCloud(
223
+ o3d.utility.Vector3dVector(
224
+ np.clip(
225
+ (final_points[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
226
+ -0.5 + 1e-6, 0.5 - 1e-6)
227
+ )
228
+ ),
229
+ voxel_size=1. / volume_resolution,
230
+ min_bound=[-0.5, -0.5, -0.5],
231
+ max_bound=[0.5, 0.5, 0.5]
232
+ )
233
+ voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
234
+ voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
235
+
236
+ features_list = [torch.from_numpy(final_points)]
237
+
238
+ if use_normals:
239
+ features_list.append(torch.from_numpy(final_normals))
240
+
241
+ view_dtype = np.dtype((np.void, final_triplets.dtype.itemsize * final_triplets.shape[-1]))
242
+ v_view = final_triplets.view(view_dtype).squeeze(-1)
243
+
244
+ sort_idx = np.argsort(v_view, axis=1)
245
+
246
+ batch_indices = np.arange(final_triplets.shape[0])[:, None]
247
+ v_sorted = final_triplets[batch_indices, sort_idx]
248
+
249
+ v1 = v_sorted[:, 0, :]
250
+ v2 = v_sorted[:, 1, :]
251
+ v3 = v_sorted[:, 2, :]
252
+
253
+ dir1 = v1 - final_points
254
+ dir2 = v2 - final_points
255
+ dir3 = v3 - final_points
256
+
257
+ features_list.append(torch.Tensor(dir1.astype(np.float32)))
258
+ features_list.append(torch.Tensor(dir2.astype(np.float32)))
259
+ features_list.append(torch.Tensor(dir3.astype(np.float32)))
260
+
261
+ points_sample = torch.cat(features_list, axis=-1)
262
+
263
+ return voxels, points_sample
264
+
265
+ def load_quantized_mesh_original(
266
+ mesh_path,
267
+ mesh_load=None,
268
+ volume_resolution=256,
269
+ use_normals=True,
270
+ pc_sample_number=4096000,
271
+ ):
272
+ cube_dilate = np.array(
273
+ [
274
+ [0, 0, 0],
275
+ [0, 0, 1],
276
+ [0, 1, 0],
277
+ [0, 0, -1],
278
+ [0, -1, 0],
279
+ [0, 1, 1],
280
+ [0, -1, 1],
281
+ [0, 1, -1],
282
+ [0, -1, -1],
283
+
284
+ [1, 0, 0],
285
+ [1, 0, 1],
286
+ [1, 1, 0],
287
+ [1, 0, -1],
288
+ [1, -1, 0],
289
+ [1, 1, 1],
290
+ [1, -1, 1],
291
+ [1, 1, -1],
292
+ [1, -1, -1],
293
+
294
+ [-1, 0, 0],
295
+ [-1, 0, 1],
296
+ [-1, 1, 0],
297
+ [-1, 0, -1],
298
+ [-1, -1, 0],
299
+ [-1, 1, 1],
300
+ [-1, -1, 1],
301
+ [-1, 1, -1],
302
+ [-1, -1, -1],
303
+ ]
304
+ ) / (volume_resolution * 4 - 1)
305
+
306
+ if mesh_load is None:
307
+ mesh = o3d.io.read_triangle_mesh(mesh_path)
308
+ vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
309
+ faces = np.asarray(mesh.triangles)
310
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
311
+ else:
312
+ vertices = np.clip(np.asarray(mesh_load.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
313
+ faces = np.asarray(mesh_load.faces)
314
+
315
+ mesh = o3d.geometry.TriangleMesh()
316
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
317
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
318
+
319
+
320
+ voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
321
+ mesh,
322
+ voxel_size=1. / volume_resolution,
323
+ min_bound=[-0.5, -0.5, -0.5],
324
+ max_bound=[0.5, 0.5, 0.5]
325
+ )
326
+ voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
327
+
328
+ points_normals_sample = trimesh.Trimesh(vertices=vertices, faces=faces).sample(count=pc_sample_number, return_index=True)
329
+ points_xyz_np = points_normals_sample[0].astype(np.float32)
330
+ points_sample = points_normals_sample[0].astype(np.float32)
331
+ face_indices = points_normals_sample[1]
332
+
333
+ voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
334
+ o3d.geometry.PointCloud(
335
+ o3d.utility.Vector3dVector(
336
+ np.clip(
337
+ (points_sample[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
338
+ -0.5 + 1e-6, 0.5 - 1e-6)
339
+ )
340
+ ),
341
+ voxel_size=1. / volume_resolution,
342
+ min_bound=[-0.5, -0.5, -0.5],
343
+ max_bound=[0.5, 0.5, 0.5]
344
+ )
345
+ voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
346
+ voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
347
+
348
+ features_list = [torch.from_numpy(points_xyz_np)]
349
+
350
+ if use_normals:
351
+ mesh.compute_triangle_normals()
352
+ normals_sample = np.asarray(
353
+ mesh.triangle_normals
354
+ )[points_normals_sample[1]].astype(np.float32)
355
+ # points_sample = torch.cat((torch.Tensor(points_sample), torch.Tensor(normals_sample)), axis=-1)
356
+ features_list.append(torch.from_numpy(normals_sample))
357
+
358
+
359
+ ########################################
360
+ # add direction to three vtx
361
+ ########################################
362
+
363
+ ## wo sort
364
+ # sampled_face_v_indices = faces[face_indices]
365
+ # v1 = vertices[sampled_face_v_indices[:, 0]]
366
+ # v2 = vertices[sampled_face_v_indices[:, 1]]
367
+ # v3 = vertices[sampled_face_v_indices[:, 2]]
368
+
369
+ # w sort
370
+ sampled_face_v_indices = faces[face_indices]
371
+ v_batch = np.stack([
372
+ vertices[sampled_face_v_indices[:, 0]],
373
+ vertices[sampled_face_v_indices[:, 1]],
374
+ vertices[sampled_face_v_indices[:, 2]]
375
+ ], axis=1)
376
+
377
+ view_dtype = np.dtype((np.void, v_batch.dtype.itemsize * v_batch.shape[-1]))
378
+ v_view = v_batch.view(view_dtype).squeeze(-1) # 变成 (N, 3) 的 void
379
+
380
+ sort_idx = np.argsort(v_view, axis=1) # (N, 3)
381
+
382
+ batch_indices = np.arange(v_batch.shape[0])[:, None]
383
+ v_sorted = v_batch[batch_indices, sort_idx] # (N, 3, 3)
384
+
385
+ v1 = v_sorted[:, 0, :]
386
+ v2 = v_sorted[:, 1, :]
387
+ v3 = v_sorted[:, 2, :]
388
+ # --------------------
389
+
390
+ dir1 = v1 - points_xyz_np
391
+ dir2 = v2 - points_xyz_np
392
+ dir3 = v3 - points_xyz_np
393
+
394
+ features_list.append(torch.Tensor(dir1.astype(np.float32)))
395
+ features_list.append(torch.Tensor(dir2.astype(np.float32)))
396
+ features_list.append(torch.Tensor(dir3.astype(np.float32)))
397
+
398
+ points_sample = torch.cat(features_list, axis=-1)
399
+ ########################################
400
+ # add direction to three vtx
401
+ ########################################
402
+
403
+ return voxels, points_sample
404
+
405
+ def get_sha256(filepath: str) -> str:
406
+ sha256_hash = hashlib.sha256()
407
+ with open(filepath, "rb") as f:
408
+ for byte_block in iter(lambda: f.read(4096), b""):
409
+ sha256_hash.update(byte_block)
410
+ return sha256_hash.hexdigest()
411
+
412
+ class VoxelVertexDataset_edge(Dataset):
413
+ def __init__(self,
414
+ root_dir: str,
415
+ base_resolution: int = 256,
416
+ min_resolution: int = 128,
417
+ img_res: int = 518,
418
+ cache_dir: str = "dataset_cache_test",
419
+ renders_dir: str = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond',
420
+ process_img: bool = False,
421
+ n_pre_samples: int = 1024,
422
+
423
+ active_voxel_res: int = 64,
424
+ pc_sample_number: int = 409600,
425
+
426
+ filter_active_voxels: bool = False, #####
427
+ min_active_voxels: int = 2000,
428
+ max_active_voxels: int = 40000,
429
+
430
+ cache_filter_path: str = "/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/data/filter_name/objaverse_200_2000_2000min_25000max.txt",
431
+
432
+ sample_type: str = 'uniform',
433
+ ):
434
+ self.root_dir = root_dir
435
+ self.cache_dir = cache_dir
436
+ self.img_res = img_res
437
+ self.renders_dir = renders_dir
438
+ self.process_img = process_img
439
+ self.filter_active_voxels=filter_active_voxels
440
+ self.min_active_voxels=min_active_voxels
441
+ self.max_active_voxels=max_active_voxels
442
+
443
+ self.active_voxel_res = active_voxel_res
444
+ self.pc_sample_number = pc_sample_number
445
+
446
+ self.sample_type = sample_type
447
+
448
+ # self.image_transform = transforms.ToTensor()
449
+ self.image_transform = transforms.Compose([
450
+ transforms.ToTensor(),
451
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
452
+ ])
453
+
454
+ os.makedirs(cache_dir, exist_ok=True)
455
+ assert (base_resolution & (base_resolution - 1)) == 0, "Resolution must be power of 2"
456
+ assert (min_resolution & (min_resolution - 1)) == 0, "Resolution must be power of 2"
457
+ self.res_levels = [
458
+ 2**i for i in range(
459
+ int(np.log2(min_resolution)),
460
+ int(np.log2(base_resolution)) + 1
461
+ )
462
+ ]
463
+
464
+ if self.active_voxel_res is not None and self.active_voxel_res not in self.res_levels:
465
+ self.res_levels.append(active_voxel_res)
466
+ self.res_levels.sort()
467
+
468
+ all_obj_files = sorted([f for f in os.listdir(root_dir) if f.endswith(('.obj', '.ply', '.glb'))])
469
+ if not all_obj_files:
470
+ raise ValueError(f"No OBJ files found in {root_dir}")
471
+
472
+ if self.process_img:
473
+ map_file_path = os.path.join(os.path.dirname(self.renders_dir), 'map.json')
474
+ if os.path.exists(map_file_path):
475
+ print(f"Loading pre-computed hash map from {map_file_path}")
476
+ with open(map_file_path, 'r') as f:
477
+ file_map = json.load(f)
478
+ filename_to_hash = {item['filename']: item['sha256'] for item in file_map}
479
+ all_obj_hashes = [filename_to_hash.get(fname) for fname in all_obj_files]
480
+ else:
481
+ print("No hash map found. Calculating SHA256 hashes on the fly... (This may take a moment)")
482
+ all_obj_hashes = []
483
+ for fname in tqdm(all_obj_files, desc="Hashing .obj files"):
484
+ fpath = os.path.join(self.root_dir, fname)
485
+ all_obj_hashes.append(get_sha256(fpath))
486
+
487
+ else:
488
+ print("process_img is False, skipping SHA256 hash calculation.")
489
+ all_obj_hashes = [None] * len(all_obj_files)
490
+
491
+
492
+ if self.filter_active_voxels and cache_filter_path:
493
+ filtered_list_cache_path = cache_filter_path
494
+
495
+ if os.path.exists(filtered_list_cache_path):
496
+ print(f"Loading filtered BASENAMES from: {filtered_list_cache_path}")
497
+ basename_to_fullname_map = {os.path.splitext(f)[0]: f for f in all_obj_files}
498
+
499
+ with open(filtered_list_cache_path, 'r') as f:
500
+ filtered_basenames = [line.strip() for line in f if line.strip()]
501
+
502
+ self.obj_files = []
503
+ for basename in filtered_basenames:
504
+ if basename in basename_to_fullname_map:
505
+ self.obj_files.append(basename_to_fullname_map[basename])
506
+ else:
507
+ print(f"[WARN] Basename '{basename}' from filter list not found in directory '{self.root_dir}'. Skipping.")
508
+
509
+ file_to_hash_map = dict(zip(all_obj_files, all_obj_hashes))
510
+ self.obj_hashes = [file_to_hash_map.get(fname) for fname in self.obj_files] # 使用 .get 更安全
511
+
512
+ print(f"Loaded and matched {len(self.obj_files)} samples from the filter list.")
513
+
514
+ else:
515
+ print(f"Cache filter file not found: {filtered_list_cache_path}. Proceeding with on-the-fly filtering...")
516
+
517
+ else:
518
+ self.obj_files = all_obj_files
519
+ self.obj_hashes = all_obj_hashes
520
+
521
+
522
+ if not self.obj_files:
523
+ raise ValueError(f"No OBJ files found in {root_dir}")
524
+
525
+ self.rembg_session = None
526
+
527
+ def _init_rembg_session_if_needed(self):
528
+ if self.rembg_session is None:
529
+ print(f"Initializing rembg session for worker {os.getpid()}...")
530
+ self.rembg_session = rembg.new_session(model_name='u2net')
531
+
532
+ def preprocess_image(self, input: Image.Image) -> Image.Image:
533
+ self._init_rembg_session_if_needed()
534
+ has_alpha = False
535
+ if input.mode == 'RGBA':
536
+ alpha = np.array(input)[:, :, 3]
537
+ if not np.all(alpha == 255):
538
+ has_alpha = True
539
+ if has_alpha:
540
+ output = input
541
+ else:
542
+ input = input.convert('RGB')
543
+ max_size = max(input.size)
544
+ scale = min(1, 1024 / max_size)
545
+ if scale < 1:
546
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
547
+ if getattr(self, 'rembg_session', None) is None:
548
+ self.rembg_session = rembg.new_session('u2net')
549
+ output = rembg.remove(input, session=self.rembg_session)
550
+ output_np = np.array(output)
551
+ alpha = output_np[:, :, 3]
552
+ bbox = np.argwhere(alpha > 0.8 * 255)
553
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
554
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
555
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
556
+ size = int(size * 1.2)
557
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
558
+ output = output.crop(bbox) # type: ignore
559
+ output = output.resize((518, 518), Image.Resampling.LANCZOS)
560
+ output = np.array(output).astype(np.float32) / 255
561
+ output = output[:, :, :3] * output[:, :, 3:4]
562
+ output = Image.fromarray((output * 255).astype(np.uint8))
563
+ return output
564
+
565
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
566
+ name = os.path.splitext(self.obj_files[idx])[0]
567
+ cache_path = os.path.join(self.cache_dir, f"{name}_precombined.npz")
568
+
569
+ sha256_hash = self.obj_hashes[idx]
570
+ mesh_render_dir = os.path.join(self.renders_dir, sha256_hash) if sha256_hash else ""
571
+
572
+ image_path = ''
573
+ if mesh_render_dir and os.path.isdir(mesh_render_dir):
574
+ try:
575
+ render_files = [f for f in os.listdir(mesh_render_dir) if f.endswith('.png')]
576
+ if render_files:
577
+ image_path = os.path.join(mesh_render_dir, random.choice(render_files))
578
+ except OSError as e:
579
+ print(f"[WARN] Could not access render directory {mesh_render_dir}: {e}")
580
+
581
+ if self.process_img:
582
+ try:
583
+ if image_path and os.path.exists(image_path):
584
+ image_obj = self.image_transform(self.preprocess_image(Image.open(image_path)).convert('RGB'))
585
+ else:
586
+ image_obj = self.image_transform(Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)).convert('RGB'))
587
+ except Exception as e:
588
+ image_obj = self.image_transform(Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)).convert('RGB'))
589
+ print(f'Error processing image {image_path}: {e}')
590
+
591
+ if os.path.exists(cache_path):
592
+ try:
593
+ loaded = np.load(cache_path, allow_pickle=True)
594
+ data = {
595
+ 'original_faces': torch.from_numpy(loaded['original_faces']),
596
+ 'original_vertices': torch.from_numpy(loaded['original_vertices']),
597
+ }
598
+ for res in self.res_levels:
599
+ # Load standard voxel data
600
+ if f'combined_voxels_{res}' in loaded:
601
+ data[f'combined_voxels_{res}'] = torch.from_numpy(loaded[f'combined_voxels_{res}'])
602
+ data[f'combined_voxel_labels_{res}'] = torch.from_numpy(loaded[f'combined_voxel_labels_{res}'])
603
+ data[f'gt_combined_endpoints_{res}'] = torch.from_numpy(loaded[f'gt_combined_endpoints_{res}'])
604
+
605
+ data[f'gt_vertex_voxels_{res}'] = torch.from_numpy(loaded[f'gt_vertex_voxels_{res}'])
606
+ data[f'gt_edge_voxels_{res}'] = torch.from_numpy(loaded[f'gt_edge_voxels_{res}'])
607
+ data[f'gt_edge_endpoints_{res}'] = torch.from_numpy(loaded[f'gt_edge_endpoints_{res}'])
608
+ data[f'gt_edge_errors_{res}'] = torch.from_numpy(loaded[f'gt_edge_errors_{res}'])
609
+
610
+ # Load Active Voxels and Point Cloud for Local Pooling
611
+ if res == self.active_voxel_res:
612
+ if f'active_voxels_{res}' in loaded:
613
+ data[f'active_voxels_{res}'] = torch.from_numpy(loaded[f'active_voxels_{res}'])
614
+ if f'point_cloud_{res}' in loaded:
615
+ data[f'point_cloud_{res}'] = torch.from_numpy(loaded[f'point_cloud_{res}'])
616
+
617
+ if self.process_img:
618
+ data['image'] = image_obj
619
+ data['image_path'] = image_path
620
+ return data
621
+
622
+ except Exception as e:
623
+ print(f"[WARN] Corrupted NPZ cache {cache_path}, regenerating... {e}")
624
+ os.remove(cache_path)
625
+
626
+ try:
627
+ mesh_path = os.path.join(self.root_dir, self.obj_files[idx])
628
+ mesh = normalize_mesh(mesh_path)
629
+ if mesh.is_empty or not hasattr(mesh.vertices, 'shape') or mesh.vertices.shape[0] < 3 or not hasattr(mesh.faces, 'shape') or mesh.faces.shape[0] < 1:
630
+ raise ValueError("Invalid or empty mesh")
631
+ except Exception as e:
632
+ print(f"[ERROR] Failed to load mesh: {self.obj_files[idx]} | {e}")
633
+ return self.__getitem__((idx + 1) % len(self))
634
+
635
+ vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
636
+ faces = torch.tensor(mesh.faces, dtype=torch.long)
637
+
638
+ data = {'original_faces': faces.clone(), 'original_vertices': vertices.clone()}
639
+
640
+ for res in self.res_levels:
641
+ quantized = quantize_vertices(vertices, res)
642
+ tmesh = trimesh.Trimesh(vertices=quantized.numpy(), faces=faces.numpy())
643
+ tmesh.merge_vertices()
644
+
645
+ vertex_voxels_raw = torch.from_numpy(tmesh.vertices.astype(np.int32))
646
+ edges_raw = tmesh.edges_unique
647
+
648
+ vertex_labels_raw = torch.zeros(vertex_voxels_raw.shape[0], dtype=torch.long)
649
+
650
+ all_edge_voxels = []
651
+ edge_endpoints = []
652
+ edge_errors = []
653
+
654
+ for u_idx, v_idx in edges_raw:
655
+ p1_grid, p2_grid = vertex_voxels_raw[u_idx].float(), vertex_voxels_raw[v_idx].float()
656
+ v, ep, err = get_voxel_line(p1_grid, p2_grid)
657
+ all_edge_voxels.extend(v)
658
+ edge_endpoints.extend(ep)
659
+ edge_errors.extend(err)
660
+
661
+ if all_edge_voxels:
662
+ edge_voxels_np = np.array(all_edge_voxels, dtype=np.int32)
663
+ edge_endpoints_np = np.array([np.stack(pair) for pair in edge_endpoints], dtype=np.float32)
664
+ edge_errors_np = np.array(edge_errors, dtype=np.float32)
665
+
666
+ unique_edge_voxels_np, first_indices = np.unique(edge_voxels_np, axis=0, return_index=True)
667
+ edge_voxels_raw = torch.from_numpy(unique_edge_voxels_np)
668
+ edge_labels_raw = torch.ones(len(edge_voxels_raw), dtype=torch.long)
669
+ edge_endpoints_raw = torch.from_numpy(edge_endpoints_np[first_indices])
670
+ edge_errors_raw = torch.from_numpy(edge_errors_np[first_indices])
671
+ else:
672
+ edge_voxels_raw = torch.empty(0, 3, dtype=torch.int32)
673
+ edge_labels_raw = torch.empty(0, dtype=torch.long)
674
+ edge_endpoints_raw = torch.empty(0, 2, 3, dtype=torch.float32)
675
+ edge_errors_raw = torch.empty(0, 3, dtype=torch.float32)
676
+
677
+
678
+ if res == self.active_voxel_res:
679
+ try:
680
+ if self.sample_type == 'uniform':
681
+ # triposf-style, normilize wrong
682
+ ts_voxels, ts_points = load_quantized_mesh_original(
683
+ mesh_path=os.path.join(self.root_dir, self.obj_files[idx]),
684
+ mesh_load=mesh,
685
+ volume_resolution=res,
686
+ use_normals=True,
687
+ pc_sample_number=self.pc_sample_number,
688
+ )
689
+ else:
690
+ ts_voxels, ts_points = load_quantized_mesh_dora(
691
+ mesh_path=os.path.join(self.root_dir, self.obj_files[idx]),
692
+ mesh_load=mesh,
693
+ volume_resolution=res,
694
+ use_normals=True,
695
+ pc_sample_number=self.pc_sample_number,
696
+ edge_sample_ratio=0.5,
697
+ )
698
+
699
+ # Convert types
700
+ # Voxels from TripoSF are float Tensor (N, 3), convert to int32
701
+ data[f'active_voxels_{res}'] = ts_voxels.int()
702
+ data[f'point_cloud_{res}'] = ts_points
703
+
704
+
705
+ except Exception as e:
706
+ print(f"[ERROR] Failed to compute active voxels/points for {name} at res {res}: {e}")
707
+ data[f'active_voxels_{res}'] = torch.empty(0, 3, dtype=torch.int32)
708
+ data[f'point_cloud_{res}'] = torch.empty(0, 6, dtype=torch.float32)
709
+
710
+ combined_voxels = torch.cat([vertex_voxels_raw, edge_voxels_raw], dim=0)
711
+ combined_labels = torch.cat([vertex_labels_raw, edge_labels_raw], dim=0)
712
+
713
+ if combined_voxels.numel() > 0:
714
+ unique_voxels, inverse_indices = torch.unique(combined_voxels, dim=0, return_inverse=True)
715
+
716
+ zero_mask = (combined_labels == 0)
717
+ if zero_mask.any():
718
+ zero_per_unique = torch.zeros(len(unique_voxels), dtype=torch.bool)
719
+ zero_per_unique.scatter_(0, inverse_indices[zero_mask], True)
720
+ final_combined_labels = torch.where(zero_per_unique, 0, 1).long()
721
+ else:
722
+ final_combined_labels = torch.ones(len(unique_voxels), dtype=torch.long)
723
+
724
+ if edge_voxels_raw.numel() > 0:
725
+ edge_endpoint_map = {tuple(coord): ep for coord, ep in zip(edge_voxels_raw.numpy(), edge_endpoints_raw.numpy())}
726
+
727
+ endpoints_arr = np.empty((len(unique_voxels), 2, 3), dtype=np.float32)
728
+ unique_voxels_np = unique_voxels.numpy()
729
+
730
+ for j, coord in enumerate(unique_voxels_np):
731
+ coord_tuple = tuple(coord)
732
+ if coord_tuple in edge_endpoint_map:
733
+ endpoints_arr[j] = edge_endpoint_map[coord_tuple]
734
+ else:
735
+ endpoints_arr[j, 0, :] = coord
736
+ endpoints_arr[j, 1, :] = coord
737
+ final_combined_endpoints = torch.from_numpy(endpoints_arr)
738
+ else:
739
+ final_combined_endpoints = unique_voxels.float().unsqueeze(1).repeat(1, 2, 1)
740
+ else:
741
+ unique_voxels = torch.empty(0, 3, dtype=torch.int32)
742
+ final_combined_labels = torch.empty(0, dtype=torch.long)
743
+ final_combined_endpoints = torch.empty(0, 2, 3, dtype=torch.float32)
744
+
745
+ data[f'combined_voxels_{res}'] = unique_voxels
746
+ data[f'combined_voxel_labels_{res}'] = final_combined_labels
747
+ data[f'gt_combined_endpoints_{res}'] = final_combined_endpoints.reshape(-1, 6)
748
+
749
+ data[f'gt_vertex_voxels_{res}'] = vertex_voxels_raw
750
+ data[f'gt_edge_voxels_{res}'] = edge_voxels_raw
751
+ data[f'gt_edge_endpoints_{res}'] = edge_endpoints_raw.reshape(-1, 6)
752
+ data[f'gt_edge_errors_{res}'] = edge_errors_raw
753
+
754
+
755
+ save_dict = {
756
+ 'original_faces': data['original_faces'].numpy(),
757
+ 'original_vertices': data['original_vertices'].numpy(),
758
+ }
759
+ for res in self.res_levels:
760
+ for key_suffix in [
761
+ 'combined_voxels', 'combined_voxel_labels', 'gt_combined_endpoints',
762
+ 'gt_vertex_voxels', 'gt_edge_voxels', 'gt_edge_endpoints', 'gt_edge_errors',
763
+ ]:
764
+ full_key = f'{key_suffix}_{res}'
765
+ if full_key in data:
766
+ save_dict[full_key] = data[full_key].numpy()
767
+
768
+ if f'active_voxels_{res}' in data:
769
+ save_dict[f'active_voxels_{res}'] = data[f'active_voxels_{res}'].numpy()
770
+
771
+ if f'point_cloud_{res}' in data:
772
+ save_dict[f'point_cloud_{res}'] = data[f'point_cloud_{res}'].numpy()
773
+
774
+ try:
775
+ np.savez_compressed(cache_path, **save_dict)
776
+ except Exception as e:
777
+ print(f"[ERROR] Failed to save cache {cache_path}: {e}")
778
+ if os.path.exists(cache_path): os.remove(cache_path)
779
+
780
+ if self.process_img:
781
+ data['image'] = image_obj
782
+ data['image_path'] = image_path
783
+
784
+ return data
785
+
786
+ def __len__(self) -> int:
787
+ return len(self.obj_files)
788
+
789
+ def collate_fn_pointnet(
790
+ batch: List[Dict[str, torch.Tensor]],
791
+ ) -> Dict[str, torch.Tensor]:
792
+
793
+ if not batch:
794
+ return {}
795
+
796
+ batch = [b for b in batch if b is not None]
797
+ if not batch:
798
+ return {}
799
+
800
+ collated = {
801
+ 'original_faces': [b['original_faces'] for b in batch],
802
+ 'original_vertices': [b['original_vertices'] for b in batch],
803
+ 'image_path': [b['image_path'] for b in batch],
804
+ }
805
+
806
+ if 'image' in batch[0] and batch[0]['image'] is not None:
807
+ collated['image'] = torch.stack([b['image'] for b in batch])
808
+
809
+ res_levels = []
810
+ for k in batch[0].keys():
811
+ if k.startswith('gt_vertex_voxels_'):
812
+ try:
813
+ res_levels.append(int(k.split('_')[-1]))
814
+ except ValueError:
815
+ pass
816
+ res_levels.sort()
817
+
818
+ for res in res_levels:
819
+ all_active_voxels_list = []
820
+ all_point_clouds_list = []
821
+
822
+ all_combined_voxels_list = []
823
+ all_combined_labels_list = []
824
+ all_vertex_voxels_only = []
825
+ all_edge_voxels_only = []
826
+ all_edge_endpoints_only = []
827
+ all_combined_endpoints = []
828
+ all_combined_errors_list = []
829
+ layout = []
830
+
831
+ vtx_offset = 0
832
+ adj_flat_offset = 0
833
+ start_idx = 0
834
+
835
+ # Attempt to find device from first tensor
836
+ device = torch.device('cpu')
837
+ for v in batch[0].values():
838
+ if isinstance(v, torch.Tensor):
839
+ device = v.device
840
+ break
841
+
842
+ for i, sample in enumerate(batch):
843
+ vertex_voxels = sample.get(f'gt_vertex_voxels_{res}', torch.empty(0,3,dtype=torch.int32)).to(device)
844
+ vertex_labels = torch.zeros(vertex_voxels.shape[0], dtype=torch.long, device=device)
845
+ edge_voxels = sample.get(f'gt_edge_voxels_{res}', torch.empty(0,3,dtype=torch.int32)).to(device)
846
+ edge_labels = torch.ones(edge_voxels.shape[0], dtype=torch.long, device=device)
847
+ edge_endpoints= sample.get(f'gt_edge_endpoints_{res}', torch.empty(0,6,dtype=torch.float32)).to(device)
848
+ edge_errors = sample.get(f'gt_edge_errors_{res}', torch.empty(0,3,dtype=torch.float32)).to(device)
849
+
850
+ vertex_errors = sample.get(f'gt_vertex_errors_{res}', torch.zeros_like(vertex_voxels, dtype=torch.float32)).to(device)
851
+
852
+ if vertex_voxels.numel() > 0:
853
+ idx = torch.full((vertex_voxels.shape[0],1), i, dtype=torch.int32, device=device)
854
+ all_vertex_voxels_only.append(torch.cat([idx, vertex_voxels], dim=1))
855
+
856
+ if edge_voxels.numel() > 0:
857
+ idx = torch.full((edge_voxels.shape[0],1), i, dtype=torch.int32, device=device)
858
+ all_edge_voxels_only.append(torch.cat([idx, edge_voxels], dim=1))
859
+ all_edge_endpoints_only.append(
860
+ torch.cat([idx.to(torch.float32), edge_endpoints], dim=1))
861
+
862
+ if vertex_voxels.numel() + edge_voxels.numel() > 0:
863
+ combined_voxels = torch.cat([vertex_voxels, edge_voxels], dim=0)
864
+ combined_labels = torch.cat([vertex_labels, edge_labels], dim=0)
865
+
866
+ endpoints = torch.zeros(combined_voxels.size(0), 6, dtype=torch.float32, device=device)
867
+ if edge_voxels.numel() > 0:
868
+ endpoints[-edge_voxels.size(0):] = edge_endpoints
869
+ if vertex_voxels.numel() > 0:
870
+ endpoints[:vertex_voxels.size(0)] = vertex_voxels.repeat(1,2).float()
871
+
872
+ combined_errors = torch.cat([vertex_errors, edge_errors], dim=0)
873
+
874
+ batch_idx_int = torch.full((combined_voxels.shape[0],1), i, dtype=torch.int32, device=device)
875
+ all_combined_voxels_list.append(torch.cat([batch_idx_int, combined_voxels], dim=1))
876
+ all_combined_labels_list.append(combined_labels)
877
+
878
+ batch_idx_float = batch_idx_int.to(torch.float32)
879
+ all_combined_endpoints.append(torch.cat([batch_idx_float, endpoints], dim=1))
880
+ all_combined_errors_list.append(torch.cat([batch_idx_float, combined_errors], dim=1))
881
+
882
+ layout.append(slice(start_idx, start_idx + combined_voxels.shape[0]))
883
+ start_idx += combined_voxels.shape[0]
884
+ else:
885
+ layout.append(slice(start_idx, start_idx))
886
+
887
+ # Active Voxels (Sparse Coords)
888
+ active_voxels = sample.get(f'active_voxels_{res}', torch.empty(0, 3, dtype=torch.int32)).to(device)
889
+ if active_voxels.numel() > 0:
890
+ idx = torch.full((active_voxels.shape[0], 1), i, dtype=torch.int32, device=device)
891
+ all_active_voxels_list.append(torch.cat([idx, active_voxels], dim=1))
892
+
893
+ # ==========================================
894
+ # Modified Section: Collect Point Clouds
895
+ # ==========================================
896
+ # pc = sample.get(f'point_cloud_{res}', torch.empty(0, 6, dtype=torch.float32)).to(device)
897
+ pc = sample.get(f'point_cloud_{res}', torch.empty(0, 15, dtype=torch.float32)).to(device)
898
+ # We expect all samples to have point clouds if res == active_voxel_res
899
+ if pc.numel() > 0:
900
+ all_point_clouds_list.append(pc)
901
+
902
+ collated[f'layout_{res}'] = layout
903
+
904
+ def cat_or_empty(lst, shape, dtype):
905
+ return torch.cat(lst, dim=0) if lst else torch.empty(shape, dtype=dtype, device=device)
906
+
907
+ collated[f'combined_voxels_{res}'] = cat_or_empty(all_combined_voxels_list,(0,4),torch.int32)
908
+ collated[f'combined_voxel_labels_{res}'] = cat_or_empty(all_combined_labels_list,(0,),torch.long)
909
+ collated[f'gt_vertex_voxels_{res}'] = cat_or_empty(all_vertex_voxels_only,(0,4),torch.int32)
910
+ collated[f'gt_edge_voxels_{res}'] = cat_or_empty(all_edge_voxels_only,(0,4),torch.int32)
911
+ collated[f'gt_edge_endpoints_{res}'] = cat_or_empty(all_edge_endpoints_only,(0,7),torch.float32)
912
+ collated[f'gt_combined_endpoints_{res}'] = cat_or_empty(all_combined_endpoints,(0,7),torch.float32)
913
+ collated[f'gt_combined_errors_{res}'] = cat_or_empty(all_combined_errors_list,(0,4),torch.float32)
914
+
915
+ collated[f'active_voxels_{res}'] = cat_or_empty(all_active_voxels_list, (0, 4), torch.int32)
916
+
917
+
918
+ if all_point_clouds_list:
919
+ collated[f'point_cloud_{res}'] = torch.stack(all_point_clouds_list, dim=0)
920
+ else:
921
+ # collated[f'point_cloud_{res}'] = torch.empty((0, 6), dtype=torch.float32, device=device)
922
+ collated[f'point_cloud_{res}'] = torch.empty((0, 15), dtype=torch.float32, device=device)
923
+
924
+ return collated
dataset_triposf_head.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from typing import *
4
+ import trimesh
5
+ import numpy as np
6
+ from torch.utils.data import Dataset
7
+ import torch.nn.functional as F
8
+ # from utils import quantize_vertices
9
+ from utils import get_voxel_line
10
+ import random
11
+ import hashlib
12
+ import json
13
+ from tqdm import tqdm
14
+ from torchvision import transforms
15
+ from PIL import Image
16
+ import rembg
17
+ import open3d as o3d
18
+ from trimesh import grouping
19
+
20
+ def normalize_mesh(mesh_path):
21
+ scene = trimesh.load(mesh_path, process=False, force='scene')
22
+ meshes = []
23
+ for node_name in scene.graph.nodes_geometry:
24
+ geom_name = scene.graph[node_name][1]
25
+ geometry = scene.geometry[geom_name]
26
+ transform = scene.graph[node_name][0]
27
+ if isinstance(geometry, trimesh.Trimesh):
28
+ geometry.apply_transform(transform)
29
+ meshes.append(geometry)
30
+
31
+ mesh = trimesh.util.concatenate(meshes)
32
+
33
+ center = mesh.bounding_box.centroid
34
+ mesh.apply_translation(-center)
35
+ scale = max(mesh.bounding_box.extents)
36
+ mesh.apply_scale(2.0 / scale * 0.5)
37
+
38
+ return mesh
39
+
40
+ def quantize_vertices(vertices: torch.Tensor, res: int):
41
+ """
42
+ Quantize normalized vertices (range approx [-0.5, 0.5]) to integer grid [0, res-1].
43
+ """
44
+ normalized = vertices + 0.5
45
+ scaled = normalized * res
46
+ quantized = torch.floor(scaled).clamp(0, res - 1).int()
47
+ return quantized
48
+
49
+ def sample_edges_dora(tm_mesh, n_samples):
50
+ adj_faces = tm_mesh.face_adjacency
51
+ adj_edges = tm_mesh.face_adjacency_edges
52
+
53
+ internal_data = None
54
+ if len(adj_faces) > 0:
55
+ n0 = tm_mesh.face_normals[adj_faces[:, 0]]
56
+ n1 = tm_mesh.face_normals[adj_faces[:, 1]]
57
+ sum_normals = n0 + n1
58
+ norms = np.linalg.norm(sum_normals, axis=1, keepdims=True)
59
+ norms[norms < 1e-6] = 1.0
60
+ int_normals = sum_normals / norms
61
+
62
+ int_v_start = tm_mesh.vertices[adj_edges[:, 0]]
63
+ int_v_end = tm_mesh.vertices[adj_edges[:, 1]]
64
+
65
+ faces_pair = tm_mesh.faces[adj_faces]
66
+ sum_face_indices = np.sum(faces_pair, axis=2)
67
+ sum_edge_indices = np.sum(adj_edges, axis=1)
68
+
69
+ unique_idx_0 = sum_face_indices[:, 0] - sum_edge_indices
70
+ unique_idx_1 = sum_face_indices[:, 1] - sum_edge_indices
71
+
72
+ v_unique_0 = tm_mesh.vertices[unique_idx_0]
73
+ v_unique_1 = tm_mesh.vertices[unique_idx_1]
74
+ int_v_virtual = (v_unique_0 + v_unique_1) * 0.5
75
+
76
+ internal_data = (int_v_start, int_v_end, int_normals, int_v_virtual)
77
+
78
+ edges_sorted = tm_mesh.edges_sorted
79
+
80
+ if len(edges_sorted) == 0:
81
+ boundary_data = None
82
+ else:
83
+ boundary_group = grouping.group_rows(edges_sorted, require_count=1)
84
+
85
+ boundary_data = None
86
+ if len(boundary_group) > 0:
87
+ boundary_indices = np.concatenate([np.atleast_1d(g) for g in boundary_group])
88
+
89
+ boundary_face_indices = boundary_indices // 3
90
+
91
+ bnd_normals = tm_mesh.face_normals[boundary_face_indices]
92
+
93
+ bnd_edge_v_indices = edges_sorted[boundary_indices]
94
+ bnd_v_start = tm_mesh.vertices[bnd_edge_v_indices[:, 0]]
95
+ bnd_v_end = tm_mesh.vertices[bnd_edge_v_indices[:, 1]]
96
+
97
+ boundary_face_v_indices = tm_mesh.faces[boundary_face_indices]
98
+
99
+ sum_face = np.sum(boundary_face_v_indices, axis=1)
100
+ sum_edge = np.sum(bnd_edge_v_indices, axis=1)
101
+ unique_idx = sum_face - sum_edge
102
+
103
+ bnd_v_virtual = tm_mesh.vertices[unique_idx]
104
+
105
+ boundary_data = (bnd_v_start, bnd_v_end, bnd_normals, bnd_v_virtual)
106
+
107
+ if internal_data is None and boundary_data is None:
108
+ return None, None, None
109
+
110
+ parts_start, parts_end, parts_norm, parts_virt = [], [], [], []
111
+
112
+ if internal_data is not None:
113
+ parts_start.append(internal_data[0])
114
+ parts_end.append(internal_data[1])
115
+ parts_norm.append(internal_data[2])
116
+ parts_virt.append(internal_data[3])
117
+
118
+ if boundary_data is not None:
119
+ parts_start.append(boundary_data[0])
120
+ parts_end.append(boundary_data[1])
121
+ parts_norm.append(boundary_data[2])
122
+ parts_virt.append(boundary_data[3])
123
+
124
+ if not parts_start:
125
+ return None, None, None
126
+
127
+ all_v_start = np.concatenate(parts_start, axis=0)
128
+ all_v_end = np.concatenate(parts_end, axis=0)
129
+ all_normals = np.concatenate(parts_norm, axis=0)
130
+ all_v_virtual = np.concatenate(parts_virt, axis=0)
131
+
132
+ edge_vectors = all_v_end - all_v_start
133
+ edge_lengths = np.linalg.norm(edge_vectors, axis=1)
134
+ total_length = np.sum(edge_lengths)
135
+
136
+ if total_length < 1e-9:
137
+ probs = np.ones(len(edge_lengths)) / len(edge_lengths)
138
+ else:
139
+ probs = edge_lengths / total_length
140
+ probs = probs / probs.sum()
141
+
142
+ chosen_indices = np.random.choice(len(edge_lengths), size=n_samples, p=probs)
143
+
144
+ t = np.random.rand(n_samples, 1)
145
+
146
+ sel_v_start = all_v_start[chosen_indices]
147
+ sel_v_end = all_v_end[chosen_indices]
148
+ sel_normals = all_normals[chosen_indices]
149
+ sel_v_virtual = all_v_virtual[chosen_indices]
150
+
151
+ sampled_points = sel_v_start + (sel_v_end - sel_v_start) * t
152
+
153
+ vertex_triplets = np.stack([sel_v_start, sel_v_end, sel_v_virtual], axis=1).astype(np.float32)
154
+
155
+ return sampled_points.astype(np.float32), sel_normals.astype(np.float32), vertex_triplets
156
+
157
+ def sample_edges_hunyuan(tm_mesh, n_samples):
158
+ if tm_mesh.vertex_normals is None:
159
+ tm_mesh.compute_vertex_normals()
160
+
161
+ V = tm_mesh.vertices
162
+ F = tm_mesh.faces
163
+ VN = tm_mesh.vertex_normals
164
+
165
+ # Edge 0: v0 -> v1 (Virtual: v2)
166
+ # Edge 1: v1 -> v2 (Virtual: v0)
167
+ # Edge 2: v2 -> v0 (Virtual: v1)
168
+
169
+ idx_start = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
170
+ idx_end = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
171
+ idx_virt = np.concatenate([F[:, 2], F[:, 0], F[:, 1]])
172
+
173
+ v_start = V[idx_start]
174
+ v_end = V[idx_end]
175
+
176
+ edge_vectors = v_end - v_start
177
+ edge_lengths = np.linalg.norm(edge_vectors, axis=1)
178
+
179
+ total_length = np.sum(edge_lengths)
180
+
181
+ if total_length < 1e-9:
182
+ probs = np.ones(len(edge_lengths)) / len(edge_lengths)
183
+ else:
184
+ probs = edge_lengths / total_length
185
+
186
+ chosen_indices = np.random.choice(len(edge_lengths), size=n_samples, p=probs)
187
+
188
+ sel_v_start = v_start[chosen_indices]
189
+ sel_v_end = v_end[chosen_indices]
190
+ sel_v_virt = V[idx_virt[chosen_indices]]
191
+
192
+ sel_vn_start = VN[idx_start[chosen_indices]]
193
+ sel_vn_end = VN[idx_end[chosen_indices]]
194
+
195
+ t = np.random.rand(n_samples, 1)
196
+
197
+ sampled_points = sel_v_start + (sel_v_end - sel_v_start) * t
198
+
199
+ sampled_normals = sel_vn_start * (1 - t) + sel_vn_end * t
200
+
201
+ norm_vals = np.linalg.norm(sampled_normals, axis=1, keepdims=True)
202
+ norm_vals[norm_vals < 1e-6] = 1.0
203
+ sampled_normals = sampled_normals / norm_vals
204
+
205
+ vertex_triplets = np.stack([sel_v_start, sel_v_end, sel_v_virt], axis=1)
206
+
207
+ return sampled_points.astype(np.float32), sampled_normals.astype(np.float32), vertex_triplets.astype(np.float32)
208
+
209
+ def load_quantized_mesh_dora(
210
+ mesh_path,
211
+ mesh_load=None,
212
+ volume_resolution=256,
213
+ use_normals=True,
214
+ pc_sample_number=4096000,
215
+ edge_sample_ratio=0.2
216
+ ):
217
+ cube_dilate = np.array([
218
+ [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, -1], [0, -1, 0], [0, 1, 1], [0, -1, 1], [0, 1, -1], [0, -1, -1],
219
+ [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 0, -1], [1, -1, 0], [1, 1, 1], [1, -1, 1], [1, 1, -1], [1, -1, -1],
220
+ [-1, 0, 0], [-1, 0, 1], [-1, 1, 0], [-1, 0, -1], [-1, -1, 0], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1], [-1, -1, -1]
221
+ ]) / (volume_resolution * 4 - 1)
222
+
223
+ if mesh_load is None:
224
+ mesh_o3d = o3d.io.read_triangle_mesh(mesh_path)
225
+ vertices = np.clip(np.asarray(mesh_o3d.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
226
+ faces = np.asarray(mesh_o3d.triangles)
227
+ mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
228
+ else:
229
+ vertices = np.clip(np.asarray(mesh_load.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
230
+ faces = np.asarray(mesh_load.faces)
231
+ mesh_o3d = o3d.geometry.TriangleMesh()
232
+ mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
233
+ mesh_o3d.triangles = o3d.utility.Vector3iVector(faces)
234
+
235
+ tm_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
236
+
237
+ n_edge_samples = int(pc_sample_number * edge_sample_ratio)
238
+ p_edge, n_edge, triplets_edge = sample_edges_dora(tm_mesh, n_edge_samples)
239
+
240
+ if p_edge is None:
241
+ # print('p_edge is none!')
242
+ n_surface_samples = pc_sample_number
243
+ else:
244
+ # print('p_edge is right!')
245
+ n_surface_samples = pc_sample_number - n_edge_samples
246
+
247
+ p_surf, idx_surf = tm_mesh.sample(n_surface_samples, return_index=True)
248
+ p_surf = p_surf.astype(np.float32)
249
+ n_surf = tm_mesh.face_normals[idx_surf].astype(np.float32)
250
+
251
+ v_indices_surf = faces[idx_surf]
252
+ triplets_surf = vertices[v_indices_surf]
253
+
254
+ if p_edge is None:
255
+ final_points = p_surf
256
+ final_normals = n_surf
257
+ final_triplets = triplets_surf
258
+ else:
259
+ final_points = np.concatenate([p_surf, p_edge.astype(np.float32)], axis=0)
260
+ if use_normals:
261
+ final_normals = np.concatenate([n_surf, n_edge.astype(np.float32)], axis=0)
262
+
263
+ final_triplets = np.concatenate([triplets_surf, triplets_edge], axis=0)
264
+
265
+ voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
266
+ mesh_o3d,
267
+ voxel_size=1. / volume_resolution,
268
+ min_bound=[-0.5, -0.5, -0.5],
269
+ max_bound=[0.5, 0.5, 0.5]
270
+ )
271
+ voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
272
+
273
+ voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
274
+ o3d.geometry.PointCloud(
275
+ o3d.utility.Vector3dVector(
276
+ np.clip(
277
+ (final_points[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
278
+ -0.5 + 1e-6, 0.5 - 1e-6)
279
+ )
280
+ ),
281
+ voxel_size=1. / volume_resolution,
282
+ min_bound=[-0.5, -0.5, -0.5],
283
+ max_bound=[0.5, 0.5, 0.5]
284
+ )
285
+ voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
286
+ voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
287
+
288
+ features_list = [torch.from_numpy(final_points)]
289
+
290
+ if use_normals:
291
+ features_list.append(torch.from_numpy(final_normals))
292
+
293
+ view_dtype = np.dtype((np.void, final_triplets.dtype.itemsize * final_triplets.shape[-1]))
294
+ v_view = final_triplets.view(view_dtype).squeeze(-1)
295
+
296
+ sort_idx = np.argsort(v_view, axis=1)
297
+
298
+ batch_indices = np.arange(final_triplets.shape[0])[:, None]
299
+ v_sorted = final_triplets[batch_indices, sort_idx]
300
+
301
+ v1 = v_sorted[:, 0, :]
302
+ v2 = v_sorted[:, 1, :]
303
+ v3 = v_sorted[:, 2, :]
304
+
305
+ dir1 = v1 - final_points
306
+ dir2 = v2 - final_points
307
+ dir3 = v3 - final_points
308
+
309
+ features_list.append(torch.Tensor(dir1.astype(np.float32)))
310
+ features_list.append(torch.Tensor(dir2.astype(np.float32)))
311
+ features_list.append(torch.Tensor(dir3.astype(np.float32)))
312
+
313
+ points_sample = torch.cat(features_list, axis=-1)
314
+
315
+ return voxels, points_sample
316
+
317
+ def load_quantized_mesh_original(
318
+ mesh_path,
319
+ mesh_load=None,
320
+ volume_resolution=256,
321
+ use_normals=True,
322
+ pc_sample_number=4096000,
323
+ ):
324
+ cube_dilate = np.array(
325
+ [
326
+ [0, 0, 0],
327
+ [0, 0, 1],
328
+ [0, 1, 0],
329
+ [0, 0, -1],
330
+ [0, -1, 0],
331
+ [0, 1, 1],
332
+ [0, -1, 1],
333
+ [0, 1, -1],
334
+ [0, -1, -1],
335
+
336
+ [1, 0, 0],
337
+ [1, 0, 1],
338
+ [1, 1, 0],
339
+ [1, 0, -1],
340
+ [1, -1, 0],
341
+ [1, 1, 1],
342
+ [1, -1, 1],
343
+ [1, 1, -1],
344
+ [1, -1, -1],
345
+
346
+ [-1, 0, 0],
347
+ [-1, 0, 1],
348
+ [-1, 1, 0],
349
+ [-1, 0, -1],
350
+ [-1, -1, 0],
351
+ [-1, 1, 1],
352
+ [-1, -1, 1],
353
+ [-1, 1, -1],
354
+ [-1, -1, -1],
355
+ ]
356
+ ) / (volume_resolution * 4 - 1)
357
+
358
+ if mesh_load is None:
359
+ mesh = o3d.io.read_triangle_mesh(mesh_path)
360
+ vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
361
+ faces = np.asarray(mesh.triangles)
362
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
363
+ else:
364
+ vertices = np.clip(np.asarray(mesh_load.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
365
+ faces = np.asarray(mesh_load.faces)
366
+
367
+ mesh = o3d.geometry.TriangleMesh()
368
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
369
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
370
+
371
+
372
+ voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
373
+ mesh,
374
+ voxel_size=1. / volume_resolution,
375
+ min_bound=[-0.5, -0.5, -0.5],
376
+ max_bound=[0.5, 0.5, 0.5]
377
+ )
378
+ voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
379
+
380
+ points_normals_sample = trimesh.Trimesh(vertices=vertices, faces=faces).sample(count=pc_sample_number, return_index=True)
381
+ points_xyz_np = points_normals_sample[0].astype(np.float32)
382
+ points_sample = points_normals_sample[0].astype(np.float32)
383
+ face_indices = points_normals_sample[1]
384
+
385
+ voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
386
+ o3d.geometry.PointCloud(
387
+ o3d.utility.Vector3dVector(
388
+ np.clip(
389
+ (points_sample[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
390
+ -0.5 + 1e-6, 0.5 - 1e-6)
391
+ )
392
+ ),
393
+ voxel_size=1. / volume_resolution,
394
+ min_bound=[-0.5, -0.5, -0.5],
395
+ max_bound=[0.5, 0.5, 0.5]
396
+ )
397
+ voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
398
+ voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
399
+
400
+ features_list = [torch.from_numpy(points_xyz_np)]
401
+
402
+ if use_normals:
403
+ mesh.compute_triangle_normals()
404
+ normals_sample = np.asarray(
405
+ mesh.triangle_normals
406
+ )[points_normals_sample[1]].astype(np.float32)
407
+ # points_sample = torch.cat((torch.Tensor(points_sample), torch.Tensor(normals_sample)), axis=-1)
408
+ features_list.append(torch.from_numpy(normals_sample))
409
+
410
+
411
+ ########################################
412
+ # add direction to three vtx
413
+ ########################################
414
+
415
+ ## wo sort
416
+ # sampled_face_v_indices = faces[face_indices]
417
+ # v1 = vertices[sampled_face_v_indices[:, 0]]
418
+ # v2 = vertices[sampled_face_v_indices[:, 1]]
419
+ # v3 = vertices[sampled_face_v_indices[:, 2]]
420
+
421
+ # w sort
422
+ sampled_face_v_indices = faces[face_indices]
423
+ v_batch = np.stack([
424
+ vertices[sampled_face_v_indices[:, 0]],
425
+ vertices[sampled_face_v_indices[:, 1]],
426
+ vertices[sampled_face_v_indices[:, 2]]
427
+ ], axis=1)
428
+
429
+ view_dtype = np.dtype((np.void, v_batch.dtype.itemsize * v_batch.shape[-1]))
430
+ v_view = v_batch.view(view_dtype).squeeze(-1) # 变成 (N, 3) 的 void
431
+
432
+ sort_idx = np.argsort(v_view, axis=1) # (N, 3)
433
+
434
+ batch_indices = np.arange(v_batch.shape[0])[:, None]
435
+ v_sorted = v_batch[batch_indices, sort_idx] # (N, 3, 3)
436
+
437
+ v1 = v_sorted[:, 0, :]
438
+ v2 = v_sorted[:, 1, :]
439
+ v3 = v_sorted[:, 2, :]
440
+ # --------------------
441
+
442
+ dir1 = v1 - points_xyz_np
443
+ dir2 = v2 - points_xyz_np
444
+ dir3 = v3 - points_xyz_np
445
+
446
+ features_list.append(torch.Tensor(dir1.astype(np.float32)))
447
+ features_list.append(torch.Tensor(dir2.astype(np.float32)))
448
+ features_list.append(torch.Tensor(dir3.astype(np.float32)))
449
+
450
+ points_sample = torch.cat(features_list, axis=-1)
451
+ ########################################
452
+ # add direction to three vtx
453
+ ########################################
454
+
455
+ return voxels, points_sample
456
+
457
+ def get_sha256(filepath: str) -> str:
458
+ sha256_hash = hashlib.sha256()
459
+ with open(filepath, "rb") as f:
460
+ for byte_block in iter(lambda: f.read(4096), b""):
461
+ sha256_hash.update(byte_block)
462
+ return sha256_hash.hexdigest()
463
+
464
+ class VoxelVertexDataset_edge(Dataset):
465
+ def __init__(self,
466
+ root_dir: str,
467
+ base_resolution: int = 256,
468
+ min_resolution: int = 128,
469
+ img_res: int = 518,
470
+ cache_dir: str = "dataset_cache_test",
471
+ renders_dir: str = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond',
472
+ process_img: bool = False,
473
+ n_pre_samples: int = 1024,
474
+
475
+ active_voxel_res: int = 64,
476
+ pc_sample_number: int = 409600,
477
+
478
+ filter_active_voxels: bool = False, #####
479
+ min_active_voxels: int = 2000,
480
+ max_active_voxels: int = 40000,
481
+
482
+ cache_filter_path: str = "/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/data/filter_name/objaverse_200_2000_2000min_25000max.txt",
483
+
484
+ sample_type: str = 'uniform',
485
+ ):
486
+ self.root_dir = root_dir
487
+ self.cache_dir = cache_dir
488
+ self.img_res = img_res
489
+ self.renders_dir = renders_dir
490
+ self.process_img = process_img
491
+ self.filter_active_voxels=filter_active_voxels
492
+ self.min_active_voxels=min_active_voxels
493
+ self.max_active_voxels=max_active_voxels
494
+
495
+ self.active_voxel_res = active_voxel_res
496
+ self.pc_sample_number = pc_sample_number
497
+
498
+ self.sample_type = sample_type
499
+
500
+ # self.image_transform = transforms.ToTensor()
501
+ self.image_transform = transforms.Compose([
502
+ transforms.ToTensor(),
503
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
504
+ ])
505
+
506
+ os.makedirs(cache_dir, exist_ok=True)
507
+ assert (base_resolution & (base_resolution - 1)) == 0, "Resolution must be power of 2"
508
+ assert (min_resolution & (min_resolution - 1)) == 0, "Resolution must be power of 2"
509
+ self.res_levels = [
510
+ 2**i for i in range(
511
+ int(np.log2(min_resolution)),
512
+ int(np.log2(base_resolution)) + 1
513
+ )
514
+ ]
515
+
516
+ if self.active_voxel_res is not None and self.active_voxel_res not in self.res_levels:
517
+ self.res_levels.append(active_voxel_res)
518
+ self.res_levels.sort()
519
+
520
+ all_obj_files = sorted([f for f in os.listdir(root_dir) if f.endswith(('.obj', '.ply', '.glb'))])
521
+ if not all_obj_files:
522
+ raise ValueError(f"No OBJ files found in {root_dir}")
523
+
524
+ if self.process_img:
525
+ map_file_path = os.path.join(os.path.dirname(self.renders_dir), 'map.json')
526
+ if os.path.exists(map_file_path):
527
+ print(f"Loading pre-computed hash map from {map_file_path}")
528
+ with open(map_file_path, 'r') as f:
529
+ file_map = json.load(f)
530
+ filename_to_hash = {item['filename']: item['sha256'] for item in file_map}
531
+ all_obj_hashes = [filename_to_hash.get(fname) for fname in all_obj_files]
532
+ else:
533
+ print("No hash map found. Calculating SHA256 hashes on the fly... (This may take a moment)")
534
+ all_obj_hashes = []
535
+ for fname in tqdm(all_obj_files, desc="Hashing .obj files"):
536
+ fpath = os.path.join(self.root_dir, fname)
537
+ all_obj_hashes.append(get_sha256(fpath))
538
+
539
+ else:
540
+ print("process_img is False, skipping SHA256 hash calculation.")
541
+ all_obj_hashes = [None] * len(all_obj_files)
542
+
543
+
544
+ if self.filter_active_voxels and cache_filter_path:
545
+ filtered_list_cache_path = cache_filter_path
546
+
547
+ if os.path.exists(filtered_list_cache_path):
548
+ print(f"Loading filtered BASENAMES from: {filtered_list_cache_path}")
549
+ basename_to_fullname_map = {os.path.splitext(f)[0]: f for f in all_obj_files}
550
+
551
+ with open(filtered_list_cache_path, 'r') as f:
552
+ filtered_basenames = [line.strip() for line in f if line.strip()]
553
+
554
+ self.obj_files = []
555
+ for basename in filtered_basenames:
556
+ if basename in basename_to_fullname_map:
557
+ self.obj_files.append(basename_to_fullname_map[basename])
558
+ else:
559
+ print(f"[WARN] Basename '{basename}' from filter list not found in directory '{self.root_dir}'. Skipping.")
560
+
561
+ file_to_hash_map = dict(zip(all_obj_files, all_obj_hashes))
562
+ self.obj_hashes = [file_to_hash_map.get(fname) for fname in self.obj_files] # 使用 .get 更安全
563
+
564
+ print(f"Loaded and matched {len(self.obj_files)} samples from the filter list.")
565
+
566
+ else:
567
+ print(f"Cache filter file not found: {filtered_list_cache_path}. Proceeding with on-the-fly filtering...")
568
+
569
+ else:
570
+ self.obj_files = all_obj_files
571
+ self.obj_hashes = all_obj_hashes
572
+
573
+
574
+ if not self.obj_files:
575
+ raise ValueError(f"No OBJ files found in {root_dir}")
576
+
577
+ self.rembg_session = None
578
+
579
+ def _init_rembg_session_if_needed(self):
580
+ if self.rembg_session is None:
581
+ print(f"Initializing rembg session for worker {os.getpid()}...")
582
+ self.rembg_session = rembg.new_session(model_name='u2net')
583
+
584
+ def preprocess_image(self, input: Image.Image) -> Image.Image:
585
+ self._init_rembg_session_if_needed()
586
+ has_alpha = False
587
+ if input.mode == 'RGBA':
588
+ alpha = np.array(input)[:, :, 3]
589
+ if not np.all(alpha == 255):
590
+ has_alpha = True
591
+ if has_alpha:
592
+ output = input
593
+ else:
594
+ input = input.convert('RGB')
595
+ max_size = max(input.size)
596
+ scale = min(1, 1024 / max_size)
597
+ if scale < 1:
598
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
599
+ if getattr(self, 'rembg_session', None) is None:
600
+ self.rembg_session = rembg.new_session('u2net')
601
+ output = rembg.remove(input, session=self.rembg_session)
602
+ output_np = np.array(output)
603
+ alpha = output_np[:, :, 3]
604
+ bbox = np.argwhere(alpha > 0.8 * 255)
605
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
606
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
607
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
608
+ size = int(size * 1.2)
609
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
610
+ output = output.crop(bbox) # type: ignore
611
+ output = output.resize((518, 518), Image.Resampling.LANCZOS)
612
+ output = np.array(output).astype(np.float32) / 255
613
+ output = output[:, :, :3] * output[:, :, 3:4]
614
+ output = Image.fromarray((output * 255).astype(np.uint8))
615
+ return output
616
+
617
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
618
+ name = os.path.splitext(self.obj_files[idx])[0]
619
+ cache_path = os.path.join(self.cache_dir, f"{name}_precombined.npz")
620
+
621
+ sha256_hash = self.obj_hashes[idx]
622
+ mesh_render_dir = os.path.join(self.renders_dir, sha256_hash) if sha256_hash else ""
623
+
624
+ image_path = ''
625
+ if mesh_render_dir and os.path.isdir(mesh_render_dir):
626
+ try:
627
+ render_files = [f for f in os.listdir(mesh_render_dir) if f.endswith('.png')]
628
+ if render_files:
629
+ image_path = os.path.join(mesh_render_dir, random.choice(render_files))
630
+ except OSError as e:
631
+ print(f"[WARN] Could not access render directory {mesh_render_dir}: {e}")
632
+
633
+ if self.process_img:
634
+ try:
635
+ if image_path and os.path.exists(image_path):
636
+ image_obj = self.image_transform(self.preprocess_image(Image.open(image_path)).convert('RGB'))
637
+ else:
638
+ image_obj = self.image_transform(Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)).convert('RGB'))
639
+ except Exception as e:
640
+ image_obj = self.image_transform(Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)).convert('RGB'))
641
+ print(f'Error processing image {image_path}: {e}')
642
+
643
+ if os.path.exists(cache_path):
644
+ try:
645
+ loaded = np.load(cache_path, allow_pickle=True)
646
+ data = {
647
+ 'original_faces': torch.from_numpy(loaded['original_faces']),
648
+ 'original_vertices': torch.from_numpy(loaded['original_vertices']),
649
+ }
650
+ for res in self.res_levels:
651
+ # Load standard voxel data
652
+ if f'combined_voxels_{res}' in loaded:
653
+ data[f'combined_voxels_{res}'] = torch.from_numpy(loaded[f'combined_voxels_{res}'])
654
+ data[f'combined_voxel_labels_{res}'] = torch.from_numpy(loaded[f'combined_voxel_labels_{res}'])
655
+ data[f'gt_combined_endpoints_{res}'] = torch.from_numpy(loaded[f'gt_combined_endpoints_{res}'])
656
+
657
+ data[f'gt_vertex_voxels_{res}'] = torch.from_numpy(loaded[f'gt_vertex_voxels_{res}'])
658
+ data[f'gt_edge_voxels_{res}'] = torch.from_numpy(loaded[f'gt_edge_voxels_{res}'])
659
+ data[f'gt_edge_endpoints_{res}'] = torch.from_numpy(loaded[f'gt_edge_endpoints_{res}'])
660
+ data[f'gt_edge_errors_{res}'] = torch.from_numpy(loaded[f'gt_edge_errors_{res}'])
661
+
662
+ # Load Active Voxels and Point Cloud for Local Pooling
663
+ if res == self.active_voxel_res:
664
+ if f'active_voxels_{res}' in loaded:
665
+ data[f'active_voxels_{res}'] = torch.from_numpy(loaded[f'active_voxels_{res}'])
666
+ if f'point_cloud_{res}' in loaded:
667
+ data[f'point_cloud_{res}'] = torch.from_numpy(loaded[f'point_cloud_{res}'])
668
+
669
+ if f'gt_vertex_edge_indices_{res}' in loaded:
670
+ data[f'gt_vertex_edge_indices_{res}'] = torch.from_numpy(loaded[f'gt_vertex_edge_indices_{res}'])
671
+
672
+ if self.process_img:
673
+ data['image'] = image_obj
674
+ data['image_path'] = image_path
675
+ return data
676
+
677
+ except Exception as e:
678
+ print(f"[WARN] Corrupted NPZ cache {cache_path}, regenerating... {e}")
679
+ os.remove(cache_path)
680
+
681
+ try:
682
+ mesh_path = os.path.join(self.root_dir, self.obj_files[idx])
683
+ mesh = normalize_mesh(mesh_path)
684
+ if mesh.is_empty or not hasattr(mesh.vertices, 'shape') or mesh.vertices.shape[0] < 3 or not hasattr(mesh.faces, 'shape') or mesh.faces.shape[0] < 1:
685
+ raise ValueError("Invalid or empty mesh")
686
+ except Exception as e:
687
+ print(f"[ERROR] Failed to load mesh: {self.obj_files[idx]} | {e}")
688
+ return self.__getitem__((idx + 1) % len(self))
689
+
690
+ vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
691
+ faces = torch.tensor(mesh.faces, dtype=torch.long)
692
+
693
+ data = {'original_faces': faces.clone(), 'original_vertices': vertices.clone()}
694
+
695
+ for res in self.res_levels:
696
+ quantized = quantize_vertices(vertices, res)
697
+ tmesh = trimesh.Trimesh(vertices=quantized.numpy(), faces=faces.numpy())
698
+ tmesh.merge_vertices()
699
+
700
+ vertex_voxels_raw = torch.from_numpy(tmesh.vertices.astype(np.int32))
701
+ edges_raw = tmesh.edges_unique
702
+
703
+ edges_indices_raw = torch.from_numpy(tmesh.edges_unique.astype(np.long))
704
+ data[f'gt_vertex_edge_indices_{res}'] = edges_indices_raw
705
+
706
+ vertex_labels_raw = torch.zeros(vertex_voxels_raw.shape[0], dtype=torch.long)
707
+
708
+ all_edge_voxels = []
709
+ edge_endpoints = []
710
+ edge_errors = []
711
+
712
+ for u_idx, v_idx in edges_raw:
713
+ p1_grid, p2_grid = vertex_voxels_raw[u_idx].float(), vertex_voxels_raw[v_idx].float()
714
+ v, ep, err = get_voxel_line(p1_grid, p2_grid, mode='cpu')
715
+ all_edge_voxels.extend(v)
716
+ edge_endpoints.extend(ep)
717
+ edge_errors.extend(err)
718
+
719
+ if all_edge_voxels:
720
+ edge_voxels_np = np.array(all_edge_voxels, dtype=np.int32)
721
+ edge_endpoints_np = np.array([np.stack(pair) for pair in edge_endpoints], dtype=np.float32)
722
+ edge_errors_np = np.array(edge_errors, dtype=np.float32)
723
+
724
+ unique_edge_voxels_np, first_indices = np.unique(edge_voxels_np, axis=0, return_index=True)
725
+ edge_voxels_raw = torch.from_numpy(unique_edge_voxels_np)
726
+ edge_labels_raw = torch.ones(len(edge_voxels_raw), dtype=torch.long)
727
+ edge_endpoints_raw = torch.from_numpy(edge_endpoints_np[first_indices])
728
+ edge_errors_raw = torch.from_numpy(edge_errors_np[first_indices])
729
+ else:
730
+ edge_voxels_raw = torch.empty(0, 3, dtype=torch.int32)
731
+ edge_labels_raw = torch.empty(0, dtype=torch.long)
732
+ edge_endpoints_raw = torch.empty(0, 2, 3, dtype=torch.float32)
733
+ edge_errors_raw = torch.empty(0, 3, dtype=torch.float32)
734
+
735
+
736
+ if res == self.active_voxel_res:
737
+ try:
738
+ if self.sample_type == 'uniform':
739
+ # triposf-style, normilize wrong
740
+ ts_voxels, ts_points = load_quantized_mesh_original(
741
+ mesh_path=os.path.join(self.root_dir, self.obj_files[idx]),
742
+ mesh_load=mesh,
743
+ volume_resolution=res,
744
+ use_normals=True,
745
+ pc_sample_number=self.pc_sample_number,
746
+ )
747
+ else:
748
+ ts_voxels, ts_points = load_quantized_mesh_dora(
749
+ mesh_path=os.path.join(self.root_dir, self.obj_files[idx]),
750
+ mesh_load=mesh,
751
+ volume_resolution=res,
752
+ use_normals=True,
753
+ pc_sample_number=self.pc_sample_number,
754
+ edge_sample_ratio=0.5,
755
+ )
756
+
757
+ # Convert types
758
+ # Voxels from TripoSF are float Tensor (N, 3), convert to int32
759
+ data[f'active_voxels_{res}'] = ts_voxels.int()
760
+ data[f'point_cloud_{res}'] = ts_points
761
+
762
+
763
+ except Exception as e:
764
+ print(f"[ERROR] Failed to compute active voxels/points for {name} at res {res}: {e}")
765
+ data[f'active_voxels_{res}'] = torch.empty(0, 3, dtype=torch.int32)
766
+ data[f'point_cloud_{res}'] = torch.empty(0, 6, dtype=torch.float32)
767
+
768
+ combined_voxels = torch.cat([vertex_voxels_raw, edge_voxels_raw], dim=0)
769
+ combined_labels = torch.cat([vertex_labels_raw, edge_labels_raw], dim=0)
770
+
771
+ if combined_voxels.numel() > 0:
772
+ unique_voxels, inverse_indices = torch.unique(combined_voxels, dim=0, return_inverse=True)
773
+
774
+ zero_mask = (combined_labels == 0)
775
+ if zero_mask.any():
776
+ zero_per_unique = torch.zeros(len(unique_voxels), dtype=torch.bool)
777
+ zero_per_unique.scatter_(0, inverse_indices[zero_mask], True)
778
+ final_combined_labels = torch.where(zero_per_unique, 0, 1).long()
779
+ else:
780
+ final_combined_labels = torch.ones(len(unique_voxels), dtype=torch.long)
781
+
782
+ if edge_voxels_raw.numel() > 0:
783
+ edge_endpoint_map = {tuple(coord): ep for coord, ep in zip(edge_voxels_raw.numpy(), edge_endpoints_raw.numpy())}
784
+
785
+ endpoints_arr = np.empty((len(unique_voxels), 2, 3), dtype=np.float32)
786
+ unique_voxels_np = unique_voxels.numpy()
787
+
788
+ for j, coord in enumerate(unique_voxels_np):
789
+ coord_tuple = tuple(coord)
790
+ if coord_tuple in edge_endpoint_map:
791
+ endpoints_arr[j] = edge_endpoint_map[coord_tuple]
792
+ else:
793
+ endpoints_arr[j, 0, :] = coord
794
+ endpoints_arr[j, 1, :] = coord
795
+ final_combined_endpoints = torch.from_numpy(endpoints_arr)
796
+ else:
797
+ final_combined_endpoints = unique_voxels.float().unsqueeze(1).repeat(1, 2, 1)
798
+ else:
799
+ unique_voxels = torch.empty(0, 3, dtype=torch.int32)
800
+ final_combined_labels = torch.empty(0, dtype=torch.long)
801
+ final_combined_endpoints = torch.empty(0, 2, 3, dtype=torch.float32)
802
+
803
+ data[f'combined_voxels_{res}'] = unique_voxels
804
+ data[f'combined_voxel_labels_{res}'] = final_combined_labels
805
+ data[f'gt_combined_endpoints_{res}'] = final_combined_endpoints.reshape(-1, 6)
806
+
807
+ data[f'gt_vertex_voxels_{res}'] = vertex_voxels_raw
808
+ data[f'gt_edge_voxels_{res}'] = edge_voxels_raw
809
+ data[f'gt_edge_endpoints_{res}'] = edge_endpoints_raw.reshape(-1, 6)
810
+ data[f'gt_edge_errors_{res}'] = edge_errors_raw
811
+
812
+
813
+ save_dict = {
814
+ 'original_faces': data['original_faces'].numpy(),
815
+ 'original_vertices': data['original_vertices'].numpy(),
816
+ }
817
+ for res in self.res_levels:
818
+ for key_suffix in [
819
+ 'combined_voxels', 'combined_voxel_labels', 'gt_combined_endpoints',
820
+ 'gt_vertex_voxels', 'gt_edge_voxels', 'gt_edge_endpoints', 'gt_edge_errors',
821
+ 'gt_vertex_edge_indices',
822
+ ]:
823
+ full_key = f'{key_suffix}_{res}'
824
+ if full_key in data:
825
+ save_dict[full_key] = data[full_key].numpy()
826
+
827
+ if f'active_voxels_{res}' in data:
828
+ save_dict[f'active_voxels_{res}'] = data[f'active_voxels_{res}'].numpy()
829
+
830
+ if f'point_cloud_{res}' in data:
831
+ save_dict[f'point_cloud_{res}'] = data[f'point_cloud_{res}'].numpy()
832
+
833
+ # try:
834
+ # np.savez_compressed(cache_path, **save_dict)
835
+ # except Exception as e:
836
+ # print(f"[ERROR] Failed to save cache {cache_path}: {e}")
837
+ # if os.path.exists(cache_path): os.remove(cache_path)
838
+
839
+ if self.process_img:
840
+ data['image'] = image_obj
841
+ data['image_path'] = image_path
842
+
843
+ return data
844
+
845
+ def __len__(self) -> int:
846
+ return len(self.obj_files)
847
+
848
+ def collate_fn_pointnet(
849
+ batch: List[Dict[str, torch.Tensor]],
850
+ ) -> Dict[str, torch.Tensor]:
851
+
852
+ if not batch:
853
+ return {}
854
+
855
+ batch = [b for b in batch if b is not None]
856
+ if not batch:
857
+ return {}
858
+
859
+ collated = {
860
+ 'original_faces': [b['original_faces'] for b in batch],
861
+ 'original_vertices': [b['original_vertices'] for b in batch],
862
+ 'image_path': [b['image_path'] for b in batch],
863
+ }
864
+
865
+ if 'image' in batch[0] and batch[0]['image'] is not None:
866
+ collated['image'] = torch.stack([b['image'] for b in batch])
867
+
868
+ res_levels = []
869
+ for k in batch[0].keys():
870
+ if k.startswith('gt_vertex_voxels_'):
871
+ try:
872
+ res_levels.append(int(k.split('_')[-1]))
873
+ except ValueError:
874
+ pass
875
+ res_levels.sort()
876
+
877
+ for res in res_levels:
878
+ all_active_voxels_list = []
879
+ all_point_clouds_list = []
880
+
881
+ all_combined_voxels_list = []
882
+ all_combined_labels_list = []
883
+ all_vertex_voxels_only = []
884
+ all_edge_voxels_only = []
885
+ all_edge_endpoints_only = []
886
+ all_combined_endpoints = []
887
+ all_combined_errors_list = []
888
+ layout = []
889
+
890
+ vtx_offset = 0
891
+ adj_flat_offset = 0
892
+ start_idx = 0
893
+
894
+ # Attempt to find device from first tensor
895
+ device = torch.device('cpu')
896
+ for v in batch[0].values():
897
+ if isinstance(v, torch.Tensor):
898
+ device = v.device
899
+ break
900
+
901
+ all_edge_indices_list = []
902
+ vertex_count_offset = 0
903
+
904
+ for i, sample in enumerate(batch):
905
+ vertex_voxels = sample.get(f'gt_vertex_voxels_{res}', torch.empty(0,3,dtype=torch.int32)).to(device)
906
+
907
+ num_vertices = vertex_voxels.shape[0]
908
+
909
+ vertex_labels = torch.zeros(vertex_voxels.shape[0], dtype=torch.long, device=device)
910
+ edge_voxels = sample.get(f'gt_edge_voxels_{res}', torch.empty(0,3,dtype=torch.int32)).to(device)
911
+ edge_labels = torch.ones(edge_voxels.shape[0], dtype=torch.long, device=device)
912
+ edge_endpoints= sample.get(f'gt_edge_endpoints_{res}', torch.empty(0,6,dtype=torch.float32)).to(device)
913
+ edge_errors = sample.get(f'gt_edge_errors_{res}', torch.empty(0,3,dtype=torch.float32)).to(device)
914
+
915
+ vertex_errors = sample.get(f'gt_vertex_errors_{res}', torch.zeros_like(vertex_voxels, dtype=torch.float32)).to(device)
916
+
917
+ if vertex_voxels.numel() > 0:
918
+ idx = torch.full((vertex_voxels.shape[0],1), i, dtype=torch.int32, device=device)
919
+ all_vertex_voxels_only.append(torch.cat([idx, vertex_voxels], dim=1))
920
+
921
+ edge_indices = sample.get(f'gt_vertex_edge_indices_{res}', torch.empty(0, 2, dtype=torch.long)).to(device)
922
+ if edge_indices.numel() > 0:
923
+ shifted_indices = edge_indices + vertex_count_offset
924
+ all_edge_indices_list.append(shifted_indices)
925
+ vertex_count_offset += num_vertices
926
+
927
+ if edge_voxels.numel() > 0:
928
+ idx = torch.full((edge_voxels.shape[0],1), i, dtype=torch.int32, device=device)
929
+ all_edge_voxels_only.append(torch.cat([idx, edge_voxels], dim=1))
930
+ all_edge_endpoints_only.append(
931
+ torch.cat([idx.to(torch.float32), edge_endpoints], dim=1))
932
+
933
+
934
+ if vertex_voxels.numel() + edge_voxels.numel() > 0:
935
+ combined_voxels = torch.cat([vertex_voxels, edge_voxels], dim=0)
936
+ combined_labels = torch.cat([vertex_labels, edge_labels], dim=0)
937
+
938
+ endpoints = torch.zeros(combined_voxels.size(0), 6, dtype=torch.float32, device=device)
939
+ if edge_voxels.numel() > 0:
940
+ endpoints[-edge_voxels.size(0):] = edge_endpoints
941
+ if vertex_voxels.numel() > 0:
942
+ endpoints[:vertex_voxels.size(0)] = vertex_voxels.repeat(1,2).float()
943
+
944
+ combined_errors = torch.cat([vertex_errors, edge_errors], dim=0)
945
+
946
+ batch_idx_int = torch.full((combined_voxels.shape[0],1), i, dtype=torch.int32, device=device)
947
+ all_combined_voxels_list.append(torch.cat([batch_idx_int, combined_voxels], dim=1))
948
+ all_combined_labels_list.append(combined_labels)
949
+
950
+ batch_idx_float = batch_idx_int.to(torch.float32)
951
+ all_combined_endpoints.append(torch.cat([batch_idx_float, endpoints], dim=1))
952
+ all_combined_errors_list.append(torch.cat([batch_idx_float, combined_errors], dim=1))
953
+
954
+ layout.append(slice(start_idx, start_idx + combined_voxels.shape[0]))
955
+ start_idx += combined_voxels.shape[0]
956
+ else:
957
+ layout.append(slice(start_idx, start_idx))
958
+
959
+ # Active Voxels (Sparse Coords)
960
+ active_voxels = sample.get(f'active_voxels_{res}', torch.empty(0, 3, dtype=torch.int32)).to(device)
961
+ if active_voxels.numel() > 0:
962
+ idx = torch.full((active_voxels.shape[0], 1), i, dtype=torch.int32, device=device)
963
+ all_active_voxels_list.append(torch.cat([idx, active_voxels], dim=1))
964
+
965
+ # ==========================================
966
+ # Modified Section: Collect Point Clouds
967
+ # ==========================================
968
+ # pc = sample.get(f'point_cloud_{res}', torch.empty(0, 6, dtype=torch.float32)).to(device)
969
+ pc = sample.get(f'point_cloud_{res}', torch.empty(0, 15, dtype=torch.float32)).to(device)
970
+ # We expect all samples to have point clouds if res == active_voxel_res
971
+ if pc.numel() > 0:
972
+ all_point_clouds_list.append(pc)
973
+
974
+ collated[f'layout_{res}'] = layout
975
+
976
+ def cat_or_empty(lst, shape, dtype):
977
+ return torch.cat(lst, dim=0) if lst else torch.empty(shape, dtype=dtype, device=device)
978
+
979
+ collated[f'combined_voxels_{res}'] = cat_or_empty(all_combined_voxels_list,(0,4),torch.int32)
980
+ collated[f'combined_voxel_labels_{res}'] = cat_or_empty(all_combined_labels_list,(0,),torch.long)
981
+ collated[f'gt_vertex_voxels_{res}'] = cat_or_empty(all_vertex_voxels_only,(0,4),torch.int32)
982
+ collated[f'gt_edge_voxels_{res}'] = cat_or_empty(all_edge_voxels_only,(0,4),torch.int32)
983
+ collated[f'gt_edge_endpoints_{res}'] = cat_or_empty(all_edge_endpoints_only,(0,7),torch.float32)
984
+ collated[f'gt_combined_endpoints_{res}'] = cat_or_empty(all_combined_endpoints,(0,7),torch.float32)
985
+ collated[f'gt_combined_errors_{res}'] = cat_or_empty(all_combined_errors_list,(0,4),torch.float32)
986
+
987
+ collated[f'active_voxels_{res}'] = cat_or_empty(all_active_voxels_list, (0, 4), torch.int32)
988
+
989
+ if all_edge_indices_list:
990
+ collated[f'gt_vertex_edge_indices_{res}'] = torch.cat(all_edge_indices_list, dim=0)
991
+ else:
992
+ collated[f'gt_vertex_edge_indices_{res}'] = torch.empty((0, 2), dtype=torch.long, device=device)
993
+
994
+ if all_point_clouds_list:
995
+ collated[f'point_cloud_{res}'] = torch.stack(all_point_clouds_list, dim=0)
996
+ else:
997
+ # collated[f'point_cloud_{res}'] = torch.empty((0, 6), dtype=torch.float32, device=device)
998
+ collated[f'point_cloud_{res}'] = torch.empty((0, 15), dtype=torch.float32, device=device)
999
+
1000
+ return collated
debug_viz/step_0_batch_0.ply ADDED
The diff for this file is too large to render. See raw diff
 
debug_viz/step_0_batch_1.ply ADDED
The diff for this file is too large to render. See raw diff
 
filter_active_voxels.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from concurrent.futures import ProcessPoolExecutor, as_completed
5
+
6
+ # === 配置 ===
7
+ cache_dir = "/gemini/user/private/zhaotianhao/dataset_cache/MERGED_DATASET_count_200_2000_100000_128to1024_819200"
8
+
9
+ # 1. Edge Voxels (512分辨率) 的筛选阈值
10
+ target_res_edge = 512
11
+ min_edge_voxels = 2000
12
+ max_edge_voxels = 75000
13
+
14
+ # 2. Active Voxels (64分辨率) 的筛选阈值
15
+ # 请根据你的需求调整这两个数值
16
+ target_res_active = 128
17
+ min_active_voxels = 2000 # 举例:最少要有100个粗糙体素
18
+ max_active_voxels = 326780 # 举例:最多8000个粗糙体素
19
+
20
+ save_txt_path = f"/gemini/user/private/zhaotianhao/Triposf/MERGED_DATASET_filtered_{min_edge_voxels}-{max_edge_voxels}edge_{min_active_voxels}-{max_active_voxels}active.txt"
21
+
22
+ # === 单文件统计函数 ===
23
+ def check_voxel_counts(npz_path):
24
+ try:
25
+ # 打开 npz 文件
26
+ with np.load(npz_path) as data:
27
+ # 键名定义
28
+ key_edge = f"combined_voxels_{target_res_edge}"
29
+ key_active = f"active_voxels_{target_res_active}"
30
+
31
+ # 检查键是否存在
32
+ if key_edge not in data or key_active not in data:
33
+ return None
34
+
35
+ # 获取数量
36
+ count_edge = len(data[key_edge])
37
+ count_active = len(data[key_active])
38
+
39
+ # === 核心筛选逻辑 (同时满足两个条件) ===
40
+ is_edge_valid = min_edge_voxels <= count_edge <= max_edge_voxels
41
+ is_active_valid = min_active_voxels <= count_active <= max_active_voxels
42
+
43
+ if is_edge_valid and is_active_valid:
44
+ base_name = os.path.basename(npz_path)
45
+ # 处理文件名
46
+ if base_name.endswith("_precombined.npz"):
47
+ original_name = base_name.replace("_precombined.npz", "")
48
+ else:
49
+ original_name = os.path.splitext(base_name)[0]
50
+
51
+ return (original_name, count_edge, count_active)
52
+
53
+ except Exception:
54
+ return None
55
+ return None
56
+
57
+ # === 获取所有 NPZ 文件 ===
58
+ if not os.path.exists(cache_dir):
59
+ print(f"错误: 缓存目录不存在 {cache_dir}")
60
+ exit()
61
+
62
+ npz_files = [os.path.join(cache_dir, f) for f in os.listdir(cache_dir) if f.endswith(".npz")]
63
+ print(f"共发现 {len(npz_files)} 个缓存文件。开始并行过滤...")
64
+ print(f"筛选条件:")
65
+ print(f" - Edge (512): {min_edge_voxels} ~ {max_edge_voxels}")
66
+ print(f" - Active (64): {min_active_voxels} ~ {max_active_voxels}")
67
+
68
+ # === 并行过滤 ===
69
+ filtered_files = []
70
+ counts_edge = []
71
+ counts_active = []
72
+
73
+ with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
74
+ futures = {executor.submit(check_voxel_counts, path): path for path in npz_files}
75
+
76
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Filtering"):
77
+ result = future.result()
78
+ if result is not None:
79
+ fname, c_edge, c_active = result
80
+ filtered_files.append(fname)
81
+ counts_edge.append(c_edge)
82
+ counts_active.append(c_active)
83
+
84
+ # === 保存结果 ===
85
+ os.makedirs(os.path.dirname(save_txt_path), exist_ok=True)
86
+ with open(save_txt_path, "w") as f:
87
+ for fname in filtered_files:
88
+ f.write(f"{fname}\n")
89
+
90
+ # === 打印统计信息 ===
91
+ print(f"\n✅ 筛选完成:")
92
+ print(f" 符合条件的文件数: {len(filtered_files)} / {len(npz_files)} (保留率: {len(filtered_files)/len(npz_files)*100:.2f}%)")
93
+
94
+ if counts_edge:
95
+ print(f"\n[统计 - Edge Voxels (512)]")
96
+ print(f" 最小值: {min(counts_edge)}")
97
+ print(f" 最大值: {max(counts_edge)}")
98
+ print(f" 平均值: {np.mean(counts_edge):.2f}")
99
+
100
+ if counts_active:
101
+ print(f"\n[统计 - Active Voxels (64)]")
102
+ print(f" 最小值: {min(counts_active)}")
103
+ print(f" 最大值: {max(counts_active)}")
104
+ print(f" 平均值: {np.mean(counts_active):.2f}")
105
+
106
+ print(f"\n 结果已保存到: {save_txt_path}")
generate_npz.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import yaml
4
+ import torch
5
+ import os
6
+ from torch.utils.data import DataLoader
7
+ from functools import partial
8
+
9
+ # Assuming your custom modules are in the same directory or in the Python path
10
+ # from dataset import VoxelVertexDataset_edge, collate_fn_edge
11
+ from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
12
+
13
+ def inspect_batch(batch, batch_idx, device):
14
+ """
15
+ A detailed function to inspect and print information about a single batch.
16
+ """
17
+ print(f"\n{'='*20} Inspecting Batch {batch_idx} {'='*20}")
18
+
19
+ # if batch is None:
20
+ # print("Batch is None. Skipping.")
21
+ # return
22
+
23
+ # print("Batch contains the following keys:")
24
+ # for key in batch.keys():
25
+ # print(f" - {key}")
26
+
27
+ # print(f"{'='*58}")
28
+
29
+
30
+ def main():
31
+ """
32
+ Main function to load configuration, set up the dataset,
33
+ and process a few batches for inspection.
34
+ """
35
+ import argparse
36
+ parser = argparse.ArgumentParser(description="Process and inspect data from the VoxelVertexDataset.")
37
+ # parser.add_argument('config_path', type=str, help='Path to the configuration YAML file.')
38
+ parser.add_argument('--num_batches', type=int, default=3, help='Number of batches to inspect.')
39
+ args = parser.parse_args()
40
+
41
+ # 1. Load Configuration
42
+ # print(f"Loading configuration from: {args.config_path}")
43
+ # with open(args.config_path) as f:
44
+ # cfg = yaml.safe_load(f)
45
+
46
+ # 2. Initialize Device
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ print(f"Using device: {device}")
49
+
50
+ # 3. Initialize Dataset
51
+ print("Initializing dataset...")
52
+ # dataset = VoxelVertexDataset_edge(
53
+ # root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/final_data_decimate_2',
54
+ # base_resolution=512,
55
+ # min_resolution=64,
56
+ # cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/final_data_decimate_60w_2',
57
+ # renders_dir=None,
58
+ # )
59
+ dataset = VoxelVertexDataset_edge(
60
+ root_dir='/root/mesh_split_200complex/mesh_split_200complex_train',
61
+ base_resolution=512,
62
+ min_resolution=64,
63
+ cache_dir='/root/Trisf/dataset_cache/objaverse_200_2000_filtered_final_8354files_512to512',
64
+ renders_dir=None,
65
+
66
+ filter_active_voxels=False,
67
+ cache_filter_path='',
68
+
69
+ active_voxel_res=512,
70
+
71
+ sample_type='dora',
72
+ )
73
+
74
+ # dataset = VoxelVertexDataset_edge(
75
+ # root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/03001627',
76
+ # base_resolution=512,
77
+ # min_resolution=64,
78
+ # cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/03001627',
79
+ # renders_dir=None,
80
+ # )
81
+ # dataset = VoxelVertexDataset_edge(
82
+ # root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/03636649',
83
+ # base_resolution=512,
84
+ # min_resolution=64,
85
+ # cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/03636649',
86
+ # renders_dir=None,
87
+ # )
88
+ # dataset = VoxelVertexDataset_edge(
89
+ # root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/04379243',
90
+ # base_resolution=512,
91
+ # min_resolution=64,
92
+ # cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/04379243',
93
+ # renders_dir=None,
94
+ # )
95
+ print(f"Dataset initialized with {len(dataset)} samples.")
96
+
97
+ # 4. Initialize DataLoader
98
+ # We don't need a DistributedSampler here, just a regular DataLoader.
99
+ print("Initializing DataLoader...")
100
+ dataloader = DataLoader(
101
+ dataset,
102
+ batch_size=1,
103
+ shuffle=False, # Shuffle for a random sample of batches
104
+ collate_fn=partial(collate_fn_pointnet,),
105
+ num_workers=24,
106
+ pin_memory=True,
107
+ )
108
+
109
+ # 5. Data Processing Loop
110
+ print(f"\nStarting data inspection loop for {args.num_batches} batches...")
111
+ for i, batch in enumerate(dataloader):
112
+ inspect_batch(batch, i, device)
113
+
114
+ print("\nData inspection complete.")
115
+
116
+
117
+ if __name__ == '__main__':
118
+ main()
mesh_augment.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import open3d as o3d
4
+ import trimesh
5
+ import random
6
+ from tqdm import tqdm
7
+ from pathlib import Path
8
+
9
+ def augment_obj_file(input_path, output_path, n_augmentations=5):
10
+ """
11
+ Augment an OBJ file with random transformations
12
+ Args:
13
+ input_path: Path to input OBJ file
14
+ output_path: Directory to save augmented files
15
+ n_augmentations: Number of augmented copies to create
16
+ """
17
+ # Create output directory if it doesn't exist
18
+ os.makedirs(output_path, exist_ok=True)
19
+
20
+ # Load the original mesh
21
+ mesh = trimesh.load(input_path)
22
+ original_name = Path(input_path).stem
23
+
24
+ for i in range(n_augmentations):
25
+ # Create a copy of the original mesh
26
+ augmented_mesh = mesh.copy()
27
+
28
+ # Random rotation (0-360 degrees around each axis)
29
+ angle_x = np.random.uniform(0, 2*np.pi)
30
+ angle_y = np.random.uniform(0, 2*np.pi)
31
+ angle_z = np.random.uniform(0, 2*np.pi)
32
+ rotation_matrix = trimesh.transformations.euler_matrix(angle_x, angle_y, angle_z)
33
+ augmented_mesh.apply_transform(rotation_matrix)
34
+
35
+ # Random scaling (0.8-1.2 range)
36
+ scale_factor = np.random.uniform(0.8, 1.2, size=3)
37
+ scale_matrix = np.eye(4)
38
+ scale_matrix[:3, :3] *= scale_factor
39
+ augmented_mesh.apply_transform(scale_matrix)
40
+
41
+ # Random translation (-0.1 to 0.1 range in each dimension)
42
+ translation = np.random.uniform(-0.1, 0.1, size=3)
43
+ translation_matrix = np.eye(4)
44
+ translation_matrix[:3, 3] = translation
45
+ augmented_mesh.apply_transform(translation_matrix)
46
+
47
+ # Save the augmented mesh
48
+ output_file = os.path.join(output_path, f"{original_name}_aug_{i}.obj")
49
+ augmented_mesh.export(output_file)
50
+
51
+ def augment_all_objs(source_dir, target_dir, n_augmentations=5):
52
+ """
53
+ Augment all OBJ files in a directory
54
+ Args:
55
+ source_dir: Directory containing original OBJ files
56
+ target_dir: Directory to save augmented files
57
+ n_augmentations: Number of augmented copies per file
58
+ """
59
+ # Get all OBJ files in source directory
60
+ obj_files = [f for f in os.listdir(source_dir) if f.endswith('.obj')]
61
+
62
+ print(f"Found {len(obj_files)} OBJ files to augment")
63
+ print(f"Will create {n_augmentations} augmented versions per file")
64
+
65
+ # Process each file
66
+ for obj_file in tqdm(obj_files, desc="Augmenting OBJ files"):
67
+ input_path = os.path.join(source_dir, obj_file)
68
+ augment_obj_file(input_path, target_dir, n_augmentations)
69
+
70
+ print(f"Finished! Augmented files saved to: {target_dir}")
71
+
72
+ if __name__ == "__main__":
73
+ # Configuration
74
+ N_AUGMENTATIONS = 10 # Number of augmented copies per file
75
+ SOURCE_DIR = "/root/shapenet_data/train_mesh_data_under_25kb_1000/train" # Directory with original OBJ files
76
+ TARGET_DIR = f"/root/shapenet_data/train_mesh_data_under_25kb_1000/train_{N_AUGMENTATIONS}augment" # Where to save augmented files
77
+
78
+ # Run augmentation
79
+ augment_all_objs(SOURCE_DIR, TARGET_DIR, N_AUGMENTATIONS)
metric.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import numpy as np
5
+ import warnings
6
+ import trimesh
7
+ from scipy.stats import entropy
8
+ from sklearn.neighbors import NearestNeighbors
9
+ from numpy.linalg import norm
10
+ from tqdm.auto import tqdm
11
+
12
+ # ==============================================================================
13
+ # 用户配置 (User Configuration)
14
+ # ==============================================================================
15
+ # --- 路径设置 ---
16
+ # !!重要提示!!
17
+ # 当您计算真实指标时, 请确保这两个路径指向不同的文件夹
18
+ # 这里为了方便测试, 设置为相同路径。当两个路径相同时:
19
+ # MMD->0, COV->1.0, 1-NNA->0.5, JSD->0
20
+ # 而 CD 和 HD 会是一个很小的值, 代表同一mesh两次不同采样的差异。
21
+
22
+ # GENERATED_MESH_DIR = "/root/mesh_split_200complex/mesh_split_200complex_test" # 存放生成的 .obj 文件的文件夹路径
23
+ # GT_MESH_DIR = "/root/mesh_split_200complex/mesh_split_200complex_test" # 存放真实的 .obj 文件的文件夹路径
24
+
25
+ GENERATED_MESH_DIR = "/root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs" # 存放生成的 .obj 文件的文件夹路径
26
+ GT_MESH_DIR = "/root/Trisf/abalation_post_processing/gt_mesh" # 存放真实的 .obj 文件的文件夹路径
27
+
28
+ # --- 采样和计算参数 ---
29
+ NUM_POINTS_PER_MESH = 2048 # 从每个mesh表面采样的点数
30
+ BATCH_SIZE = 32 # 计算指标时使用的批次大小,根据显存调整
31
+ JSD_RESOLUTION = 28 # JSD计算中体素网格的分辨率
32
+
33
+ # ==============================================================================
34
+ # 核心功能函数: Mesh处理
35
+ # ==============================================================================
36
+
37
+ def process_meshes_in_folder(folder_path, num_points):
38
+ """
39
+ 加载文件夹中所有的 .obj 文件, 将它们采样成点云, 并进行归一化。
40
+ """
41
+ # 按文件名排序以确保一一对应
42
+ mesh_files = sorted(glob.glob(os.path.join(folder_path, '*.obj')))
43
+ if not mesh_files:
44
+ raise FileNotFoundError(f"在文件夹 '{folder_path}' 中没有找到任何 .obj 文件。")
45
+
46
+ all_point_clouds = []
47
+ print(f"正在从 '{folder_path}' 处理 {len(mesh_files)} 个mesh...")
48
+
49
+ for mesh_path in tqdm(mesh_files, desc=f'处理 {os.path.basename(folder_path)}'):
50
+ try:
51
+ mesh = trimesh.load(mesh_path, process=False)
52
+
53
+ # 归一化: 移动到原点并缩放到单位球体内
54
+ center = mesh.bounds.mean(axis=0)
55
+ mesh.apply_translation(-center)
56
+ max_dist = np.max(np.linalg.norm(mesh.vertices, axis=1))
57
+ if max_dist > 0:
58
+ mesh.apply_scale(1.0 / max_dist)
59
+
60
+ points, _ = trimesh.sample.sample_surface(mesh, num_points)
61
+
62
+ if points.shape[0] != num_points:
63
+ # print(f"警告: {mesh_path} 采样点数 {points.shape[0]} != {num_points}, 进行重采样。")
64
+ indices = np.random.choice(points.shape[0], num_points, replace=True)
65
+ points = points[indices]
66
+
67
+ all_point_clouds.append(points)
68
+
69
+ except Exception as e:
70
+ print(f"错误:加载或处理文件 {mesh_path} 失败: {e}")
71
+
72
+ return np.array(all_point_clouds)
73
+
74
+ # ==============================================================================
75
+ # 评估指标代码 (来自 PointFlow 及新增)
76
+ # ==============================================================================
77
+
78
+ _EMD_NOT_IMPL_WARNED = False
79
+ def emd_approx(sample, ref):
80
+ global _EMD_NOT_IMPL_WARNED
81
+ emd = torch.zeros([sample.size(0)]).to(sample)
82
+ if not _EMD_NOT_IMPL_WARNED:
83
+ _EMD_NOT_IMPL_WARNED = True
84
+ print('\n\n[WARNING] EMD is not implemented. Setting to zero.')
85
+ return emd
86
+
87
+ def distChamfer(a, b):
88
+ x, y = a, b
89
+ bs, num_points, points_dim = x.size()
90
+ xx = torch.bmm(x, x.transpose(2, 1))
91
+ yy = torch.bmm(y, y.transpose(2, 1))
92
+ zz = torch.bmm(x, y.transpose(2, 1))
93
+ diag_ind = torch.arange(0, num_points, device=a.device).long()
94
+ rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
95
+ ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
96
+ P = (rx.transpose(2, 1) + ry - 2 * zz)
97
+ # P is batch_size x n_points x n_points matrix of squared distances
98
+ return P.min(1)[0], P.min(2)[0]
99
+
100
+ def compute_cd_hd(sample_pcs, ref_pcs, batch_size):
101
+ """
102
+ 计算平均成对的Chamfer Distance (CD) 和 Hausdorff Distance (HD)。
103
+ """
104
+ print("\n--- 开始计算 平均Chamfer和Hausdorff距离 ---")
105
+ N_sample = sample_pcs.shape[0]
106
+ N_ref = ref_pcs.shape[0]
107
+
108
+ assert N_sample == N_ref, f"用于成对度量计算的集合大小必须相等, 但得到 {N_sample} 和 {N_ref}"
109
+
110
+ cd_all = []
111
+ hd_all = []
112
+
113
+ iterator = range(0, N_sample, batch_size)
114
+ for b_start in tqdm(iterator, desc='计算 CD/HD'):
115
+ b_end = min(N_sample, b_start + batch_size)
116
+ sample_batch = sample_pcs[b_start:b_end]
117
+ ref_batch = ref_pcs[b_start:b_end]
118
+
119
+ # distChamfer返回的是平方距离
120
+ dist1_sq, dist2_sq = distChamfer(sample_batch, ref_batch)
121
+
122
+ # 计算 Chamfer Distance
123
+ cd_batch = dist1_sq.mean(dim=1) + dist2_sq.mean(dim=1)
124
+ cd_all.append(cd_batch)
125
+
126
+ # 计算 Hausdorff Distance
127
+ # HD = max(max(min_dist_1), max(min_dist_2))
128
+ # 我们需要对平方距离开方来得到真实距离
129
+ hd_batch = torch.max(dist1_sq.max(dim=1)[0], dist2_sq.max(dim=1)[0]).sqrt()
130
+ hd_all.append(hd_batch)
131
+
132
+ cd_all = torch.cat(cd_all)
133
+ hd_all = torch.cat(hd_all)
134
+
135
+ results = {
136
+ 'Chamfer-L2': cd_all.mean(),
137
+ 'Hausdorff': hd_all.mean(),
138
+ }
139
+ return results
140
+
141
+ def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, verbose=True):
142
+ N_sample = sample_pcs.shape[0]
143
+ N_ref = ref_pcs.shape[0]
144
+ all_cd = []
145
+ iterator = range(N_sample)
146
+ if verbose:
147
+ iterator = tqdm(iterator, desc='计算点云间距离')
148
+ for i in iterator:
149
+ sample_batch = sample_pcs[i]
150
+ cd_lst = []
151
+ sub_iterator = range(0, N_ref, batch_size)
152
+ for b_start in sub_iterator:
153
+ b_end = min(N_ref, b_start + batch_size)
154
+ ref_batch = ref_pcs[b_start:b_end]
155
+ batch_size_ref = ref_batch.size(0)
156
+ sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1).contiguous()
157
+ dl, dr = distChamfer(sample_batch_exp, ref_batch)
158
+ cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
159
+ cd_lst = torch.cat(cd_lst, dim=1)
160
+ all_cd.append(cd_lst)
161
+ all_cd = torch.cat(all_cd, dim=0)
162
+ # EMD is not implemented, so we return a dummy tensor for it
163
+ all_emd = torch.zeros_like(all_cd)
164
+ return all_cd, all_emd
165
+
166
+ def knn(Mxx, Mxy, Myy, k, sqrt=False):
167
+ n0, n1 = Mxx.size(0), Myy.size(0)
168
+ device = Mxx.device
169
+
170
+ ones_tensor = torch.ones(n0, device=device)
171
+ zeros_tensor = torch.zeros(n1, device=device)
172
+ label = torch.cat((ones_tensor, zeros_tensor))
173
+
174
+ M = torch.cat([torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.t(), Myy), 1)], 0)
175
+ if sqrt: M = M.abs().sqrt()
176
+
177
+ diag_inf = torch.diag(torch.full((n0 + n1,), float('inf'), device=device))
178
+ val, idx = (M + diag_inf).topk(k, 0, False)
179
+
180
+ count = torch.zeros(n0 + n1, device=device)
181
+ for i in range(k):
182
+ count.add_(label.index_select(0, idx[i]))
183
+
184
+ threshold = torch.full((n0 + n1,), float(k) / 2, device=device)
185
+ pred = (count >= threshold).float()
186
+
187
+ return {'acc': (label == pred).float().mean()}
188
+
189
+ def lgan_mmd_cov(all_dist):
190
+ N_sample, N_ref = all_dist.shape
191
+ min_val, min_idx = all_dist.min(dim=1) # For each sample, find closest ref
192
+ mmd_smp = min_val.mean() # MMD-smp
193
+
194
+ min_val_ref, _ = all_dist.min(dim=0) # For each ref, find closest sample
195
+ mmd = min_val_ref.mean() # MMD-ref
196
+
197
+ cov = min_idx.unique().numel() / float(N_ref)
198
+ cov = torch.tensor(cov, device=all_dist.device)
199
+
200
+ return {'lgan_mmd': mmd, 'lgan_cov': cov}
201
+
202
+ def compute_mmd_cov_1nna(sample_pcs, ref_pcs, batch_size):
203
+ results = {}
204
+ print("\n--- 开始计算 MMD-CD, COV-CD, 1-NNA-CD ---")
205
+
206
+ M_rs_cd, _ = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) # ref vs sample
207
+
208
+ res_cd = lgan_mmd_cov(M_rs_cd.t()) # Transpose to get sample vs ref
209
+ results.update({f"{k}-CD": v for k, v in res_cd.items()})
210
+
211
+ M_rr_cd, _ = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
212
+ M_ss_cd, _ = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
213
+
214
+ one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1)
215
+ results.update({"1-NNA-CD": one_nn_cd_res['acc']})
216
+
217
+ return results
218
+
219
+ def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
220
+ grid = np.linspace(-0.5, 0.5, resolution)
221
+ x, y, z = np.meshgrid(grid, grid, grid, indexing='ij')
222
+ grid = np.stack([x, y, z], axis=-1).reshape(-1, 3)
223
+ if clip_sphere:
224
+ grid = grid[norm(grid, axis=1) <= 0.5]
225
+ return grid
226
+
227
+ def entropy_of_occupancy_grid(pclouds, grid_resolution):
228
+ grid_coords = unit_cube_grid_point_cloud(grid_resolution, True)
229
+ grid_counters = np.zeros(len(grid_coords))
230
+ nn = NearestNeighbors(n_neighbors=1).fit(grid_coords)
231
+
232
+ for pc in tqdm(pclouds, desc='计算占据网格'):
233
+ _, indices = nn.kneighbors(pc)
234
+ indices = np.unique(indices.squeeze())
235
+ grid_counters[indices] += 1
236
+ return grid_counters
237
+
238
+ def jensen_shannon_divergence(P, Q):
239
+ P_ = P / (P.sum() + 1e-9)
240
+ Q_ = Q / (Q.sum() + 1e-9)
241
+ M = 0.5 * (P_ + Q_)
242
+ return 0.5 * (entropy(P_, M, base=2) + entropy(Q_, M, base=2))
243
+
244
+ def compute_jsd(sample_pcs, ref_pcs, resolution):
245
+ print("\n--- 开始计算 JSD ---")
246
+ sample_grid_dist = entropy_of_occupancy_grid(sample_pcs, resolution)
247
+ ref_grid_dist = entropy_of_occupancy_grid(ref_pcs, resolution)
248
+ jsd = jensen_shannon_divergence(sample_grid_dist, ref_grid_dist)
249
+ return jsd
250
+
251
+ # ==============================================================================
252
+ # 主执行函数 (Main Execution)
253
+ # ==============================================================================
254
+ if __name__ == '__main__':
255
+ # 1. 加载并处理Meshes为点云 (Numpy arrays)
256
+ sample_pcs_np = process_meshes_in_folder(GENERATED_MESH_DIR, NUM_POINTS_PER_MESH)
257
+ ref_pcs_np = process_meshes_in_folder(GT_MESH_DIR, NUM_POINTS_PER_MESH)
258
+
259
+ print(f"\n加载完成: {sample_pcs_np.shape[0]} 个生成点云, {ref_pcs_np.shape[0]} 个真实点云。")
260
+ print(f"每个点云包含 {sample_pcs_np.shape[1]} 个点。")
261
+
262
+ # 2. 设置设备并转换数据为PyTorch Tensors
263
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
264
+ print(f"使用设备: {device}")
265
+
266
+ sample_pcs_torch = torch.from_numpy(sample_pcs_np).float().to(device)
267
+ ref_pcs_torch = torch.from_numpy(ref_pcs_np).float().to(device)
268
+
269
+ # 3. 计算分布度量: MMD, COV, 1-NNA (使用PyTorch)
270
+ metrics_results = compute_mmd_cov_1nna(sample_pcs_torch, ref_pcs_torch, BATCH_SIZE)
271
+
272
+ # 4. 计算成对几何度量: CD, HD (使用PyTorch)
273
+ cd_hd_results = compute_cd_hd(sample_pcs_torch, ref_pcs_torch, BATCH_SIZE)
274
+ metrics_results.update(cd_hd_results) # 合并结果
275
+
276
+ # 5. 计算JSD (使用Numpy)
277
+ jsd_result = compute_jsd(sample_pcs_np, ref_pcs_np, JSD_RESOLUTION)
278
+
279
+ # 6. 打印最终结果
280
+ print("\n==================================================")
281
+ print(" 评估结果")
282
+ print("==================================================")
283
+
284
+ print("\n--- 分布质量与多样性 (Distribution Metrics) ---")
285
+ # MMD: 越低越好 (质量)
286
+ print(f"{'lgan_mmd-CD':<12s}: {metrics_results['lgan_mmd-CD'].item():.6f} (↓ Lower is better)")
287
+ # COV: 越高越好 (多样性)
288
+ print(f"{'lgan_cov-CD':<12s}: {metrics_results['lgan_cov-CD'].item():.6f} (↑ Higher is better)")
289
+ # 1-NNA: 越接近0.5越好 (真实性)
290
+ print(f"{'1-NNA-CD':<12s}: {metrics_results['1-NNA-CD'].item():.6f} (→ Closer to 0.5 is better)")
291
+ # JSD: 越低越好 (分布相似性)
292
+ print(f"{'JSD':<12s}: {jsd_result:.6f} (↓ Lower is better)")
293
+
294
+ print("\n--- 平均几何保真度 (Average Geometric Fidelity) ---")
295
+ # CD: 越低越好
296
+ print(f"{'Chamfer-L2':<12s}: {metrics_results['Chamfer-L2'].item():.6f} (↓ Lower is better)")
297
+ # HD: 越低越好
298
+ print(f"{'Hausdorff':<12s}: {metrics_results['Hausdorff'].item():.6f} (↓ Lower is better)")
299
+
300
+ print("==================================================")
metric_cd.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ import trimesh
6
+ from tqdm import tqdm
7
+
8
+ # =====================================================
9
+ # 🔹 Mesh归一化函数
10
+ # =====================================================
11
+ def normalize_to_unit_sphere(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
12
+ """将mesh平移到原点并缩放到单位球内"""
13
+ vertices = mesh.vertices
14
+ centroid = vertices.mean(axis=0)
15
+ vertices = vertices - centroid
16
+ scale = np.max(np.linalg.norm(vertices, axis=1))
17
+ vertices = vertices / scale
18
+ mesh.vertices = vertices
19
+ return mesh
20
+
21
+ def normalize_to_unit_cube(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
22
+ """将mesh平移并缩放到[-1,1]^3单位立方体内"""
23
+ bbox_min, bbox_max = mesh.bounds
24
+ center = (bbox_min + bbox_max) / 2
25
+ scale = (bbox_max - bbox_min).max() / 2
26
+ mesh.vertices = (mesh.vertices - center) / scale
27
+ return mesh
28
+
29
+ # =====================================================
30
+ # 🔹 点云采样函数 + 返回面片数
31
+ # =====================================================
32
+ def sample_points_from_mesh(mesh_path: str, num_points: int, normalize: str = "none"):
33
+ """
34
+ 从mesh文件采样点云,并可选归一化。
35
+ 返回: (points: Tensor, face_count: int)
36
+ """
37
+ try:
38
+ mesh = trimesh.load(mesh_path, force='mesh', process=False)
39
+ if normalize == "sphere":
40
+ mesh = normalize_to_unit_sphere(mesh)
41
+ elif normalize == "cube":
42
+ mesh = normalize_to_unit_cube(mesh)
43
+ points, _ = trimesh.sample.sample_surface(mesh, num_points)
44
+ face_count = len(mesh.faces)
45
+ return torch.from_numpy(points).float(), face_count
46
+ except Exception as e:
47
+ print(f"[-] 警告:加载或采样文件失败 {mesh_path}。错误: {e}")
48
+ return None, 0
49
+
50
+ # =====================================================
51
+ # 🔹 Chamfer Distance 计算函数
52
+ # =====================================================
53
+ def find_minimum_cd_batched(gen_pc: torch.Tensor, gt_pcs_batch: torch.Tensor):
54
+ """计算生成点云到一批GT点云的最小CD及对应索引"""
55
+ gen_pc_batch = gen_pc.unsqueeze(0).expand(gt_pcs_batch.size(0), -1, -1)
56
+ dist_matrix = torch.cdist(gen_pc_batch, gt_pcs_batch)
57
+ min_dist_gen_to_gt = dist_matrix.min(2).values.mean(1)
58
+ min_dist_gt_to_gen = dist_matrix.min(1).values.mean(1)
59
+ cd_scores_for_one_gen = min_dist_gen_to_gt + min_dist_gt_to_gen
60
+ min_cd, min_idx = cd_scores_for_one_gen.min(0)
61
+ return min_cd.item(), min_idx.item()
62
+
63
+ # =====================================================
64
+ # 🔹 主流程
65
+ # =====================================================
66
+ def main(args):
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ print(f"[*] 使用设备: {device}")
69
+ print(f"[*] 归一化模式: {args.normalize}")
70
+
71
+ # --- Step 1: 加载GT网格并采样 ---
72
+ print("[*] 正在预加载并采样所有GT网格...")
73
+ gt_files = sorted([f for f in os.listdir(args.gt_dir) if f.endswith(('.obj', '.ply', '.off'))])
74
+ if not gt_files:
75
+ print(f"[-] 错误: GT目录中未找到mesh文件: {args.gt_dir}")
76
+ return
77
+
78
+ gt_point_clouds, gt_faces_counts = [], []
79
+ for gt_filename in tqdm(gt_files, desc="预处理GT网格"):
80
+ gt_filepath = os.path.join(args.gt_dir, gt_filename)
81
+ pc, fnum = sample_points_from_mesh(gt_filepath, args.num_points, args.normalize)
82
+ if pc is not None:
83
+ gt_point_clouds.append(pc.to(device))
84
+ gt_faces_counts.append(fnum)
85
+
86
+ if not gt_point_clouds:
87
+ print("[-] 错误: 无法从任何GT文件采样点云。")
88
+ return
89
+
90
+ print(f"[*] 成功加载 {len(gt_point_clouds)} 个GT点云。")
91
+
92
+ # --- Step 2: 遍历生成的网格 ---
93
+ gen_files = sorted([f for f in os.listdir(args.generated_dir) if f.endswith(('.obj', '.ply', '.off'))])
94
+ if not gen_files:
95
+ print(f"[-] 错误: 生成目录中未找到mesh文件: {args.generated_dir}")
96
+ return
97
+
98
+ all_min_cd_scores = []
99
+ face_ratios = []
100
+ pred_faces_all = []
101
+ gt_faces_matched = []
102
+
103
+ for gen_filename in tqdm(gen_files, desc="评估生成的网格"):
104
+ gen_filepath = os.path.join(args.generated_dir, gen_filename)
105
+ gen_pc, gen_face_count = sample_points_from_mesh(gen_filepath, args.num_points, args.normalize)
106
+ if gen_pc is None:
107
+ continue
108
+
109
+ gen_pc = gen_pc.to(device)
110
+ batch_size = args.batch_size
111
+ min_cd_for_this_gen = float('inf')
112
+ matched_gt_idx = -1
113
+
114
+ for i in range(0, len(gt_point_clouds), batch_size):
115
+ gt_pcs_batch = torch.stack(gt_point_clouds[i:i + batch_size])
116
+ min_cd_in_batch, idx_in_batch = find_minimum_cd_batched(gen_pc, gt_pcs_batch)
117
+ if min_cd_in_batch < min_cd_for_this_gen:
118
+ min_cd_for_this_gen = min_cd_in_batch
119
+ matched_gt_idx = i + idx_in_batch
120
+
121
+ all_min_cd_scores.append(min_cd_for_this_gen)
122
+ if matched_gt_idx >= 0:
123
+ gt_face_count = gt_faces_counts[matched_gt_idx]
124
+ face_ratio = gen_face_count / gt_face_count if gt_face_count > 0 else 0
125
+ face_ratios.append(face_ratio)
126
+ pred_faces_all.append(gen_face_count)
127
+ gt_faces_matched.append(gt_face_count)
128
+ if not args.quiet:
129
+ print(f" -> {gen_filename}: 最小CD={min_cd_for_this_gen:.6f}, Pred面数={gen_face_count}, GT面数={gt_face_count}, 比值={face_ratio:.3f}")
130
+
131
+ # --- Step 3: 汇总 ---
132
+ if not all_min_cd_scores:
133
+ print("\n[-] 评估结束,但没有成功处理任何网格。")
134
+ else:
135
+ mean_min_cd = np.mean(all_min_cd_scores)
136
+ mean_face_ratio = np.mean(face_ratios) if face_ratios else 0
137
+ mean_pred_faces = np.mean(pred_faces_all) if pred_faces_all else 0
138
+ mean_gt_faces = np.mean(gt_faces_matched) if gt_faces_matched else 0
139
+
140
+ print("\n" + "="*70)
141
+ print(f"[*] 评估完成 (基于最小CD匹配)")
142
+ print(f"[*] 共评估 {len(all_min_cd_scores)} 个生成网格")
143
+ print(f"[*] 平均最小倒角距离 (Mean Min CD): {mean_min_cd:.6f}")
144
+ print(f"[*] 平均Pred面片数: {mean_pred_faces:.1f}")
145
+ print(f"[*] 平均GT面片数: {mean_gt_faces:.1f}")
146
+ print(f"[*] 平均面片比 (Pred/GT): {mean_face_ratio:.3f}")
147
+ print("="*70)
148
+
149
+
150
+ # =====================================================
151
+ # 🔹 命令行接口
152
+ # =====================================================
153
+ if __name__ == "__main__":
154
+ parser = argparse.ArgumentParser(description="评估生成mesh与GT集合的最小Chamfer Distance及面片数比")
155
+
156
+ parser.add_argument("--generated_dir", type=str, required=True, help="生成的mesh文件夹路径")
157
+ parser.add_argument("--gt_dir", type=str, required=True, help="GT网格文件夹路径")
158
+ parser.add_argument("--num_points", type=int, default=10000, help="每个mesh采样点数")
159
+ parser.add_argument("--batch_size", type=int, default=16, help="与多少个GT点云进行批处理比较")
160
+ parser.add_argument("--normalize", type=str, default="none", choices=["none", "sphere", "cube"], help="归一化模式: none | sphere | cube")
161
+ parser.add_argument("--quiet", action="store_true", help="静默模式,只输出最终平均CD")
162
+
163
+ args = parser.parse_args()
164
+ main(args)
165
+
166
+
167
+
168
+
169
+ '''
170
+ # 不归一化
171
+ python metric_cd.py \
172
+ --generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs \
173
+ --gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
174
+ --num_points 4096 \
175
+ --normalize none
176
+
177
+ # 归一化到单位球
178
+ python metric_cd.py \
179
+ --generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs/0.8_1.5 \
180
+ --gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
181
+ --num_points 4096 \
182
+ --normalize sphere
183
+
184
+ # 归一化到单位立方体
185
+ python metric_cd.py \
186
+ --generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs \
187
+ --gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
188
+ --num_points 4096 \
189
+ --normalize cube
190
+ '''
query_point.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import einsum
4
+ import torch.nn.functional as F
5
+ from functools import partial
6
+ from timm.models.layers import DropPath
7
+ from einops import rearrange, repeat
8
+
9
+ # ---- PE: NeRF-style Position Encoding ----
10
+ class Embedder:
11
+ def __init__(self, **kwargs):
12
+ self.kwargs = kwargs
13
+ self.create_embedding_fn()
14
+
15
+ def create_embedding_fn(self):
16
+ embed_fns = []
17
+ d = self.kwargs['input_dims']
18
+ out_dim = 0
19
+ if self.kwargs['include_input']:
20
+ embed_fns.append(self.identity_fn)
21
+ out_dim += d
22
+
23
+ max_freq = self.kwargs['max_freq_log2']
24
+ N_freqs = self.kwargs['num_freqs']
25
+
26
+ if self.kwargs['log_sampling']:
27
+ freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
28
+ else:
29
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
30
+
31
+ for freq in freq_bands:
32
+ for p_fn in self.kwargs['periodic_fns']:
33
+ embed_fns.append(partial(self.periodic_fn, p_fn=p_fn, freq=freq))
34
+ out_dim += d
35
+
36
+ self.embed_fns = embed_fns
37
+ self.out_dim = out_dim
38
+
39
+ def identity_fn(self, x):
40
+ return x
41
+
42
+ def periodic_fn(self, x, p_fn, freq):
43
+ return p_fn(x * freq)
44
+
45
+ def embed(self, inputs):
46
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
47
+
48
+ def get_embedder(multires, i=0):
49
+ if i == -1:
50
+ return nn.Identity(), 1
51
+
52
+ embed_kwargs = {
53
+ 'include_input': True,
54
+ 'input_dims': 1,
55
+ 'max_freq_log2': multires-1,
56
+ 'num_freqs': multires,
57
+ 'log_sampling': True,
58
+ 'periodic_fns': [torch.sin, torch.cos],
59
+ }
60
+
61
+ embedder_obj = Embedder(**embed_kwargs)
62
+ embed = embedder_obj.embed
63
+ return embed, embedder_obj.out_dim
64
+
65
+
66
+ class PE_NeRF(nn.Module):
67
+ def __init__(self, out_channels=512, multires=10):
68
+ super().__init__()
69
+
70
+ self.multires = multires
71
+ self.embed_fn, embed_dim_per_dim = get_embedder(multires) # per-dim embed
72
+ self.embed_dim = embed_dim_per_dim * 3 # since 3D: x, y, z
73
+
74
+ self.coor_embed = nn.Sequential(
75
+ nn.Linear(self.embed_dim, 256),
76
+ nn.GELU(),
77
+ nn.Linear(256, out_channels)
78
+ )
79
+
80
+ def forward(self, vertices: torch.Tensor) -> torch.Tensor:
81
+ """
82
+ Args:
83
+ vertices: [B, 3] or [N, 3], coordinates in [-0.5, 0.5]
84
+ Returns:
85
+ encoded: [B, out_channels * 3]
86
+ """
87
+ x_embed = self.embed_fn(vertices[..., 0:1]) # [N, D]
88
+ y_embed = self.embed_fn(vertices[..., 1:2])
89
+ z_embed = self.embed_fn(vertices[..., 2:3])
90
+
91
+ pos_enc = torch.cat([x_embed, y_embed, z_embed], dim=-1) # [N, D * 3]
92
+
93
+ return self.coor_embed(pos_enc)
94
+
95
+ def exists(val):
96
+ return val is not None
97
+
98
+ def default(val, d):
99
+ return val if exists(val) else d
100
+
101
+ # ---- Attention & FF blocks ----
102
+ class GEGLU(nn.Module):
103
+ def forward(self, x):
104
+ x, gate = x.chunk(2, dim=-1)
105
+ return x * F.gelu(gate)
106
+
107
+
108
+ class FeedForward(nn.Module):
109
+ def __init__(self, dim, mult=4):
110
+ super().__init__()
111
+ self.net = nn.Sequential(
112
+ nn.Linear(dim, dim * mult * 2),
113
+ GEGLU(),
114
+ nn.Linear(dim * mult, dim)
115
+ )
116
+
117
+ def forward(self, x):
118
+ return self.net(x)
119
+
120
+
121
+ class PreNorm(nn.Module):
122
+ def __init__(self, dim, fn, context_dim = None):
123
+ super().__init__()
124
+ self.fn = fn
125
+ self.norm = nn.LayerNorm(dim)
126
+ self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
127
+
128
+ def forward(self, x, **kwargs):
129
+ x = self.norm(x)
130
+
131
+ if exists(self.norm_context):
132
+ context = kwargs['context']
133
+ normed_context = self.norm_context(context)
134
+ kwargs.update(context = normed_context)
135
+
136
+ return self.fn(x, **kwargs)
137
+
138
+
139
+ class Attention(nn.Module):
140
+ def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
141
+ super().__init__()
142
+ inner_dim = dim_head * heads
143
+ context_dim = default(context_dim, query_dim)
144
+ self.scale = dim_head ** -0.5
145
+ self.heads = heads
146
+
147
+ self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
148
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
149
+ self.to_out = nn.Linear(inner_dim, query_dim)
150
+
151
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
152
+
153
+ def forward(self, x, context = None, mask = None):
154
+ h = self.heads
155
+
156
+ q = self.to_q(x)
157
+ context = default(context, x)
158
+ k, v = self.to_kv(context).chunk(2, dim = -1)
159
+
160
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
161
+
162
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
163
+
164
+ if exists(mask):
165
+ mask = rearrange(mask, 'b ... -> b (...)')
166
+ max_neg_value = -torch.finfo(sim.dtype).max
167
+ mask = repeat(mask, 'b j -> (b h) () j', h = h)
168
+ sim.masked_fill_(~mask, max_neg_value)
169
+
170
+ # attention, what we cannot get enough of
171
+ attn = sim.softmax(dim = -1)
172
+
173
+ out = einsum('b i j, b j d -> b i d', attn, v)
174
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
175
+ return self.drop_path(self.to_out(out))
176
+
177
+ class QueryPointDecoder(nn.Module):
178
+ def __init__(self, query_dim=1536, context_dim=512, output_dim=1, depth=8,
179
+ using_nerf=True, quantize_bits=10, dim=512, heads=8, multires=10):
180
+ super().__init__()
181
+ self.using_nerf = using_nerf
182
+ self.depth = depth
183
+
184
+ if using_nerf:
185
+ self.pe = PE_NeRF(out_channels=query_dim, multires=multires)
186
+ else:
187
+ self.embedding_x = nn.Embedding(2**quantize_bits, query_dim // 3)
188
+ self.embedding_y = nn.Embedding(2**quantize_bits, query_dim // 3)
189
+ self.embedding_z = nn.Embedding(2**quantize_bits, query_dim // 3)
190
+ self.coord_proj = nn.Sequential(
191
+ nn.Linear(query_dim, query_dim * 4),
192
+ nn.GELU(),
193
+ nn.Linear(query_dim * 4, query_dim)
194
+ )
195
+
196
+ # self.context_proj = nn.Linear(context_dim, query_dim)
197
+ self.context_proj = nn.Linear(context_dim, dim)
198
+
199
+ self.pe_ctx = PE_NeRF(out_channels=dim, multires=multires)
200
+
201
+ self.context_self_attn_layers = nn.ModuleList([
202
+ nn.ModuleList([
203
+ PreNorm(dim, Attention(dim, dim_head=64, heads=heads)),
204
+ PreNorm(dim, FeedForward(dim))
205
+ ]) for _ in range(depth)
206
+ ])
207
+
208
+ self.cross_attn = PreNorm(dim,
209
+ Attention(dim, dim,
210
+ dim_head=dim, heads=1))
211
+ self.cross_ff = PreNorm(dim, FeedForward(dim))
212
+
213
+ self.to_outputs = nn.Linear(dim, output_dim)
214
+
215
+ def forward(self, query_points, context_feats, context_mask=None, voxels_coords=None,):
216
+ B, N, _ = query_points.shape
217
+
218
+ if self.using_nerf:
219
+ # print('query_points.min()', query_points.min())
220
+ # print('query_points.max()', query_points.max())
221
+ x = self.pe(query_points.view(-1, 3)).view(B, N, -1)
222
+ else:
223
+ embeddings = torch.cat([
224
+ self.embedding_x(query_points[..., 0]),
225
+ self.embedding_y(query_points[..., 1]),
226
+ self.embedding_z(query_points[..., 2]),
227
+ ], dim=-1)
228
+ x = self.coord_proj(embeddings)
229
+
230
+ context = self.context_proj(context_feats)
231
+
232
+ if voxels_coords is not None:
233
+ M = voxels_coords.shape[1]
234
+ normalized_coords = 2.0 * (voxels_coords.float() / 1024.) - 1.0
235
+ context += self.pe_ctx(normalized_coords.view(-1, 3)).view(B, M, -1)
236
+
237
+ attn_mask = context_mask[:, None, None, :] if context_mask is not None else None
238
+
239
+ for self_attn, ff in self.context_self_attn_layers:
240
+ context = self_attn(context, mask=attn_mask) + context
241
+ context = ff(context) + context
242
+
243
+ latents = self.cross_attn(x, context=context, mask=attn_mask)
244
+ latents = self.cross_ff(x) + latents
245
+
246
+ return self.to_outputs(latents).squeeze(-1)
247
+
248
+ if __name__ == '__main__':
249
+ torch.manual_seed(42)
250
+ model = QueryPointDecoder().cuda()
251
+ model.eval()
252
+
253
+ B, N, M = 2, 64, 20
254
+ query_pts = torch.rand(B, N, 3).cuda() - 0.5 # [-0.5, 0.5]
255
+ context_feats = torch.randn(B, M, 512).cuda()
256
+
257
+ with torch.no_grad():
258
+ logits = model(query_pts, context_feats)
259
+ print("Logits shape:", logits.shape) # [B, N, 1]
test_slat_flow_128to1024_pointnet.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import yaml
6
+ import time
7
+ from datetime import datetime
8
+ from torch.utils.data import DataLoader
9
+ from functools import partial
10
+ import torch.nn.functional as F
11
+ from torch.amp import GradScaler, autocast
12
+ from typing import *
13
+ from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig, Dinov2Model, AutoImageProcessor, Dinov2Config
14
+ import torch
15
+ import re
16
+ from utils import load_pretrained_woself
17
+
18
+ from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
19
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
20
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet
21
+
22
+ from trellis.models.structured_latent_flow import SLatFlowModel
23
+ from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
24
+
25
+ from trellis.pipelines.samplers import FlowEulerSampler
26
+ from safetensors.torch import load_file
27
+ import open3d as o3d
28
+ from PIL import Image
29
+
30
+ from triposf.modules.sparse.basic import SparseTensor
31
+ from trellis.modules.sparse.basic import SparseTensor as SparseTensor_trellis
32
+
33
+ from triposf.modules.utils import DiagonalGaussianDistribution
34
+
35
+ from sklearn.decomposition import PCA
36
+ import trimesh
37
+ import torchvision.transforms as transforms
38
+
39
+ # --- Helper Functions ---
40
+ def save_colored_ply(points, colors, filename):
41
+ if len(points) == 0:
42
+ print(f"[Warning] No points to save for {filename}")
43
+ return
44
+ # Ensure colors are uint8
45
+ if colors.max() <= 1.0:
46
+ colors = (colors * 255).astype(np.uint8)
47
+ colors = colors.astype(np.uint8)
48
+
49
+ # Add Alpha if missing
50
+ if colors.shape[1] == 3:
51
+ colors = np.hstack([colors, np.full((len(colors), 1), 255, dtype=np.uint8)])
52
+
53
+ cloud = trimesh.PointCloud(points, colors=colors)
54
+ cloud.export(filename)
55
+ print(f"Saved colored point cloud to {filename}")
56
+
57
+ def normalize_to_rgb(features_3d):
58
+ min_vals = features_3d.min(axis=0)
59
+ max_vals = features_3d.max(axis=0)
60
+ range_vals = max_vals - min_vals
61
+ range_vals[range_vals == 0] = 1
62
+ normalized = (features_3d - min_vals) / range_vals
63
+ return (normalized * 255).astype(np.uint8)
64
+
65
+ class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
66
+ def __init__(self, *args, **kwargs):
67
+ super().__init__(*args, **kwargs)
68
+ self.cfg = kwargs.pop('cfg', None)
69
+ if self.cfg is None:
70
+ raise ValueError("Configuration dictionary 'cfg' must be provided.")
71
+
72
+ self.sampler = FlowEulerSampler(sigma_min=1.e-5)
73
+ self.device = torch.device("cuda")
74
+
75
+ # Based on PointNet Encoder setting
76
+ self.resolution = 128
77
+
78
+ self.condition_type = 'image'
79
+ self.is_cond = False
80
+
81
+ self.img_res = 518
82
+ self.feature_dim = self.cfg['model']['latent_dim']
83
+
84
+ self._init_components(
85
+ clip_model_path=self.cfg['training'].get('clip_model_path', None),
86
+ dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
87
+ vae_path=self.cfg['training']['vae_path'],
88
+ )
89
+
90
+ # Classifier head removed as it is not part of the Active Voxel pipeline
91
+
92
+ def _load_denoiser(self, denoiser_checkpoint_path):
93
+ path = denoiser_checkpoint_path
94
+ if not path or not os.path.isfile(path):
95
+ print("No valid checkpoint path provided for fine-tuning. Starting from scratch.")
96
+ return
97
+
98
+ print(f"Loading checkpoint from: {path}")
99
+ checkpoint = torch.load(path, map_location=self.device)
100
+
101
+ try:
102
+ denoiser_state_dict = checkpoint['denoiser']
103
+ # Handle DDP prefix
104
+ if next(iter(denoiser_state_dict)).startswith('module.'):
105
+ denoiser_state_dict = {k[7:]: v for k, v in denoiser_state_dict.items()}
106
+
107
+ self.denoiser.load_state_dict(denoiser_state_dict)
108
+ print("Denoiser weights loaded successfully.")
109
+ except KeyError:
110
+ print("[WARN] 'denoiser' key not found in checkpoint. Skipping.")
111
+ except Exception as e:
112
+ print(f"[ERROR] Failed to load denoiser state_dict: {e}")
113
+
114
+ def _init_components(self,
115
+ clip_model_path=None,
116
+ dinov2_model_path=None,
117
+ vae_path=None,
118
+ ):
119
+
120
+ # 1. Initialize PointNet Voxel Encoder (Matches Training)
121
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
122
+ in_channels=15,
123
+ hidden_dim=256,
124
+ out_channels=1024,
125
+ scatter_type='mean',
126
+ n_blocks=5,
127
+ resolution=128,
128
+ add_label=False,
129
+ ).to(self.device)
130
+
131
+ # 2. Initialize VAE
132
+ self.vae = VoxelVAE(
133
+ in_channels=self.cfg['model']['in_channels'],
134
+ latent_dim=self.cfg['model']['latent_dim'],
135
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
136
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
137
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
138
+ num_heads=8,
139
+ num_head_channels=64,
140
+ mlp_ratio=4.0,
141
+ attn_mode="swin",
142
+ window_size=8,
143
+ pe_mode="ape",
144
+ use_fp16=False,
145
+ use_checkpoint=False,
146
+ qk_rms_norm=False,
147
+ using_subdivide=True,
148
+ using_attn=self.cfg['model']['using_attn'],
149
+ attn_first=self.cfg['model'].get('attn_first', True),
150
+ pred_direction=self.cfg['model'].get('pred_direction', False),
151
+ ).to(self.device)
152
+
153
+ # 3. Initialize Dataset with collate_fn_pointnet
154
+ self.dataset = VoxelVertexDataset_edge(
155
+ root_dir=self.cfg['dataset']['path'],
156
+ base_resolution=self.cfg['dataset']['base_resolution'],
157
+ min_resolution=self.cfg['dataset']['min_resolution'],
158
+ cache_dir=self.cfg['dataset']['cache_dir'],
159
+ renders_dir=self.cfg['dataset']['renders_dir'],
160
+
161
+ process_img=False,
162
+
163
+ active_voxel_res=128,
164
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
165
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
166
+ sample_type=self.cfg['dataset'].get('sample_type', 'dora'),
167
+ )
168
+
169
+ self.dataloader = DataLoader(
170
+ self.dataset,
171
+ batch_size=1,
172
+ shuffle=True,
173
+ collate_fn=partial(collate_fn_pointnet,), # Critical Change
174
+ num_workers=4,
175
+ pin_memory=True,
176
+ persistent_workers=True,
177
+ )
178
+
179
+ # 4. Load Pretrained Weights
180
+ # Assuming vae_path contains 'voxel_encoder' and 'vae'
181
+ print(f"Loading VAE/Encoder from {vae_path}")
182
+ ckpt = torch.load(vae_path, map_location='cpu')
183
+
184
+ # Load VAE
185
+ if 'vae' in ckpt:
186
+ self.vae.load_state_dict(ckpt['vae'], strict=False)
187
+ else:
188
+ self.vae.load_state_dict(ckpt) # Fallback
189
+
190
+ # Load Encoder
191
+ if 'voxel_encoder' in ckpt:
192
+ self.voxel_encoder.load_state_dict(ckpt['voxel_encoder'])
193
+ else:
194
+ print("[WARN] 'voxel_encoder' not found in checkpoint, random init (BAD for inference).")
195
+
196
+ self.voxel_encoder.eval()
197
+ self.vae.eval()
198
+
199
+ # 5. Initialize Conditioning Model
200
+ if self.condition_type == 'text':
201
+ self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
202
+ self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
203
+ elif self.condition_type == 'image':
204
+ model_name = 'dinov2_vitl14_reg'
205
+ local_repo_path = "/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main"
206
+ weights_path = "/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
207
+
208
+ dinov2_model = torch.hub.load(
209
+ repo_or_dir=local_repo_path,
210
+ model=model_name,
211
+ source='local',
212
+ pretrained=False
213
+ )
214
+ self.condition_model = dinov2_model
215
+ self.condition_model.load_state_dict(torch.load(weights_path))
216
+
217
+ self.image_cond_model_transform = transforms.Compose([
218
+ transforms.ToTensor(),
219
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
220
+ ])
221
+ else:
222
+ raise ValueError(f"Unsupported condition type: {self.condition_type}")
223
+
224
+ self.condition_model.to(self.device).eval()
225
+
226
+ @torch.no_grad()
227
+ def encode_image(self, images) -> torch.Tensor:
228
+ if isinstance(images, torch.Tensor):
229
+ batch_tensor = images.to(self.device)
230
+ elif isinstance(images, list):
231
+ assert all(isinstance(i, Image.Image) for i in images)
232
+ image = [i.resize((518, 518), Image.LANCZOS) for i in images]
233
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
234
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
235
+ batch_tensor = torch.stack(image).to(self.device)
236
+ else:
237
+ raise ValueError(f"Unsupported type of image: {type(images)}")
238
+
239
+ if batch_tensor.shape[-2:] != (518, 518):
240
+ batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
241
+
242
+ features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
243
+ patchtokens = F.layer_norm(features, features.shape[-1:])
244
+ return patchtokens
245
+
246
+ def process_batch(self, batch):
247
+ preprocessed_images = batch['image']
248
+ cond_ = self.encode_image(preprocessed_images)
249
+ return cond_
250
+
251
+ def eval(self):
252
+ # Unconditional Setup
253
+ if self.is_cond == False:
254
+ if self.condition_type == 'text':
255
+ txt = ['']
256
+ encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
257
+ tokens = encoding['input_ids'].to(self.device)
258
+ with torch.no_grad():
259
+ cond_ = self.condition_model(input_ids=tokens).last_hidden_state
260
+ else:
261
+ blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
262
+ with torch.no_grad():
263
+ dummy_cond = self.encode_image([blank_img])
264
+ cond_ = torch.zeros_like(dummy_cond)
265
+ print(f"Generated unconditional image prompt (zero tensor) with shape: {cond_.shape}")
266
+
267
+ self.denoiser.eval()
268
+
269
+ # Load Denoiser Checkpoint
270
+ # Update this path to your ACTIVE VOXEL trained checkpoint
271
+ checkpoint_path = '/gemini/user/private/zhaotianhao/output_slat_flow_matching_active/ckpts/8780_complex_128to1024_rope/checkpoint_step90000_loss0_694290.pt'
272
+
273
+ self._load_denoiser(checkpoint_path)
274
+
275
+ filename = os.path.basename(checkpoint_path)
276
+ match = re.search(r'step(\d+)', filename)
277
+ step_str = match.group(1) if match else "eval"
278
+ save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_1000complex")
279
+ os.makedirs(save_dir, exist_ok=True)
280
+ print(f"Results will be saved to: {save_dir}")
281
+
282
+ for i, batch in enumerate(self.dataloader):
283
+ if i > 50: exit() # Visualize first 10
284
+
285
+ if self.is_cond and self.condition_type == 'image':
286
+ cond_ = self.process_batch(batch)
287
+
288
+ if cond_.shape[0] != 1:
289
+ cond_ = cond_.expand(batch['active_voxels_128'].shape[0], -1, -1).contiguous().to(self.device)
290
+ else:
291
+ cond_ = cond_.to(self.device)
292
+
293
+ # --- Data Retrieval (Matches collate_fn_pointnet) ---
294
+ point_cloud = batch['point_cloud_128'].to(self.device)
295
+ active_coords = batch['active_voxels_128'].to(self.device) # [N, 4]
296
+
297
+ with autocast(device_type='cuda', dtype=torch.bfloat16):
298
+ with torch.no_grad():
299
+ # 1. Encode Ground Truth Latents
300
+ active_voxel_feats = self.voxel_encoder(
301
+ p=point_cloud,
302
+ sparse_coords=active_coords,
303
+ res=128,
304
+ bbox_size=(-0.5, 0.5),
305
+ )
306
+
307
+ sparse_input = SparseTensor(
308
+ feats=active_voxel_feats,
309
+ coords=active_coords.int()
310
+ )
311
+
312
+ # Encode to get GT distribution
313
+ gt_latents, posterior = self.vae.encode(sparse_input)
314
+
315
+ print(f"Batch {i}: Active voxels: {active_coords.shape[0]}")
316
+
317
+ # 2. Generation / Sampling
318
+ # Generate noise on the SAME active coordinates
319
+ noise = SparseTensor_trellis(
320
+ coords=active_coords.int(),
321
+ feats=torch.randn(
322
+ active_coords.shape[0],
323
+ self.feature_dim,
324
+ device=self.device,
325
+ ),
326
+ )
327
+
328
+ sample_results = self.sampler.sample(
329
+ model=self.denoiser.float(),
330
+ noise=noise.to(self.device).float(),
331
+ cond=cond_.to(self.device).float(),
332
+ steps=50,
333
+ rescale_t=1.0,
334
+ verbose=True,
335
+ )
336
+
337
+ generated_sparse_tensor = sample_results.samples
338
+ generated_coords = generated_sparse_tensor.coords
339
+ generated_features = generated_sparse_tensor.feats
340
+
341
+ print('Gen features mean:', generated_features.mean().item(), 'std:', generated_features.std().item())
342
+ print('GT features mean:', gt_latents.feats.mean().item(), 'std:', gt_latents.feats.std().item())
343
+ print('MSE:', F.mse_loss(generated_features, gt_latents.feats).item())
344
+
345
+ # --- Visualization (PCA) ---
346
+ gt_feats_np = gt_latents.feats.detach().cpu().numpy()
347
+ gen_feats_np = generated_features.detach().cpu().numpy()
348
+ coords_np = active_coords[:, 1:4].detach().cpu().numpy() # x, y, z
349
+
350
+ print("Visualizing features using PCA...")
351
+ pca = PCA(n_components=3)
352
+
353
+ # Fit PCA on GT, transform both
354
+ pca.fit(gt_feats_np)
355
+ gt_feats_3d = pca.transform(gt_feats_np)
356
+ gen_feats_3d = pca.transform(gen_feats_np)
357
+
358
+ gt_colors = normalize_to_rgb(gt_feats_3d)
359
+ gen_colors = normalize_to_rgb(gen_feats_3d)
360
+
361
+ # Save PLYs
362
+ save_colored_ply(coords_np, gt_colors, os.path.join(save_dir, f"batch_{i}_gt_pca.ply"))
363
+ save_colored_ply(coords_np, gen_colors, os.path.join(save_dir, f"batch_{i}_gen_pca.ply"))
364
+
365
+ # Save Tensors for further analysis
366
+ torch.save(gt_latents, os.path.join(save_dir, f"gt_latent_{i}.pt"))
367
+
368
+ torch.save(batch, os.path.join(save_dir, f"gt_data_batch_{i}.pt"))
369
+ torch.save(sample_results.samples, os.path.join(save_dir, f"sample_latent_{i}.pt"))
370
+
371
+ if __name__ == '__main__':
372
+ torch.manual_seed(42)
373
+ config_path = "/gemini/user/private/zhaotianhao/Triposf/config_slat_flow_128to1024_pointnet_test.yaml"
374
+ with open(config_path) as f:
375
+ cfg = yaml.safe_load(f)
376
+
377
+ # Initialize Model on CPU first
378
+ diffusion_model = SLatFlowModel(
379
+ resolution=cfg['flow']['resolution'],
380
+ in_channels=cfg['flow']['in_channels'],
381
+ out_channels=cfg['flow']['out_channels'],
382
+ model_channels=cfg['flow']['model_channels'],
383
+ cond_channels=cfg['flow']['cond_channels'],
384
+ num_blocks=cfg['flow']['num_blocks'],
385
+ num_heads=cfg['flow']['num_heads'],
386
+ mlp_ratio=cfg['flow']['mlp_ratio'],
387
+ patch_size=cfg['flow']['patch_size'],
388
+ num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
389
+ io_block_channels=cfg['flow']['io_block_channels'],
390
+ pe_mode=cfg['flow']['pe_mode'],
391
+ qk_rms_norm=cfg['flow']['qk_rms_norm'],
392
+ qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
393
+ use_fp16=cfg['flow'].get('use_fp16', False),
394
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
395
+
396
+ trainer = SLatFlowMatchingTrainer(
397
+ denoiser=diffusion_model,
398
+ t_schedule=cfg['t_schedule'],
399
+ sigma_min=cfg['sigma_min'],
400
+ cfg=cfg,
401
+ )
402
+
403
+ trainer.eval()
test_slat_flow_128to256_pointnet.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import yaml
6
+ import time
7
+ from datetime import datetime
8
+ from torch.utils.data import DataLoader
9
+ from functools import partial
10
+ import torch.nn.functional as F
11
+ from torch.amp import GradScaler, autocast
12
+ from typing import *
13
+ from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig, Dinov2Model, AutoImageProcessor, Dinov2Config
14
+ import torch
15
+ import re
16
+ from utils import load_pretrained_woself
17
+
18
+ from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
19
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
20
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet
21
+
22
+ from trellis.models.structured_latent_flow import SLatFlowModel
23
+ from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
24
+
25
+ from trellis.pipelines.samplers import FlowEulerSampler
26
+ from safetensors.torch import load_file
27
+ import open3d as o3d
28
+ from PIL import Image
29
+
30
+ from triposf.modules.sparse.basic import SparseTensor
31
+ from trellis.modules.sparse.basic import SparseTensor as SparseTensor_trellis
32
+
33
+ from triposf.modules.utils import DiagonalGaussianDistribution
34
+
35
+ from sklearn.decomposition import PCA
36
+ import trimesh
37
+ import torchvision.transforms as transforms
38
+
39
+ # --- Helper Functions ---
40
+ def save_colored_ply(points, colors, filename):
41
+ if len(points) == 0:
42
+ print(f"[Warning] No points to save for {filename}")
43
+ return
44
+ # Ensure colors are uint8
45
+ if colors.max() <= 1.0:
46
+ colors = (colors * 255).astype(np.uint8)
47
+ colors = colors.astype(np.uint8)
48
+
49
+ # Add Alpha if missing
50
+ if colors.shape[1] == 3:
51
+ colors = np.hstack([colors, np.full((len(colors), 1), 255, dtype=np.uint8)])
52
+
53
+ cloud = trimesh.PointCloud(points, colors=colors)
54
+ cloud.export(filename)
55
+ print(f"Saved colored point cloud to {filename}")
56
+
57
+ def normalize_to_rgb(features_3d):
58
+ min_vals = features_3d.min(axis=0)
59
+ max_vals = features_3d.max(axis=0)
60
+ range_vals = max_vals - min_vals
61
+ range_vals[range_vals == 0] = 1
62
+ normalized = (features_3d - min_vals) / range_vals
63
+ return (normalized * 255).astype(np.uint8)
64
+
65
+ class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
66
+ def __init__(self, *args, **kwargs):
67
+ super().__init__(*args, **kwargs)
68
+ self.cfg = kwargs.pop('cfg', None)
69
+ if self.cfg is None:
70
+ raise ValueError("Configuration dictionary 'cfg' must be provided.")
71
+
72
+ self.sampler = FlowEulerSampler(sigma_min=1.e-5)
73
+ self.device = torch.device("cuda")
74
+
75
+ # Based on PointNet Encoder setting
76
+ self.resolution = 128
77
+
78
+ self.condition_type = 'image'
79
+ self.is_cond = False
80
+
81
+ self.img_res = 518
82
+ self.feature_dim = self.cfg['model']['latent_dim']
83
+
84
+ self._init_components(
85
+ clip_model_path=self.cfg['training'].get('clip_model_path', None),
86
+ dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
87
+ vae_path=self.cfg['training']['vae_path'],
88
+ )
89
+
90
+ # Classifier head removed as it is not part of the Active Voxel pipeline
91
+
92
+ def _load_denoiser(self, denoiser_checkpoint_path):
93
+ path = denoiser_checkpoint_path
94
+ if not path or not os.path.isfile(path):
95
+ print("No valid checkpoint path provided for fine-tuning. Starting from scratch.")
96
+ return
97
+
98
+ print(f"Loading checkpoint from: {path}")
99
+ checkpoint = torch.load(path, map_location=self.device)
100
+
101
+ try:
102
+ denoiser_state_dict = checkpoint['denoiser']
103
+ # Handle DDP prefix
104
+ if next(iter(denoiser_state_dict)).startswith('module.'):
105
+ denoiser_state_dict = {k[7:]: v for k, v in denoiser_state_dict.items()}
106
+
107
+ self.denoiser.load_state_dict(denoiser_state_dict)
108
+ print("Denoiser weights loaded successfully.")
109
+ except KeyError:
110
+ print("[WARN] 'denoiser' key not found in checkpoint. Skipping.")
111
+ except Exception as e:
112
+ print(f"[ERROR] Failed to load denoiser state_dict: {e}")
113
+
114
+ def _init_components(self,
115
+ clip_model_path=None,
116
+ dinov2_model_path=None,
117
+ vae_path=None,
118
+ ):
119
+
120
+ # 1. Initialize PointNet Voxel Encoder (Matches Training)
121
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
122
+ in_channels=15,
123
+ hidden_dim=256,
124
+ out_channels=1024,
125
+ scatter_type='mean',
126
+ n_blocks=5,
127
+ resolution=128,
128
+ add_label=False,
129
+ ).to(self.device)
130
+
131
+ # 2. Initialize VAE
132
+ self.vae = VoxelVAE(
133
+ in_channels=self.cfg['model']['in_channels'],
134
+ latent_dim=self.cfg['model']['latent_dim'],
135
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
136
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
137
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
138
+ num_heads=8,
139
+ num_head_channels=64,
140
+ mlp_ratio=4.0,
141
+ attn_mode="swin",
142
+ window_size=8,
143
+ pe_mode="ape",
144
+ use_fp16=False,
145
+ use_checkpoint=False,
146
+ qk_rms_norm=False,
147
+ using_subdivide=True,
148
+ using_attn=self.cfg['model']['using_attn'],
149
+ attn_first=self.cfg['model'].get('attn_first', True),
150
+ pred_direction=self.cfg['model'].get('pred_direction', False),
151
+ ).to(self.device)
152
+
153
+ # 3. Initialize Dataset with collate_fn_pointnet
154
+ self.dataset = VoxelVertexDataset_edge(
155
+ root_dir=self.cfg['dataset']['path'],
156
+ base_resolution=self.cfg['dataset']['base_resolution'],
157
+ min_resolution=self.cfg['dataset']['min_resolution'],
158
+ cache_dir=self.cfg['dataset']['cache_dir'],
159
+ renders_dir=self.cfg['dataset']['renders_dir'],
160
+
161
+ process_img=False,
162
+
163
+ active_voxel_res=128,
164
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
165
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
166
+ sample_type=self.cfg['dataset'].get('sample_type', 'dora'),
167
+ )
168
+
169
+ self.dataloader = DataLoader(
170
+ self.dataset,
171
+ batch_size=1,
172
+ shuffle=True,
173
+ collate_fn=partial(collate_fn_pointnet,), # Critical Change
174
+ num_workers=4,
175
+ pin_memory=True,
176
+ persistent_workers=True,
177
+ )
178
+
179
+ # 4. Load Pretrained Weights
180
+ # Assuming vae_path contains 'voxel_encoder' and 'vae'
181
+ print(f"Loading VAE/Encoder from {vae_path}")
182
+ ckpt = torch.load(vae_path, map_location='cpu')
183
+
184
+ # Load VAE
185
+ if 'vae' in ckpt:
186
+ self.vae.load_state_dict(ckpt['vae'], strict=False)
187
+ else:
188
+ self.vae.load_state_dict(ckpt) # Fallback
189
+
190
+ # Load Encoder
191
+ if 'voxel_encoder' in ckpt:
192
+ self.voxel_encoder.load_state_dict(ckpt['voxel_encoder'])
193
+ else:
194
+ print("[WARN] 'voxel_encoder' not found in checkpoint, random init (BAD for inference).")
195
+
196
+ self.voxel_encoder.eval()
197
+ self.vae.eval()
198
+
199
+ # 5. Initialize Conditioning Model
200
+ if self.condition_type == 'text':
201
+ self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
202
+ self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
203
+ elif self.condition_type == 'image':
204
+ model_name = 'dinov2_vitl14_reg'
205
+ local_repo_path = "/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main"
206
+ weights_path = "/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
207
+
208
+ dinov2_model = torch.hub.load(
209
+ repo_or_dir=local_repo_path,
210
+ model=model_name,
211
+ source='local',
212
+ pretrained=False
213
+ )
214
+ self.condition_model = dinov2_model
215
+ self.condition_model.load_state_dict(torch.load(weights_path))
216
+
217
+ self.image_cond_model_transform = transforms.Compose([
218
+ transforms.ToTensor(),
219
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
220
+ ])
221
+ else:
222
+ raise ValueError(f"Unsupported condition type: {self.condition_type}")
223
+
224
+ self.condition_model.to(self.device).eval()
225
+
226
+ @torch.no_grad()
227
+ def encode_image(self, images) -> torch.Tensor:
228
+ if isinstance(images, torch.Tensor):
229
+ batch_tensor = images.to(self.device)
230
+ elif isinstance(images, list):
231
+ assert all(isinstance(i, Image.Image) for i in images)
232
+ image = [i.resize((518, 518), Image.LANCZOS) for i in images]
233
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
234
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
235
+ batch_tensor = torch.stack(image).to(self.device)
236
+ else:
237
+ raise ValueError(f"Unsupported type of image: {type(images)}")
238
+
239
+ if batch_tensor.shape[-2:] != (518, 518):
240
+ batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
241
+
242
+ features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
243
+ patchtokens = F.layer_norm(features, features.shape[-1:])
244
+ return patchtokens
245
+
246
+ def process_batch(self, batch):
247
+ preprocessed_images = batch['image']
248
+ cond_ = self.encode_image(preprocessed_images)
249
+ return cond_
250
+
251
+ def eval(self):
252
+ # Unconditional Setup
253
+ if self.is_cond == False:
254
+ if self.condition_type == 'text':
255
+ txt = ['']
256
+ encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
257
+ tokens = encoding['input_ids'].to(self.device)
258
+ with torch.no_grad():
259
+ cond_ = self.condition_model(input_ids=tokens).last_hidden_state
260
+ else:
261
+ blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
262
+ with torch.no_grad():
263
+ dummy_cond = self.encode_image([blank_img])
264
+ cond_ = torch.zeros_like(dummy_cond)
265
+ print(f"Generated unconditional image prompt (zero tensor) with shape: {cond_.shape}")
266
+
267
+ self.denoiser.eval()
268
+
269
+ # Load Denoiser Checkpoint
270
+ # Update this path to your ACTIVE VOXEL trained checkpoint
271
+ checkpoint_path = '/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/checkpoint_step215000_loss0_332666.pt'
272
+
273
+ self._load_denoiser(checkpoint_path)
274
+
275
+ filename = os.path.basename(checkpoint_path)
276
+ match = re.search(r'step(\d+)', filename)
277
+ step_str = match.group(1) if match else "eval"
278
+ save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_1000complex")
279
+ os.makedirs(save_dir, exist_ok=True)
280
+ print(f"Results will be saved to: {save_dir}")
281
+
282
+ for i, batch in enumerate(self.dataloader):
283
+ if i > 50: exit() # Visualize first 10
284
+
285
+ if self.is_cond and self.condition_type == 'image':
286
+ cond_ = self.process_batch(batch)
287
+
288
+ if cond_.shape[0] != 1:
289
+ cond_ = cond_.expand(batch['active_voxels_128'].shape[0], -1, -1).contiguous().to(self.device)
290
+ else:
291
+ cond_ = cond_.to(self.device)
292
+
293
+ # --- Data Retrieval (Matches collate_fn_pointnet) ---
294
+ point_cloud = batch['point_cloud_128'].to(self.device)
295
+ active_coords = batch['active_voxels_128'].to(self.device) # [N, 4]
296
+
297
+ with autocast(device_type='cuda', dtype=torch.bfloat16):
298
+ with torch.no_grad():
299
+ # 1. Encode Ground Truth Latents
300
+ active_voxel_feats = self.voxel_encoder(
301
+ p=point_cloud,
302
+ sparse_coords=active_coords,
303
+ res=128,
304
+ bbox_size=(-0.5, 0.5),
305
+ )
306
+
307
+ sparse_input = SparseTensor(
308
+ feats=active_voxel_feats,
309
+ coords=active_coords.int()
310
+ )
311
+
312
+ # Encode to get GT distribution
313
+ gt_latents, posterior = self.vae.encode(sparse_input)
314
+
315
+ print(f"Batch {i}: Active voxels: {active_coords.shape[0]}")
316
+
317
+ # 2. Generation / Sampling
318
+ # Generate noise on the SAME active coordinates
319
+ noise = SparseTensor_trellis(
320
+ coords=active_coords.int(),
321
+ feats=torch.randn(
322
+ active_coords.shape[0],
323
+ self.feature_dim,
324
+ device=self.device,
325
+ ),
326
+ )
327
+
328
+ sample_results = self.sampler.sample(
329
+ model=self.denoiser.float(),
330
+ noise=noise.to(self.device).float(),
331
+ cond=cond_.to(self.device).float(),
332
+ steps=50,
333
+ rescale_t=1.0,
334
+ verbose=True,
335
+ )
336
+
337
+ generated_sparse_tensor = sample_results.samples
338
+ generated_coords = generated_sparse_tensor.coords
339
+ generated_features = generated_sparse_tensor.feats
340
+
341
+ print('Gen features mean:', generated_features.mean().item(), 'std:', generated_features.std().item())
342
+ print('GT features mean:', gt_latents.feats.mean().item(), 'std:', gt_latents.feats.std().item())
343
+ print('MSE:', F.mse_loss(generated_features, gt_latents.feats).item())
344
+
345
+ # --- Visualization (PCA) ---
346
+ gt_feats_np = gt_latents.feats.detach().cpu().numpy()
347
+ gen_feats_np = generated_features.detach().cpu().numpy()
348
+ coords_np = active_coords[:, 1:4].detach().cpu().numpy() # x, y, z
349
+
350
+ print("Visualizing features using PCA...")
351
+ pca = PCA(n_components=3)
352
+
353
+ # Fit PCA on GT, transform both
354
+ pca.fit(gt_feats_np)
355
+ gt_feats_3d = pca.transform(gt_feats_np)
356
+ gen_feats_3d = pca.transform(gen_feats_np)
357
+
358
+ gt_colors = normalize_to_rgb(gt_feats_3d)
359
+ gen_colors = normalize_to_rgb(gen_feats_3d)
360
+
361
+ # Save PLYs
362
+ save_colored_ply(coords_np, gt_colors, os.path.join(save_dir, f"batch_{i}_gt_pca.ply"))
363
+ save_colored_ply(coords_np, gen_colors, os.path.join(save_dir, f"batch_{i}_gen_pca.ply"))
364
+
365
+ # Save Tensors for further analysis
366
+ torch.save(gt_latents, os.path.join(save_dir, f"gt_latent_{i}.pt"))
367
+
368
+ torch.save(batch, os.path.join(save_dir, f"gt_data_batch_{i}.pt"))
369
+ torch.save(sample_results.samples, os.path.join(save_dir, f"sample_latent_{i}.pt"))
370
+
371
+ if __name__ == '__main__':
372
+ torch.manual_seed(42)
373
+ config_path = "/gemini/user/private/zhaotianhao/Triposf/config_slat_flow_128to256_pointnet_test.yaml"
374
+ with open(config_path) as f:
375
+ cfg = yaml.safe_load(f)
376
+
377
+ # Initialize Model on CPU first
378
+ diffusion_model = SLatFlowModel(
379
+ resolution=cfg['flow']['resolution'],
380
+ in_channels=cfg['flow']['in_channels'],
381
+ out_channels=cfg['flow']['out_channels'],
382
+ model_channels=cfg['flow']['model_channels'],
383
+ cond_channels=cfg['flow']['cond_channels'],
384
+ num_blocks=cfg['flow']['num_blocks'],
385
+ num_heads=cfg['flow']['num_heads'],
386
+ mlp_ratio=cfg['flow']['mlp_ratio'],
387
+ patch_size=cfg['flow']['patch_size'],
388
+ num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
389
+ io_block_channels=cfg['flow']['io_block_channels'],
390
+ pe_mode=cfg['flow']['pe_mode'],
391
+ qk_rms_norm=cfg['flow']['qk_rms_norm'],
392
+ qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
393
+ use_fp16=cfg['flow'].get('use_fp16', False),
394
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
395
+
396
+ trainer = SLatFlowMatchingTrainer(
397
+ denoiser=diffusion_model,
398
+ t_schedule=cfg['t_schedule'],
399
+ sigma_min=cfg['sigma_min'],
400
+ cfg=cfg,
401
+ )
402
+
403
+ trainer.eval()
test_slat_vae_128to1024_pointnet.py ADDED
The diff for this file is too large to render. See raw diff
 
test_slat_vae_128to1024_pointnet_vae.py ADDED
The diff for this file is too large to render. See raw diff
 
test_slat_vae_128to1024_pointnet_vae_addhead.py ADDED
The diff for this file is too large to render. See raw diff
 
test_slat_vae_128to1024_pointnet_vae_head.py ADDED
@@ -0,0 +1,1339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from tqdm import tqdm
7
+ from collections import defaultdict
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+ from torch.utils.data import DataLoader, Subset
11
+ from triposf.modules.sparse.basic import SparseTensor
12
+
13
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
14
+
15
+ from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead
16
+ from utils import load_pretrained_woself
17
+
18
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
19
+
20
+
21
+ from functools import partial
22
+ import itertools
23
+ from typing import List, Tuple, Set
24
+ from collections import OrderedDict
25
+ from scipy.spatial import cKDTree
26
+ from sklearn.neighbors import KDTree
27
+
28
+ import trimesh
29
+
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import time
33
+
34
+ from sklearn.decomposition import PCA
35
+ import matplotlib.pyplot as plt
36
+
37
+ import networkx as nx
38
+
39
+ def predict_mesh_connectivity(
40
+ connection_head,
41
+ vtx_feats,
42
+ vtx_coords,
43
+ batch_size=10000,
44
+ threshold=0.5,
45
+ k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测
46
+ device='cuda'
47
+ ):
48
+ """
49
+ Args:
50
+ connection_head: 训练好的 MLP 模型
51
+ vtx_feats: [N, C] 顶点特征
52
+ vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边)
53
+ batch_size: MLP 推理的 batch size
54
+ threshold: 判定连接的概率阈值
55
+ k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。
56
+ """
57
+ num_verts = vtx_feats.shape[0]
58
+ if num_verts < 3:
59
+ return [], [] # 无法构成三角形
60
+
61
+ connection_head.eval()
62
+
63
+ # --- 1. 生成候选边 (Candidate Edges) ---
64
+ if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts:
65
+ # 策略 A: 局部 KNN (推荐)
66
+ # 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版
67
+ # 或者直接暴力 cdist 如果 N < 10000
68
+
69
+ # 为了简单且高效,这里演示简单的 cdist (注意显存)
70
+ # 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree
71
+ dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N]
72
+
73
+ # 取 topk (smallest distance),排除自己
74
+ # values: [N, K], indices: [N, K]
75
+ _, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False)
76
+ neighbor_indices = indices[:, 1:] # 去掉第一列(自己)
77
+
78
+ # 构建 source, target 索引
79
+ src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten()
80
+ dst = neighbor_indices.flatten()
81
+
82
+ # 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重
83
+ # 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可
84
+ # 这里为了简单,我们生成 u < v 的 mask
85
+ mask = src < dst
86
+ u_indices = src[mask]
87
+ v_indices = dst[mask]
88
+
89
+ else:
90
+ # 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用
91
+ u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device)
92
+
93
+ # --- 2. 批量推理 ---
94
+ all_probs = []
95
+ num_candidates = u_indices.shape[0]
96
+
97
+ with torch.no_grad():
98
+ for i in range(0, num_candidates, batch_size):
99
+ end = min(i + batch_size, num_candidates)
100
+ batch_u = u_indices[i:end]
101
+ batch_v = v_indices[i:end]
102
+
103
+ feat_u = vtx_feats[batch_u]
104
+ feat_v = vtx_feats[batch_v]
105
+
106
+ # Symmetric Forward (和你训练时保持一致)
107
+ # A -> B
108
+ input_uv = torch.cat([feat_u, feat_v], dim=-1)
109
+ logits_uv = connection_head(input_uv)
110
+
111
+ # B -> A
112
+ input_vu = torch.cat([feat_v, feat_u], dim=-1)
113
+ logits_vu = connection_head(input_vu)
114
+
115
+ # Sum logits
116
+ logits = (logits_uv + logits_vu) / 2.
117
+ probs = torch.sigmoid(logits)
118
+ all_probs.append(probs)
119
+
120
+ all_probs = torch.cat(all_probs).squeeze() # [M]
121
+
122
+ # --- 3. 筛选连接边 ---
123
+ connected_mask = all_probs > threshold
124
+ final_u = u_indices[connected_mask].cpu().numpy()
125
+ final_v = v_indices[connected_mask].cpu().numpy()
126
+
127
+ edges = np.stack([final_u, final_v], axis=1) # [E, 2]
128
+
129
+ return edges
130
+
131
+ def build_triangles_from_edges(edges, num_verts):
132
+ """
133
+ 从边列表构建三角形。
134
+ 寻找图中所有的 3-Cliques (三元环)。
135
+ 这在图论中是一个经典问题,可以使用 networkx 库。
136
+ """
137
+ if len(edges) == 0:
138
+ return np.empty((0, 3), dtype=int)
139
+
140
+ G = nx.Graph()
141
+ G.add_nodes_from(range(num_verts))
142
+ G.add_edges_from(edges)
143
+
144
+ # 寻找所有的 3-cliques (三角形)
145
+ # enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的
146
+ # 或者使用 nx.triangles ? 不,那个只返回数量
147
+ # 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以
148
+
149
+ # 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w
150
+ triangles = []
151
+ adj = [set(G.neighbors(n)) for n in range(num_verts)]
152
+
153
+ # 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w
154
+ # 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u]
155
+
156
+ # 优化算法:
157
+ for u, v in edges:
158
+ if u > v: u, v = v, u # 确保有序
159
+
160
+ # 找公共邻居
161
+ common = adj[u].intersection(adj[v])
162
+ for w in common:
163
+ if w > v: # 强制顺序 u < v < w 防止重复
164
+ triangles.append([u, v, w])
165
+
166
+ return np.array(triangles)
167
+
168
+
169
+ def downsample_voxels(
170
+ voxels: torch.Tensor,
171
+ input_resolution: int,
172
+ output_resolution: int
173
+ ) -> torch.Tensor:
174
+ if input_resolution % output_resolution != 0:
175
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
176
+ f"by output_resolution ({output_resolution}).")
177
+
178
+ factor = input_resolution // output_resolution
179
+
180
+ downsampled_voxels = voxels.clone().to(torch.long)
181
+
182
+ downsampled_voxels[:, 1:] //= factor
183
+
184
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
185
+ return unique_downsampled_voxels
186
+
187
+ def visualize_colored_points_ply(coords, vectors, filename):
188
+ """
189
+ 可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。
190
+
191
+ Args:
192
+ coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。
193
+ vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。
194
+ filename (str): 保存输出文件的名称,必须是 .ply 格式。
195
+ """
196
+ # 确保输入是 numpy 数组
197
+ if isinstance(coords, torch.Tensor):
198
+ coords = coords.detach().cpu().numpy()
199
+ if isinstance(vectors, torch.Tensor):
200
+ vectors = vectors.detach().cpu().to(torch.float32).numpy()
201
+
202
+ # 检查输入数据是否为空,防止崩溃
203
+ if coords.size == 0 or vectors.size == 0:
204
+ print(f"警告:输入数据为空,未生成 {filename} 文件。")
205
+ return
206
+
207
+ # 将向量分量从 [-1, 1] 映射到 [0, 255]
208
+ # np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出
209
+ # (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2]
210
+ # * 127.5 将范围从 [0, 2] 缩放到 [0, 255]
211
+ colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8)
212
+
213
+ # 创建一个点云对象,并传入颜色信息
214
+ # trimesh.PointCloud 能够自动处理带颜色的点
215
+ points = trimesh.points.PointCloud(coords, colors=colors)
216
+ # 导出为 PLY 文件
217
+ points.export(filename, file_type='ply')
218
+ print(f"可视化文件已成功保存为: {filename}")
219
+
220
+
221
+ def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0):
222
+ """
223
+ 使用 KDTree 最近邻 + 贪心匹配算法计算顶点匹配 (欧式距离)
224
+
225
+ 参数:
226
+ pred_coords: 预测坐标 (Tensor)
227
+ gt_coords: 真实坐标 (Tensor)
228
+ threshold: 匹配误差阈值 (默认1.0)
229
+
230
+ 返回:
231
+ matches: 匹配成功的顶点数量
232
+ match_rate: 匹配率 (基于真实顶点数)
233
+ pred_total: 预测顶点总数
234
+ gt_total: 真实顶点总数
235
+ """
236
+ # 转换为整数坐标并去重
237
+ print('len(pred_coords)', len(pred_coords))
238
+
239
+ pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0)
240
+ gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0)
241
+ print('len(pred_array)', len(pred_array))
242
+ pred_total = len(pred_array)
243
+ gt_total = len(gt_array)
244
+
245
+ # 如果没有点,直接返回
246
+ if pred_total == 0 or gt_total == 0:
247
+ return 0, 0.0, pred_total, gt_total
248
+
249
+ # 建立 KDTree(以 gt 为基准)
250
+ tree = KDTree(gt_array)
251
+
252
+ # 查找预测点到最近的 gt 点
253
+ dist, indices = tree.query(pred_array, k=1)
254
+ dist = dist.squeeze()
255
+ indices = indices.squeeze()
256
+
257
+ # 贪心去重:确保 1 对 1
258
+ matches = 0
259
+ used_gt = set()
260
+ for d, idx in zip(dist, indices):
261
+ if d <= threshold and idx not in used_gt:
262
+ matches += 1
263
+ used_gt.add(idx)
264
+
265
+ match_rate = matches / max(gt_total, 1)
266
+
267
+ return matches, match_rate, pred_total, gt_total
268
+
269
+
270
+ def flatten_coords_3d(coords_3d: torch.Tensor):
271
+ coords_3d_long = coords_3d #.long()
272
+
273
+ base_x = 1024
274
+ base_y = 1024 * 1024
275
+ base_z = 1024 * 1024 * 1024
276
+
277
+ flat_coords = coords_3d_long[:, 0] * base_z + \
278
+ coords_3d_long[:, 1] * base_y + \
279
+ coords_3d_long[:, 2] * base_x
280
+ return flat_coords
281
+
282
+ class Tester:
283
+ def __init__(self, ckpt_path, config_path=None, dataset_path=None):
284
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
285
+ self.ckpt_path = ckpt_path
286
+
287
+ self.config = self._load_config(config_path)
288
+ self.dataset_path = dataset_path # or self.config['dataset']['path']
289
+ checkpoint = torch.load(self.ckpt_path, map_location='cpu')
290
+ self.epoch = checkpoint.get('epoch', 0)
291
+
292
+ self._init_models()
293
+ self._init_dataset()
294
+
295
+ self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results")
296
+ os.makedirs(self.result_dir, exist_ok=True)
297
+
298
+ dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '')
299
+ self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path),
300
+ f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs")
301
+ os.makedirs(self.output_voxel_dir, exist_ok=True)
302
+
303
+ self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path),
304
+ f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs")
305
+ os.makedirs(self.output_obj_dir, exist_ok=True)
306
+
307
+ def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str):
308
+ if coords.numel() == 0:
309
+ return
310
+
311
+ coords_np = coords.cpu().to(torch.float32).numpy()
312
+ labels_np = labels.cpu().to(torch.float32).numpy()
313
+
314
+ colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8)
315
+ colors[labels_np == 0] = [255, 0, 0]
316
+ colors[labels_np == 1] = [0, 0, 255]
317
+
318
+ try:
319
+ import trimesh
320
+ point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors)
321
+ ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
322
+ point_cloud.export(ply_path)
323
+ except ImportError:
324
+ ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
325
+ with open(ply_path, 'w') as f:
326
+ f.write("ply\n")
327
+ f.write("format ascii 1.0\n")
328
+ f.write(f"element vertex {coords_np.shape[0]}\n")
329
+ f.write("property float x\n")
330
+ f.write("property float y\n")
331
+ f.write("property float z\n")
332
+ f.write("property uchar red\n")
333
+ f.write("property uchar green\n")
334
+ f.write("property uchar blue\n")
335
+ f.write("end_header\n")
336
+ for i in range(coords_np.shape[0]):
337
+ f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
338
+
339
+ def _load_config(self, config_path=None):
340
+ if config_path and os.path.exists(config_path):
341
+ with open(config_path) as f:
342
+ return yaml.safe_load(f)
343
+
344
+ ckpt_dir = os.path.dirname(self.ckpt_path)
345
+ possible_configs = [
346
+ os.path.join(ckpt_dir, "config.yaml"),
347
+ os.path.join(os.path.dirname(ckpt_dir), "config.yaml")
348
+ ]
349
+
350
+ for config_file in possible_configs:
351
+ if os.path.exists(config_file):
352
+ with open(config_file) as f:
353
+ print(f"Loaded config from: {config_file}")
354
+ return yaml.safe_load(f)
355
+
356
+ checkpoint = torch.load(self.ckpt_path, map_location='cpu')
357
+ if 'config' in checkpoint:
358
+ print("Loaded config from checkpoint")
359
+ return checkpoint['config']
360
+
361
+ raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.")
362
+
363
+ def _init_models(self):
364
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
365
+ in_channels=15,
366
+ hidden_dim=256,
367
+ out_channels=1024,
368
+ scatter_type='mean',
369
+ n_blocks=5,
370
+ resolution=128,
371
+
372
+ ).to(self.device)
373
+
374
+ self.connection_head = ConnectionHead(
375
+ channels=32 * 2,
376
+ out_channels=1,
377
+ mlp_ratio=16,
378
+ ).to(self.device)
379
+
380
+ self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation
381
+ in_channels=self.config['model']['in_channels'],
382
+ latent_dim=self.config['model']['latent_dim'],
383
+ encoder_blocks=self.config['model']['encoder_blocks'],
384
+ # decoder_blocks=self.config['model']['decoder_blocks'],
385
+ decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'],
386
+ decoder_blocks_edge=self.config['model']['decoder_blocks_edge'],
387
+ num_heads=8,
388
+ num_head_channels=64,
389
+ mlp_ratio=4.0,
390
+ attn_mode="swin",
391
+ window_size=8,
392
+ pe_mode="ape",
393
+ use_fp16=False,
394
+ use_checkpoint=False,
395
+ qk_rms_norm=False,
396
+ using_subdivide=True,
397
+ using_attn=self.config['model']['using_attn'],
398
+ attn_first=self.config['model'].get('attn_first', True),
399
+ pred_direction=self.config['model'].get('pred_direction', False),
400
+ ).to(self.device)
401
+
402
+ load_pretrained_woself(
403
+ checkpoint_path=self.ckpt_path,
404
+ voxel_encoder=self.voxel_encoder,
405
+ connection_head=self.connection_head,
406
+ vae=self.vae,
407
+ )
408
+ # --- 【新增】在这里添加权重检查逻辑 ---
409
+ print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---")
410
+ has_nan_inf = False
411
+ if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"):
412
+ has_nan_inf = True
413
+
414
+ if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"):
415
+ has_nan_inf = True
416
+
417
+ if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"):
418
+ has_nan_inf = True
419
+
420
+ if not has_nan_inf:
421
+ print("--- 权重检查通过。未发现 NaN/Inf 值。 ---")
422
+ else:
423
+ # 如果发现坏值,直接抛出异常,因为评估无法继续
424
+ raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。")
425
+ # --- 检查逻辑结束 ---
426
+
427
+ self.vae.eval()
428
+ self.voxel_encoder.eval()
429
+ self.connection_head.eval()
430
+
431
+ def _init_dataset(self):
432
+ self.dataset = VoxelVertexDataset_edge(
433
+ root_dir=self.dataset_path,
434
+ base_resolution=self.config['dataset']['base_resolution'],
435
+ min_resolution=self.config['dataset']['min_resolution'],
436
+ cache_dir='/gemini/user/private/zhaotianhao/dataset_cache/test_15c_dora_128to1024',
437
+ # cache_dir=self.config['dataset']['cache_dir'],
438
+ renders_dir=self.config['dataset']['renders_dir'],
439
+
440
+ # filter_active_voxels=self.config['dataset']['filter_active_voxels'],
441
+ filter_active_voxels=False,
442
+ cache_filter_path=self.config['dataset']['cache_filter_path'],
443
+
444
+ sample_type=self.config['dataset']['sample_type'],
445
+ active_voxel_res=128,
446
+ pc_sample_number=819200,
447
+
448
+ )
449
+
450
+ self.dataloader = DataLoader(
451
+ self.dataset,
452
+ batch_size=1,
453
+ shuffle=False,
454
+ collate_fn=partial(collate_fn_pointnet),
455
+ num_workers=0,
456
+ pin_memory=True,
457
+ # prefetch_factor=4,
458
+ )
459
+
460
+ def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool:
461
+ """
462
+ 检查模型的所有参数中是否存在 NaN 或 Inf 值。
463
+
464
+ Args:
465
+ model (torch.nn.Module): 要检查的模型。
466
+ model_name (str): 模型的名称,用于打印日志。
467
+
468
+ Returns:
469
+ bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。
470
+ """
471
+ found_issue = False
472
+ for name, param in model.named_parameters():
473
+ if torch.isnan(param.data).any():
474
+ print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!")
475
+ found_issue = True
476
+ if torch.isinf(param.data).any():
477
+ print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!")
478
+ found_issue = True
479
+ return found_issue
480
+
481
+
482
+ def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
483
+ """
484
+ 修改后的函数,确保一对一匹配,并优先匹配最近的点对。
485
+ """
486
+ pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
487
+ gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
488
+
489
+ pred_total = len(pred_array)
490
+ gt_total = len(gt_array)
491
+
492
+ if pred_total == 0 or gt_total == 0:
493
+ return {
494
+ 'recall': 0.0,
495
+ 'precision': 0.0,
496
+ 'f1': 0.0,
497
+ 'matches': 0,
498
+ 'pred_count': pred_total,
499
+ 'gt_count': gt_total
500
+ }
501
+
502
+ # 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点
503
+ tree = cKDTree(pred_array)
504
+ dists, pred_idxs = tree.query(gt_array, k=1)
505
+
506
+ # --- 核心修改部分 ---
507
+
508
+ # 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引)
509
+ # 这样我们就可以按距离对所有可能的匹配进行排序
510
+ possible_matches = []
511
+ for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)):
512
+ if dist <= threshold:
513
+ possible_matches.append((dist, gt_idx, pred_idx))
514
+
515
+ # 2. 按距离从小到大排序(贪心策略)
516
+ possible_matches.sort(key=lambda x: x[0])
517
+
518
+ matches = 0
519
+ # 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配
520
+ used_pred_indices = set()
521
+ used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨
522
+
523
+ # 3. 遍历排序后的可能匹配,进行一对一分配
524
+ for dist, gt_idx, pred_idx in possible_matches:
525
+ # 如果这个预测点和这个真实点都还没有被使用过
526
+ if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices:
527
+ matches += 1
528
+ used_pred_indices.add(pred_idx)
529
+ used_gt_indices.add(gt_idx)
530
+
531
+ # --- 修改结束 ---
532
+
533
+ # matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total
534
+ recall = matches / gt_total if gt_total > 0 else 0.0
535
+ precision = matches / pred_total if pred_total > 0 else 0.0
536
+
537
+ # 计算F1时,使用标准的 Precision 和 Recall 定义
538
+ if (precision + recall) == 0:
539
+ f1 = 0.0
540
+ else:
541
+ f1 = 2 * (precision * recall) / (precision + recall)
542
+
543
+ return {
544
+ 'recall': recall,
545
+ 'precision': precision,
546
+ 'f1': f1,
547
+ 'matches': matches,
548
+ 'pred_count': pred_total,
549
+ 'gt_count': gt_total
550
+ }
551
+
552
+ def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
553
+ """
554
+ 一个折衷的顶点指标计算方案。
555
+ 它沿用“为每个真实点寻找最近预测点”的逻辑,
556
+ 但通过修正计算方式,确保Precision和F1值不会超过1.0。
557
+ """
558
+ # 假设 pred_coords 和 gt_coords 是 PyTorch 张量
559
+ pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
560
+ gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
561
+
562
+ pred_total = len(pred_array)
563
+ gt_total = len(gt_array)
564
+
565
+ if pred_total == 0 or gt_total == 0:
566
+ return {
567
+ 'recall': 0.0,
568
+ 'precision': 0.0,
569
+ 'f1': 0.0,
570
+ 'matches': 0,
571
+ 'pred_count': pred_total,
572
+ 'gt_count': gt_total
573
+ }
574
+
575
+ # 在预测点上构建KD-Tree,为每个真实点查找最近的预测点
576
+ tree = cKDTree(pred_array)
577
+ dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引
578
+
579
+ # 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall)
580
+ # 这和您的第一个函数完全一样。
581
+ # 这个值代表了“有多少个真实点被成功找到了”。
582
+ matches_from_gt = np.sum(dists <= threshold)
583
+
584
+ # 2. 计算 Recall (召回率)
585
+ # 召回率的分母是真实点的总数,所以这里的计算是合理的。
586
+ recall = matches_from_gt / gt_total if gt_total > 0 else 0.0
587
+
588
+ # 3. 计算 Precision (精确率) - ✅ 这是核心修正点
589
+ # 精确率的分母是预测点的总数。
590
+ # 分子(True Positives)不能超过预测点的总数。
591
+ # 因此,我们取 matches_from_gt 和 pred_total 中的较小值。
592
+ # 这解决了 Precision > 1 的问题。
593
+ tp_for_precision = min(matches_from_gt, pred_total)
594
+ precision = tp_for_precision / pred_total if pred_total > 0 else 0.0
595
+
596
+ # 4. 使用标准的F1分数公式
597
+ # 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score,
598
+ # 更常用的是基于 Precision 和 Recall 的调和平均数。
599
+ if (precision + recall) == 0:
600
+ f1 = 0.0
601
+ else:
602
+ f1 = 2 * (precision * recall) / (precision + recall)
603
+
604
+ return {
605
+ 'recall': recall,
606
+ 'precision': precision,
607
+ 'f1': f1,
608
+ 'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察
609
+ 'pred_count': pred_total,
610
+ 'gt_count': gt_total
611
+ }
612
+
613
+ def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False):
614
+ if len(p1) == 0 or len(p2) == 0:
615
+ return float('nan')
616
+
617
+ dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean()
618
+
619
+ if one_sided:
620
+ return dist_p1_p2.item()
621
+ else:
622
+ dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean()
623
+ return (dist_p1_p2 + dist_p2_p1).item() / 2
624
+
625
+ def visualize_latent_space_pca(self, sample_idx: int):
626
+ """
627
+ Encodes a sample, performs PCA on its latent features, and saves a
628
+ colored PLY file for visualization.
629
+
630
+ The position of each point in the PLY file corresponds to the spatial
631
+ location in the latent grid.
632
+
633
+ The color of each point represents the first three principal components
634
+ of its feature vector.
635
+ """
636
+ print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---")
637
+ self.vae.eval()
638
+
639
+ try:
640
+ # 1. Get the latent representation for the sample
641
+ latent = self._get_latent_for_sample(sample_idx)
642
+ except ValueError as e:
643
+ print(f"Error: {e}")
644
+ return
645
+
646
+ latent_coords = latent.coords.detach().cpu().numpy()
647
+ latent_feats = latent.feats.detach().cpu().numpy()
648
+
649
+ if latent_feats.shape[0] < 3:
650
+ print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.")
651
+ return
652
+
653
+ print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...")
654
+
655
+ # 2. Perform PCA to reduce feature dimensions to 3
656
+ pca = PCA(n_components=3)
657
+ pca_features = pca.fit_transform(latent_feats)
658
+
659
+ print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}")
660
+ print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}")
661
+
662
+ # 3. Normalize the PCA components to be used as RGB colors [0, 255]
663
+ # We normalize each component independently to maximize color contrast
664
+ normalized_colors = np.zeros_like(pca_features)
665
+ for i in range(3):
666
+ min_val = pca_features[:, i].min()
667
+ max_val = pca_features[:, i].max()
668
+ if max_val - min_val > 1e-6:
669
+ normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val)
670
+ else:
671
+ normalized_colors[:, i] = 0.5 # Handle case of constant value
672
+
673
+ colors_uint8 = (normalized_colors * 255).astype(np.uint8)
674
+
675
+ # 4. Prepare spatial coordinates for the point cloud
676
+ # latent_coords is (batch_idx, x, y, z), we want the xyz part
677
+ spatial_coords = latent_coords[:, 1:]
678
+
679
+ # 5. Create and save the colored PLY file
680
+ try:
681
+ # Create a Trimesh PointCloud object
682
+ point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8)
683
+
684
+ # Define the output filename
685
+ filename = f"sample_{sample_idx}_latent_pca.ply"
686
+ ply_path = os.path.join(self.output_voxel_dir, filename)
687
+
688
+ # Export the file
689
+ point_cloud.export(ply_path)
690
+ print(f"--> Successfully saved PCA visualization to: {ply_path}")
691
+
692
+ except Exception as e:
693
+ print(f"Error during Trimesh export: {e}")
694
+ print("Please ensure 'trimesh' is installed correctly.")
695
+
696
+ def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor:
697
+ """
698
+ Encodes a single sample and returns its latent representation.
699
+ """
700
+ print(f"--> Encoding sample {sample_idx} to get its latent vector...")
701
+ # Get data for the specified sample
702
+ batch_data = self.dataset[sample_idx]
703
+ if batch_data is None:
704
+ raise ValueError(f"Sample at index {sample_idx} could not be loaded.")
705
+
706
+ # Use the collate function to form a batch
707
+ batch_data = collate_fn_pointnet([batch_data])
708
+
709
+ with torch.no_grad():
710
+ # 1. Get input data and move to device
711
+ gt_vertex_voxels_1024 = batch_data['gt_vertex_voxels_1024'].to(self.device)
712
+ combined_voxels_1024 = batch_data['combined_voxels_1024'].to(self.device)
713
+ active_coords = batch_data['active_voxels_128'].to(self.device)
714
+ point_cloud = batch_data['point_cloud_128'].to(self.device)
715
+
716
+ vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
717
+ edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
718
+
719
+
720
+ active_voxel_feats = self.voxel_encoder(
721
+ p=point_cloud,
722
+ sparse_coords=active_coords,
723
+ res=128,
724
+ bbox_size=(-0.5, 0.5),
725
+
726
+ # voxel_label=active_labels,
727
+ )
728
+
729
+ sparse_input = SparseTensor(
730
+ feats=active_voxel_feats,
731
+ coords=active_coords.int()
732
+ )
733
+
734
+ # 2. Encode to get the latent representation
735
+ latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,)
736
+ print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}")
737
+ return latent_128
738
+
739
+
740
+ def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.):
741
+ total_samples = len(self.dataset)
742
+ eval_samples = min(num_samples or total_samples, total_samples)
743
+ # sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples)
744
+ sample_indices = range(eval_samples)
745
+
746
+ eval_dataset = Subset(self.dataset, sample_indices)
747
+ eval_loader = DataLoader(
748
+ eval_dataset,
749
+ batch_size=1,
750
+ shuffle=False,
751
+ collate_fn=partial(collate_fn_pointnet),
752
+ num_workers=self.config['training']['num_workers'],
753
+ pin_memory=True,
754
+ )
755
+
756
+ per_sample_metrics = {
757
+ 'vertex': {res: [] for res in [128, 256, 512, 1024]},
758
+ 'edge': {res: [] for res in [128, 256, 512, 1024]},
759
+ 'sample_names': []
760
+ }
761
+ avg_metrics = {
762
+ 'vertex': {res: defaultdict(list) for res in [128, 256, 512, 1024]},
763
+ 'edge': {res: defaultdict(list) for res in [128, 256, 512, 1024]},
764
+ }
765
+
766
+ self.vae.eval()
767
+
768
+ for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")):
769
+ if batch_data is None:
770
+ continue
771
+ sample_idx = sample_indices[batch_idx]
772
+ sample_name = f'sample_{sample_idx}'
773
+ per_sample_metrics['sample_names'].append(sample_name)
774
+
775
+ # batch_save_path = f"/root/Trisf/output_slat_flow_matching_active/ckpts/8w/195000_sample_active_vis_42seed_trellis_generate/gt_data_batch_{batch_idx}.pt"
776
+ # if not os.path.exists(batch_save_path):
777
+ # print(f"Warning: Saved batch file not found: {batch_save_path}")
778
+ # continue
779
+ # batch_data = torch.load(batch_save_path, map_location=self.device)
780
+
781
+ with torch.no_grad():
782
+ # 1. Get input data
783
+ combined_voxels_1024 = batch_data['combined_voxels_1024'].to(self.device)
784
+ combined_voxel_labels_1024 = batch_data['combined_voxel_labels_1024'].to(self.device)
785
+ gt_combined_endpoints_1024 = batch_data['gt_combined_endpoints_1024'].to(self.device)
786
+ gt_combined_errors_1024 = batch_data['gt_combined_errors_1024'].to(self.device)
787
+
788
+ edge_mask = (combined_voxel_labels_1024 == 1)
789
+
790
+ gt_edge_endpoints_1024 = gt_combined_endpoints_1024[edge_mask].to(self.device)
791
+ gt_edge_errors_1024 = gt_combined_errors_1024[edge_mask].to(self.device)
792
+ gt_edge_voxels_1024 = combined_voxels_1024[edge_mask].to(self.device)
793
+
794
+ p1 = gt_edge_endpoints_1024[:, 1:4].float()
795
+ p2 = gt_edge_endpoints_1024[:, 4:7].float()
796
+
797
+ mask = ( (p1[:,0] < p2[:,0]) |
798
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
799
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
800
+
801
+ pA = torch.where(mask[:, None], p1, p2) # smaller one
802
+ pB = torch.where(mask[:, None], p2, p1) # larger one
803
+
804
+ d = pB - pA
805
+ dir_gt = F.normalize(d, dim=-1, eps=1e-6)
806
+
807
+ gt_vertex_voxels_1024 = batch_data['gt_vertex_voxels_1024'].to(self.device).int()
808
+
809
+ vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
810
+
811
+ edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
812
+ edge_512 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=512)
813
+ edge_256 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=256)
814
+ edge_1024 = combined_voxels_1024
815
+
816
+ print('vtx_128.shape', vtx_128.shape)
817
+ print('edge_128.shape', edge_128.shape)
818
+
819
+ gt_edge_voxels_list = [
820
+ edge_128,
821
+ edge_256,
822
+ edge_512,
823
+ edge_1024,
824
+ ]
825
+
826
+ active_coords = batch_data['active_voxels_128'].to(self.device)
827
+ point_cloud = batch_data['point_cloud_128'].to(self.device)
828
+
829
+
830
+
831
+ active_voxel_feats = self.voxel_encoder(
832
+ p=point_cloud,
833
+ sparse_coords=active_coords,
834
+ res=128,
835
+ bbox_size=(-0.5, 0.5),
836
+
837
+ # voxel_label=active_labels,
838
+ )
839
+
840
+ sparse_input = SparseTensor(
841
+ feats=active_voxel_feats,
842
+ coords=active_coords.int()
843
+ )
844
+
845
+ latent_128, posterior = self.vae.encode(sparse_input)
846
+
847
+
848
+ # load_path = f'/root/Trisf/output_slat_flow_matching_active/ckpts/8w/195000_sample_active_vis_42seed_trellis_generate/sample_latent_{batch_idx}.pt'
849
+ # latent_128 = torch.load(load_path, map_location=self.device)
850
+
851
+ print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std())
852
+ print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean())
853
+ print('latent_128.coords.shape', latent_128.coords.shape)
854
+
855
+ # latent_128 = torch.load(f"/root/Trisf/output_slat_flow_matching/ckpts/1100_chair_sample/110000step_sample/sample_results_samples_{batch_idx}.pt", map_location=self.device)
856
+ latent_128 = SparseTensor(
857
+ coords=latent_128.coords,
858
+ feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats),
859
+ )
860
+
861
+ # self.output_voxel_dir = os.path.dirname(load_path)
862
+ # self.output_obj_dir = os.path.dirname(load_path)
863
+
864
+ # 7. Decoding with separate vertex and edge processing
865
+ decoded_results = self.vae.decode(
866
+ latent_128,
867
+ gt_vertex_voxels_list=[],
868
+ gt_edge_voxels_list=[],
869
+ training=False,
870
+
871
+ inference_threshold=0.5,
872
+ vis_last_layer=False,
873
+ )
874
+
875
+ error = 0 #decoded_results[-1]['edge']['predicted_offset_feats']
876
+
877
+ if self.config['model'].get('pred_direction', False):
878
+ pred_dir = decoded_results[-1]['edge']['predicted_direction_feats']
879
+ zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0
880
+ num_zeros = zero_mask.sum().item()
881
+ print("Number of zero vectors:", num_zeros)
882
+
883
+ pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
884
+ print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
885
+ print('pred_dir.shape', pred_dir.shape)
886
+ if pred_edge_coords_3d.shape[-1] == 4:
887
+ pred_edge_coords_3d = pred_edge_coords_3d[:, 1:]
888
+ # visualize_directions(pred_edge_coords_3d, pred_dir, sample_ratio=0.02)
889
+ save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply")
890
+ # visualize_colored_points_ply(pred_edge_coords_3d - error / 2. + 0.5, pred_dir, save_pth)
891
+ visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth)
892
+
893
+ save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply")
894
+ # visualize_colored_points_ply((gt_edge_voxels_1024[:, 1:] - gt_edge_errors_1024[:, 1:] + 0.5), dir_gt, save_pth)
895
+ visualize_colored_points_ply((gt_edge_voxels_1024[:, 1:]), dir_gt, save_pth)
896
+
897
+
898
+ pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords']
899
+ pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
900
+
901
+ pred_edge_coords_np = np.round(pred_edge_coords_3d.cpu().numpy()).astype(int)
902
+
903
+ gt_vertex_voxels_1024 = batch_data['gt_vertex_voxels_1024'][:, 1:].to(self.device)
904
+ gt_edge_voxels_1024 = batch_data['gt_edge_voxels_1024'][:, 1:].to(self.device)
905
+ gt_edge_coords_np = np.round(gt_edge_voxels_1024.cpu().numpy()).astype(int)
906
+
907
+
908
+ # Calculate metrics and save results
909
+ matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_1024, threshold=threshold,)
910
+ print(f"\n----- Resolution {1024} vtx -----")
911
+ print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
912
+ print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
913
+
914
+ self._save_voxel_ply(pred_vtx_coords_3d / 1024., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx")
915
+ self._save_voxel_ply((pred_edge_coords_3d) / 1024, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge")
916
+
917
+ self._save_voxel_ply(gt_vertex_voxels_1024 / 1024, torch.zeros(len(gt_vertex_voxels_1024)), f"{sample_name}_gt_vertex")
918
+ self._save_voxel_ply((combined_voxels_1024[:, 1:]) / 1024., torch.zeros(len(gt_combined_errors_1024)), f"{sample_name}_gt_edge")
919
+
920
+
921
+ # Calculate vertex-specific metrics
922
+ matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_1024[:, 1:], threshold=threshold,)
923
+ print(f"\n----- Resolution {1024} edge -----")
924
+ print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
925
+ print('gt_edge_voxels_1024.shape', gt_edge_voxels_1024.shape)
926
+
927
+ print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
928
+ print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
929
+
930
+ pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int)
931
+ pred_edges = []
932
+ gt_vertex_coords_np = np.round(gt_vertex_voxels_1024.cpu().numpy()).astype(int)
933
+ if visualize:
934
+ if pred_vtx_coords_3d.shape[-1] == 4:
935
+ pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float()
936
+ else:
937
+ pred_vtx_coords_float = pred_vtx_coords_3d.float()
938
+
939
+ pred_vtx_feats = decoded_results[-1]['vertex']['feats']
940
+
941
+ # ==========================================
942
+ # Link Prediction & Mesh Generation
943
+ # ==========================================
944
+ print("Predicting connectivity...")
945
+
946
+ # 1. 预测边
947
+ # 注意:K_neighbors 的设置。如果是物体,64 足够了。
948
+ # 如果点非常稀疏,可能需要更大。
949
+ pred_edges = predict_mesh_connectivity(
950
+ connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的
951
+ vtx_feats=pred_vtx_feats,
952
+ vtx_coords=pred_vtx_coords_float,
953
+ batch_size=4096,
954
+ threshold=0.5,
955
+ k_neighbors=None,
956
+ device=self.device
957
+ )
958
+ print(f"Predicted {len(pred_edges)} edges.")
959
+
960
+ # 2. 构建三角形
961
+ num_verts = pred_vtx_coords_float.shape[0]
962
+ pred_faces = build_triangles_from_edges(pred_edges, num_verts)
963
+ print(f"Constructed {len(pred_faces)} triangles.")
964
+
965
+ # 3. 保存 OBJ
966
+ import trimesh
967
+
968
+ # 坐标归一化/还原 (根据你的需求,这里假设你是 0-1024 的体素坐标)
969
+ # 如果想保存为归一化坐标:
970
+ mesh_verts = pred_vtx_coords_float.cpu().numpy() / 1024.0
971
+
972
+ # 如果有 error offset,记得加上!
973
+ # 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
974
+ # 如果 vertex 也有 offset (如 dual contouring),在这里加上
975
+
976
+ # 移动到中心 (可选)
977
+ mesh_verts = mesh_verts - 0.5
978
+
979
+ mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
980
+
981
+ # 过滤孤立点 (可选)
982
+ # mesh.remove_unreferenced_vertices()
983
+
984
+ output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj")
985
+ mesh.export(output_obj_path)
986
+ print(f"Saved mesh to {output_obj_path}")
987
+
988
+ # 保存边线 (用于 Debug)
989
+ # 有时候三角形很难形成,只看边也很有用
990
+ edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply")
991
+ # self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison")
992
+
993
+
994
+ # Process results at different resolutions
995
+ for i, res in enumerate([128, 256, 512, 1024]):
996
+ if i >= len(decoded_results):
997
+ continue
998
+
999
+ gt_key = f'gt_vertex_voxels_{res}'
1000
+ if gt_key not in batch_data:
1001
+ continue
1002
+ if i == 0:
1003
+ pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float()
1004
+ gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
1005
+ else:
1006
+ pred_coords_res = decoded_results[i]['vertex']['coords'].float()
1007
+ gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
1008
+
1009
+
1010
+ v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold)
1011
+
1012
+ per_sample_metrics['vertex'][res].append({
1013
+ 'recall': v_metrics['recall'],
1014
+ 'precision': v_metrics['precision'],
1015
+ 'f1': v_metrics['f1'],
1016
+ 'num_pred': len(pred_coords_res),
1017
+ 'num_gt': len(gt_coords_res)
1018
+ })
1019
+
1020
+ avg_metrics['vertex'][res]['recall'].append(v_metrics['recall'])
1021
+ avg_metrics['vertex'][res]['precision'].append(v_metrics['precision'])
1022
+ avg_metrics['vertex'][res]['f1'].append(v_metrics['f1'])
1023
+
1024
+ gt_edge_key = f'gt_edge_voxels_{res}'
1025
+ if gt_edge_key not in batch_data:
1026
+ continue
1027
+
1028
+ if i == 0:
1029
+ pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float()
1030
+ # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1031
+ idx = i
1032
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1033
+ elif i == 3:
1034
+ idx = i
1035
+ #################################
1036
+ # pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5
1037
+ # # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1038
+ # gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_1024[:, 1:].to(self.device) + 0.5
1039
+
1040
+
1041
+ pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
1042
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1043
+
1044
+ # self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset")
1045
+ # self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset")
1046
+
1047
+ else:
1048
+ idx = i
1049
+ pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
1050
+ # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1051
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1052
+
1053
+ # self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res")
1054
+ # self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res")
1055
+
1056
+ e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold)
1057
+
1058
+ per_sample_metrics['edge'][res].append({
1059
+ 'recall': e_metrics['recall'],
1060
+ 'precision': e_metrics['precision'],
1061
+ 'f1': e_metrics['f1'],
1062
+ 'num_pred': len(pred_edge_coords_res),
1063
+ 'num_gt': len(gt_edge_coords_res)
1064
+ })
1065
+
1066
+ avg_metrics['edge'][res]['recall'].append(e_metrics['recall'])
1067
+ avg_metrics['edge'][res]['precision'].append(e_metrics['precision'])
1068
+ avg_metrics['edge'][res]['f1'].append(e_metrics['f1'])
1069
+
1070
+ avg_metrics_processed = {}
1071
+ for category, res_dict in avg_metrics.items():
1072
+ avg_metrics_processed[category] = {}
1073
+ for resolution, metric_dict in res_dict.items():
1074
+ avg_metrics_processed[category][resolution] = {
1075
+ metric_name: np.mean(values) if values else float('nan')
1076
+ for metric_name, values in metric_dict.items()
1077
+ }
1078
+
1079
+ result_data = {
1080
+ 'config': self.config,
1081
+ 'checkpoint': self.ckpt_path,
1082
+ 'dataset': self.dataset_path,
1083
+ 'num_samples': eval_samples,
1084
+ 'per_sample_metrics': per_sample_metrics,
1085
+ 'avg_metrics': avg_metrics_processed
1086
+ }
1087
+
1088
+ results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml")
1089
+ with open(results_file_path, 'w') as f:
1090
+ yaml.dump(result_data, f, default_flow_style=False)
1091
+
1092
+ return result_data
1093
+
1094
+ def _generate_line_voxels(
1095
+ self,
1096
+ p1: torch.Tensor,
1097
+ p2: torch.Tensor
1098
+ ) -> Tuple[
1099
+ List[Tuple[int, int, int]],
1100
+ List[Tuple[torch.Tensor, torch.Tensor]],
1101
+ List[np.ndarray]
1102
+ ]:
1103
+ """
1104
+ Improved version using better sampling strategy
1105
+ """
1106
+ p1_np = p1 #.cpu().numpy()
1107
+ p2_np = p2 #.cpu().numpy()
1108
+ voxel_dict = OrderedDict()
1109
+
1110
+ # Use proper 3D line voxelization algorithm
1111
+ def bresenham_3d(p1, p2):
1112
+ """3D Bresenham's line algorithm"""
1113
+ x1, y1, z1 = np.round(p1).astype(int)
1114
+ x2, y2, z2 = np.round(p2).astype(int)
1115
+
1116
+ points = []
1117
+ dx = abs(x2 - x1)
1118
+ dy = abs(y2 - y1)
1119
+ dz = abs(z2 - z1)
1120
+
1121
+ xs = 1 if x2 > x1 else -1
1122
+ ys = 1 if y2 > y1 else -1
1123
+ zs = 1 if z2 > z1 else -1
1124
+
1125
+ # Driving axis is X
1126
+ if dx >= dy and dx >= dz:
1127
+ err_1 = 2 * dy - dx
1128
+ err_2 = 2 * dz - dx
1129
+ for i in range(dx + 1):
1130
+ points.append((x1, y1, z1))
1131
+ if err_1 > 0:
1132
+ y1 += ys
1133
+ err_1 -= 2 * dx
1134
+ if err_2 > 0:
1135
+ z1 += zs
1136
+ err_2 -= 2 * dx
1137
+ err_1 += 2 * dy
1138
+ err_2 += 2 * dz
1139
+ x1 += xs
1140
+
1141
+ # Driving axis is Y
1142
+ elif dy >= dx and dy >= dz:
1143
+ err_1 = 2 * dx - dy
1144
+ err_2 = 2 * dz - dy
1145
+ for i in range(dy + 1):
1146
+ points.append((x1, y1, z1))
1147
+ if err_1 > 0:
1148
+ x1 += xs
1149
+ err_1 -= 2 * dy
1150
+ if err_2 > 0:
1151
+ z1 += zs
1152
+ err_2 -= 2 * dy
1153
+ err_1 += 2 * dx
1154
+ err_2 += 2 * dz
1155
+ y1 += ys
1156
+
1157
+ # Driving axis is Z
1158
+ else:
1159
+ err_1 = 2 * dx - dz
1160
+ err_2 = 2 * dy - dz
1161
+ for i in range(dz + 1):
1162
+ points.append((x1, y1, z1))
1163
+ if err_1 > 0:
1164
+ x1 += xs
1165
+ err_1 -= 2 * dz
1166
+ if err_2 > 0:
1167
+ y1 += ys
1168
+ err_2 -= 2 * dz
1169
+ err_1 += 2 * dx
1170
+ err_2 += 2 * dy
1171
+ z1 += zs
1172
+
1173
+ return points
1174
+
1175
+ # Get all voxels using Bresenham algorithm
1176
+ voxel_coords = bresenham_3d(p1_np, p2_np)
1177
+
1178
+ # Add all voxels to dictionary
1179
+ for coord in voxel_coords:
1180
+ voxel_dict[tuple(coord)] = (p1, p2)
1181
+
1182
+ voxel_coords = list(voxel_dict.keys())
1183
+ endpoint_pairs = list(voxel_dict.values())
1184
+
1185
+ # --- compute error vectors ---
1186
+ error_vectors = []
1187
+ diff = p2_np - p1_np
1188
+ d_norm_sq = np.dot(diff, diff)
1189
+
1190
+ for v in voxel_coords:
1191
+ v_center = np.array(v, dtype=float) + 0.5
1192
+ if d_norm_sq == 0: # degenerate line
1193
+ closest = p1_np
1194
+ else:
1195
+ t = np.dot(v_center - p1_np, diff) / d_norm_sq
1196
+ t = np.clip(t, 0.0, 1.0)
1197
+ closest = p1_np + t * diff
1198
+ error_vectors.append(v_center - closest)
1199
+
1200
+ return voxel_coords, endpoint_pairs, error_vectors
1201
+
1202
+
1203
+ # 使用示例
1204
+ def set_seed(seed: int):
1205
+ random.seed(seed)
1206
+ np.random.seed(seed)
1207
+ torch.manual_seed(seed)
1208
+ if torch.cuda.is_available():
1209
+ torch.cuda.manual_seed(seed)
1210
+ torch.cuda.manual_seed_all(seed)
1211
+ torch.backends.cudnn.deterministic = True
1212
+ torch.backends.cudnn.benchmark = False
1213
+
1214
+ def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
1215
+ set_seed(42)
1216
+ tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
1217
+ result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD)
1218
+
1219
+ # 生成文件名
1220
+ epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0]
1221
+ dataset_name = os.path.basename(os.path.normpath(dataset_path))
1222
+
1223
+ # 保存简版报告(TXT)
1224
+ summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt")
1225
+ with open(summary_path, 'w') as f:
1226
+ # 头部信息
1227
+ f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n")
1228
+ f.write(f"Dataset: {dataset_name}\n")
1229
+ f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n")
1230
+
1231
+ # 平均指标
1232
+ f.write("=== Average Metrics ===\n")
1233
+ for category, data in result_data['avg_metrics'].items():
1234
+ if isinstance(data, dict): # 处理多分辨率情况
1235
+ f.write(f"\n{category.upper()}:\n")
1236
+ for res, metrics in data.items():
1237
+ f.write(f" Resolution {res}:\n")
1238
+ for k, v in metrics.items():
1239
+ # 确保值是数字类型后再格式化
1240
+ if isinstance(v, (int, float)):
1241
+ f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
1242
+ else:
1243
+ f.write(f" {str(k).ljust(15)}: {str(v)}\n")
1244
+ else: # 处理非多分辨率情况
1245
+ f.write(f"\n{category.upper()}:\n")
1246
+ for k, v in data.items():
1247
+ if isinstance(v, (int, float)):
1248
+ f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
1249
+ else:
1250
+ f.write(f" {str(k).ljust(15)}: {str(v)}\n")
1251
+
1252
+ # 样本级详细统计
1253
+ f.write("\n\n=== Detailed Per-Sample Metrics ===\n")
1254
+ for name, vertex_metrics, edge_metrics in zip(
1255
+ result_data['per_sample_metrics']['sample_names'],
1256
+ zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256, 512, 1024]]),
1257
+ zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256, 512, 1024]])
1258
+ ):
1259
+ # 样本标题
1260
+ f.write(f"\n◆ Sample: {name}\n")
1261
+
1262
+ # 顶点指标
1263
+ f.write(f"[Vertex Prediction]\n")
1264
+ f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
1265
+ for res, metrics in zip([128, 256, 512, 1024], vertex_metrics):
1266
+ f.write(f" {str(res).ljust(10)} "
1267
+ f"{metrics['recall']:.4f} "
1268
+ f"{metrics['precision']:.4f} "
1269
+ f"{metrics['f1']:.4f} "
1270
+ f"{metrics['num_pred']}/{metrics['num_gt']}\n")
1271
+
1272
+ # Edge指标
1273
+ f.write(f"[Edge Prediction]\n")
1274
+ f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
1275
+ for res, metrics in zip([128, 256, 512, 1024], edge_metrics):
1276
+ f.write(f" {str(res).ljust(10)} "
1277
+ f"{metrics['recall']:.4f} "
1278
+ f"{metrics['precision']:.4f} "
1279
+ f"{metrics['f1']:.4f} "
1280
+ f"{metrics['num_pred']}/{metrics['num_gt']}\n")
1281
+
1282
+ f.write("-"*60 + "\n")
1283
+
1284
+ print(f"Saved summary to: {summary_path}")
1285
+ return result_data
1286
+
1287
+
1288
+ if __name__ == '__main__':
1289
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
1290
+ evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
1291
+ EPOCH_START = 14
1292
+ EPOCH_END = 460
1293
+ CHAMFER_EDGE_THRESHOLD=0.5
1294
+ NUM_SAMPLES=50
1295
+ VISUALIZE=True
1296
+ THRESHOLD=1.5
1297
+ VISUAL_FIELD=False
1298
+
1299
+ ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to1024_dir_sorted_dora_head_small_right/checkpoint_epoch14_batch5216_loss0.2745.pt'
1300
+ dataset_path = '/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/test'
1301
+ dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized'
1302
+
1303
+ if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
1304
+ RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
1305
+ else:
1306
+ RENDERS_DIR = ''
1307
+
1308
+
1309
+ ckpt_dir = os.path.dirname(ckpt_path)
1310
+ eval_dir = os.path.join(ckpt_dir, "evaluate")
1311
+ os.makedirs(eval_dir, exist_ok=True)
1312
+
1313
+ if False:
1314
+ for i in range(NUM_SAMPLES):
1315
+ print("--- Starting Latent Space PCA Visualization ---")
1316
+ tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
1317
+ tester.visualize_latent_space_pca(sample_idx=i)
1318
+ print("--- PCA Visualization Finished ---")
1319
+
1320
+ if not evaluate_all_checkpoints:
1321
+ evaluate_checkpoint(ckpt_path, dataset_path, eval_dir)
1322
+ else:
1323
+ pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
1324
+
1325
+ filtered_pt_files = []
1326
+ for f in pt_files:
1327
+ try:
1328
+ parts = f.split('_')
1329
+ epoch_str = parts[1].replace('epoch', '')
1330
+ epoch = int(epoch_str)
1331
+ if EPOCH_START <= epoch <= EPOCH_END:
1332
+ filtered_pt_files.append(f)
1333
+ except Exception as e:
1334
+ print(f"Warning: Could not parse epoch from {f}: {e}")
1335
+ continue
1336
+
1337
+ for pt_file in filtered_pt_files:
1338
+ full_ckpt_path = os.path.join(ckpt_dir, pt_file)
1339
+ evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir)
test_slat_vae_128to1024_pointnet_vae_head_woca.py ADDED
The diff for this file is too large to render. See raw diff
 
test_slat_vae_128to256_pointnet_vae_head.py ADDED
@@ -0,0 +1,1349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from tqdm import tqdm
7
+ from collections import defaultdict
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+ from torch.utils.data import DataLoader, Subset
11
+ from triposf.modules.sparse.basic import SparseTensor
12
+
13
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
14
+
15
+ from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead
16
+ from utils import load_pretrained_woself
17
+
18
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
19
+
20
+
21
+ from functools import partial
22
+ import itertools
23
+ from typing import List, Tuple, Set
24
+ from collections import OrderedDict
25
+ from scipy.spatial import cKDTree
26
+ from sklearn.neighbors import KDTree
27
+
28
+ import trimesh
29
+
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import time
33
+
34
+ from sklearn.decomposition import PCA
35
+ import matplotlib.pyplot as plt
36
+
37
+ import networkx as nx
38
+
39
+ def predict_mesh_connectivity(
40
+ connection_head,
41
+ vtx_feats,
42
+ vtx_coords,
43
+ batch_size=10000,
44
+ threshold=0.5,
45
+ k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测
46
+ device='cuda'
47
+ ):
48
+ """
49
+ Args:
50
+ connection_head: 训练好的 MLP 模型
51
+ vtx_feats: [N, C] 顶点特征
52
+ vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边)
53
+ batch_size: MLP 推理的 batch size
54
+ threshold: 判定连接的概率阈值
55
+ k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。
56
+ """
57
+ num_verts = vtx_feats.shape[0]
58
+ if num_verts < 3:
59
+ return [], [] # 无法构成三角形
60
+
61
+ connection_head.eval()
62
+
63
+ # --- 1. 生成候选边 (Candidate Edges) ---
64
+ if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts:
65
+ # 策略 A: 局部 KNN (推荐)
66
+ # 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版
67
+ # 或者直接暴力 cdist 如果 N < 10000
68
+
69
+ # 为了简单且高效,这里演示简单的 cdist (注意显存)
70
+ # 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree
71
+ dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N]
72
+
73
+ # 取 topk (smallest distance),排除自己
74
+ # values: [N, K], indices: [N, K]
75
+ _, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False)
76
+ neighbor_indices = indices[:, 1:] # 去掉第一列(自己)
77
+
78
+ # 构建 source, target 索引
79
+ src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten()
80
+ dst = neighbor_indices.flatten()
81
+
82
+ # 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重
83
+ # 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可
84
+ # 这里为了简单,我们生成 u < v 的 mask
85
+ mask = src < dst
86
+ u_indices = src[mask]
87
+ v_indices = dst[mask]
88
+
89
+ else:
90
+ # 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用
91
+ u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device)
92
+
93
+ # --- 2. 批量推理 ---
94
+ all_probs = []
95
+ num_candidates = u_indices.shape[0]
96
+
97
+ with torch.no_grad():
98
+ for i in range(0, num_candidates, batch_size):
99
+ end = min(i + batch_size, num_candidates)
100
+ batch_u = u_indices[i:end]
101
+ batch_v = v_indices[i:end]
102
+
103
+ feat_u = vtx_feats[batch_u]
104
+ feat_v = vtx_feats[batch_v]
105
+
106
+ # Symmetric Forward (和你训练时保持一致)
107
+ # A -> B
108
+ input_uv = torch.cat([feat_u, feat_v], dim=-1)
109
+ logits_uv = connection_head(input_uv)
110
+
111
+ # B -> A
112
+ input_vu = torch.cat([feat_v, feat_u], dim=-1)
113
+ logits_vu = connection_head(input_vu)
114
+
115
+ # Sum logits
116
+ logits = logits_uv + logits_vu
117
+ probs = torch.sigmoid(logits)
118
+ all_probs.append(probs)
119
+
120
+ all_probs = torch.cat(all_probs).squeeze() # [M]
121
+
122
+ # --- 3. 筛选连接边 ---
123
+ connected_mask = all_probs > threshold
124
+ final_u = u_indices[connected_mask].cpu().numpy()
125
+ final_v = v_indices[connected_mask].cpu().numpy()
126
+
127
+ edges = np.stack([final_u, final_v], axis=1) # [E, 2]
128
+
129
+ return edges
130
+
131
+ def build_triangles_from_edges(edges, num_verts):
132
+ """
133
+ 从边列表构建三角形。
134
+ 寻找图中所有的 3-Cliques (三元环)。
135
+ 这在图论中是一个经典问题,可以使用 networkx 库。
136
+ """
137
+ if len(edges) == 0:
138
+ return np.empty((0, 3), dtype=int)
139
+
140
+ G = nx.Graph()
141
+ G.add_nodes_from(range(num_verts))
142
+ G.add_edges_from(edges)
143
+
144
+ # 寻找所有的 3-cliques (三角形)
145
+ # enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的
146
+ # 或者使用 nx.triangles ? 不,那个只返回数量
147
+ # 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以
148
+
149
+ # 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w
150
+ triangles = []
151
+ adj = [set(G.neighbors(n)) for n in range(num_verts)]
152
+
153
+ # 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w
154
+ # 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u]
155
+
156
+ # 优化算法:
157
+ for u, v in edges:
158
+ if u > v: u, v = v, u # 确保有序
159
+
160
+ # 找公共邻居
161
+ common = adj[u].intersection(adj[v])
162
+ for w in common:
163
+ if w > v: # 强制顺序 u < v < w 防止重复
164
+ triangles.append([u, v, w])
165
+
166
+ return np.array(triangles)
167
+
168
+ def downsample_voxels(
169
+ voxels: torch.Tensor,
170
+ input_resolution: int,
171
+ output_resolution: int
172
+ ) -> torch.Tensor:
173
+ if input_resolution % output_resolution != 0:
174
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
175
+ f"by output_resolution ({output_resolution}).")
176
+
177
+ factor = input_resolution // output_resolution
178
+
179
+ downsampled_voxels = voxels.clone().to(torch.long)
180
+
181
+ downsampled_voxels[:, 1:] //= factor
182
+
183
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
184
+ return unique_downsampled_voxels
185
+
186
+ def visualize_colored_points_ply(coords, vectors, filename):
187
+ """
188
+ 可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。
189
+
190
+ Args:
191
+ coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。
192
+ vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。
193
+ filename (str): 保存输出文件的名称,必须是 .ply 格式。
194
+ """
195
+ # 确保输入是 numpy 数组
196
+ if isinstance(coords, torch.Tensor):
197
+ coords = coords.detach().cpu().numpy()
198
+ if isinstance(vectors, torch.Tensor):
199
+ vectors = vectors.detach().cpu().to(torch.float32).numpy()
200
+
201
+ # 检查输入数据是否为空,防止崩溃
202
+ if coords.size == 0 or vectors.size == 0:
203
+ print(f"警告:输入数据为空,未生成 {filename} 文件。")
204
+ return
205
+
206
+ # 将向量分量从 [-1, 1] 映射到 [0, 255]
207
+ # np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出
208
+ # (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2]
209
+ # * 127.5 将范围从 [0, 2] 缩放到 [0, 255]
210
+ colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8)
211
+
212
+ # 创建一个点云对象,并传入颜色信息
213
+ # trimesh.PointCloud 能够自动处理带颜色的点
214
+ points = trimesh.points.PointCloud(coords, colors=colors)
215
+ # 导出为 PLY 文件
216
+ points.export(filename, file_type='ply')
217
+ print(f"可视化文件已成功保存为: {filename}")
218
+
219
+
220
+ def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0):
221
+ """
222
+ 使用 KDTree 最近邻 + 贪心匹配算法计算顶点匹配 (欧式距离)
223
+
224
+ 参数:
225
+ pred_coords: 预测坐标 (Tensor)
226
+ gt_coords: 真实坐标 (Tensor)
227
+ threshold: 匹配误差阈值 (默认1.0)
228
+
229
+ 返回:
230
+ matches: 匹配成功的顶点数量
231
+ match_rate: 匹配率 (基于真实顶点数)
232
+ pred_total: 预测顶点总数
233
+ gt_total: 真实顶点总数
234
+ """
235
+ # 转换为整数坐标并去重
236
+ print('len(pred_coords)', len(pred_coords))
237
+
238
+ pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0)
239
+ gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0)
240
+ print('len(pred_array)', len(pred_array))
241
+ pred_total = len(pred_array)
242
+ gt_total = len(gt_array)
243
+
244
+ # 如果没有点,直接返回
245
+ if pred_total == 0 or gt_total == 0:
246
+ return 0, 0.0, pred_total, gt_total
247
+
248
+ # 建立 KDTree(以 gt 为基准)
249
+ tree = KDTree(gt_array)
250
+
251
+ # 查找预测点到最近的 gt 点
252
+ dist, indices = tree.query(pred_array, k=1)
253
+ dist = dist.squeeze()
254
+ indices = indices.squeeze()
255
+
256
+ # 贪心去重:确保 1 对 1
257
+ matches = 0
258
+ used_gt = set()
259
+ for d, idx in zip(dist, indices):
260
+ if d <= threshold and idx not in used_gt:
261
+ matches += 1
262
+ used_gt.add(idx)
263
+
264
+ match_rate = matches / max(gt_total, 1)
265
+
266
+ return matches, match_rate, pred_total, gt_total
267
+
268
+ def flatten_coords_4d(coords_4d: torch.Tensor):
269
+ coords_4d_long = coords_4d.long()
270
+
271
+ base_x = 256
272
+ base_y = 256 * 256
273
+ base_z = 256 * 256 * 256
274
+
275
+ flat_coords = coords_4d_long[:, 0] * base_z + \
276
+ coords_4d_long[:, 1] * base_y + \
277
+ coords_4d_long[:, 2] * base_x + \
278
+ coords_4d_long[:, 3]
279
+ return flat_coords
280
+
281
+ def flatten_coords_3d(coords_3d: torch.Tensor):
282
+ coords_3d_long = coords_3d #.long()
283
+
284
+ base_x = 256
285
+ base_y = 256 * 256
286
+ base_z = 256 * 256 * 256
287
+
288
+ flat_coords = coords_3d_long[:, 0] * base_z + \
289
+ coords_3d_long[:, 1] * base_y + \
290
+ coords_3d_long[:, 2] * base_x
291
+ return flat_coords
292
+
293
+ class Tester:
294
+ def __init__(self, ckpt_path, config_path=None, dataset_path=None):
295
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
296
+ self.ckpt_path = ckpt_path
297
+
298
+ self.config = self._load_config(config_path)
299
+ self.dataset_path = dataset_path # or self.config['dataset']['path']
300
+ checkpoint = torch.load(self.ckpt_path, map_location='cpu')
301
+ self.epoch = checkpoint.get('epoch', 0)
302
+
303
+ self._init_models()
304
+ self._init_dataset()
305
+
306
+ self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results")
307
+ os.makedirs(self.result_dir, exist_ok=True)
308
+
309
+ dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '')
310
+ self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path),
311
+ f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs")
312
+ os.makedirs(self.output_voxel_dir, exist_ok=True)
313
+
314
+ self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path),
315
+ f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs")
316
+ os.makedirs(self.output_obj_dir, exist_ok=True)
317
+
318
+
319
+ def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str):
320
+ if coords.numel() == 0:
321
+ return
322
+
323
+ coords_np = coords.cpu().to(torch.float32).numpy()
324
+ labels_np = labels.cpu().to(torch.float32).numpy()
325
+
326
+ colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8)
327
+ colors[labels_np == 0] = [255, 0, 0]
328
+ colors[labels_np == 1] = [0, 0, 255]
329
+
330
+ try:
331
+ import trimesh
332
+ point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors)
333
+ ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
334
+ point_cloud.export(ply_path)
335
+ except ImportError:
336
+ ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
337
+ with open(ply_path, 'w') as f:
338
+ f.write("ply\n")
339
+ f.write("format ascii 1.0\n")
340
+ f.write(f"element vertex {coords_np.shape[0]}\n")
341
+ f.write("property float x\n")
342
+ f.write("property float y\n")
343
+ f.write("property float z\n")
344
+ f.write("property uchar red\n")
345
+ f.write("property uchar green\n")
346
+ f.write("property uchar blue\n")
347
+ f.write("end_header\n")
348
+ for i in range(coords_np.shape[0]):
349
+ f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
350
+
351
+ def _load_config(self, config_path=None):
352
+ if config_path and os.path.exists(config_path):
353
+ with open(config_path) as f:
354
+ return yaml.safe_load(f)
355
+
356
+ ckpt_dir = os.path.dirname(self.ckpt_path)
357
+ possible_configs = [
358
+ os.path.join(ckpt_dir, "config.yaml"),
359
+ os.path.join(os.path.dirname(ckpt_dir), "config.yaml")
360
+ ]
361
+
362
+ for config_file in possible_configs:
363
+ if os.path.exists(config_file):
364
+ with open(config_file) as f:
365
+ print(f"Loaded config from: {config_file}")
366
+ return yaml.safe_load(f)
367
+
368
+ checkpoint = torch.load(self.ckpt_path, map_location='cpu')
369
+ if 'config' in checkpoint:
370
+ print("Loaded config from checkpoint")
371
+ return checkpoint['config']
372
+
373
+ raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.")
374
+
375
+ def _init_models(self):
376
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
377
+ in_channels=15,
378
+ hidden_dim=256,
379
+ out_channels=1024,
380
+ scatter_type='mean',
381
+ n_blocks=5,
382
+ resolution=128,
383
+
384
+ ).to(self.device)
385
+
386
+ self.connection_head = ConnectionHead(
387
+ channels=128 * 2,
388
+ out_channels=1,
389
+ mlp_ratio=4,
390
+ ).to(self.device)
391
+
392
+ self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation
393
+ in_channels=self.config['model']['in_channels'],
394
+ latent_dim=self.config['model']['latent_dim'],
395
+ encoder_blocks=self.config['model']['encoder_blocks'],
396
+ # decoder_blocks=self.config['model']['decoder_blocks'],
397
+ decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'],
398
+ decoder_blocks_edge=self.config['model']['decoder_blocks_edge'],
399
+ num_heads=8,
400
+ num_head_channels=64,
401
+ mlp_ratio=4.0,
402
+ attn_mode="swin",
403
+ window_size=8,
404
+ pe_mode="ape",
405
+ use_fp16=False,
406
+ use_checkpoint=False,
407
+ qk_rms_norm=False,
408
+ using_subdivide=True,
409
+ using_attn=self.config['model']['using_attn'],
410
+ attn_first=self.config['model'].get('attn_first', True),
411
+ pred_direction=self.config['model'].get('pred_direction', False),
412
+ ).to(self.device)
413
+
414
+ load_pretrained_woself(
415
+ checkpoint_path=self.ckpt_path,
416
+ voxel_encoder=self.voxel_encoder,
417
+ connection_head=self.connection_head,
418
+ vae=self.vae,
419
+ )
420
+ # --- 【新增】在这里添加权重检查逻辑 ---
421
+ print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---")
422
+ has_nan_inf = False
423
+ if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"):
424
+ has_nan_inf = True
425
+
426
+ if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"):
427
+ has_nan_inf = True
428
+
429
+ if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"):
430
+ has_nan_inf = True
431
+
432
+ if not has_nan_inf:
433
+ print("--- 权重检查通过。未发现 NaN/Inf 值。 ---")
434
+ else:
435
+ # 如果发现坏值,直接抛出异常,因为评估无法继续
436
+ raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。")
437
+ # --- 检查逻辑结束 ---
438
+
439
+ self.vae.eval()
440
+ self.voxel_encoder.eval()
441
+ self.connection_head.eval()
442
+
443
+ def _init_dataset(self):
444
+ self.dataset = VoxelVertexDataset_edge(
445
+ root_dir=self.dataset_path,
446
+ base_resolution=self.config['dataset']['base_resolution'],
447
+ min_resolution=self.config['dataset']['min_resolution'],
448
+ # cache_dir='./dataset_cache/test_15c_dora',
449
+ cache_dir=self.config['dataset']['cache_dir'],
450
+ renders_dir=self.config['dataset']['renders_dir'],
451
+
452
+ # filter_active_voxels=self.config['dataset']['filter_active_voxels'],
453
+ filter_active_voxels=False,
454
+ cache_filter_path=self.config['dataset']['cache_filter_path'],
455
+
456
+ sample_type=self.config['dataset']['sample_type'],
457
+ active_voxel_res=128,
458
+ pc_sample_number=819200,
459
+
460
+ )
461
+
462
+ self.dataloader = DataLoader(
463
+ self.dataset,
464
+ batch_size=1,
465
+ shuffle=False,
466
+ collate_fn=partial(collate_fn_pointnet),
467
+ num_workers=0,
468
+ pin_memory=True,
469
+ # prefetch_factor=4,
470
+ )
471
+
472
+ def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool:
473
+ """
474
+ 检查模型的所有参数中是否存在 NaN 或 Inf 值。
475
+
476
+ Args:
477
+ model (torch.nn.Module): 要检查的模型。
478
+ model_name (str): 模型的名称,用于打印日志。
479
+
480
+ Returns:
481
+ bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。
482
+ """
483
+ found_issue = False
484
+ for name, param in model.named_parameters():
485
+ if torch.isnan(param.data).any():
486
+ print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!")
487
+ found_issue = True
488
+ if torch.isinf(param.data).any():
489
+ print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!")
490
+ found_issue = True
491
+ return found_issue
492
+
493
+
494
+ def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
495
+ """
496
+ 修改后的函数,确保一对一匹配,并优先匹配最近的点对。
497
+ """
498
+ pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
499
+ gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
500
+
501
+ pred_total = len(pred_array)
502
+ gt_total = len(gt_array)
503
+
504
+ if pred_total == 0 or gt_total == 0:
505
+ return {
506
+ 'recall': 0.0,
507
+ 'precision': 0.0,
508
+ 'f1': 0.0,
509
+ 'matches': 0,
510
+ 'pred_count': pred_total,
511
+ 'gt_count': gt_total
512
+ }
513
+
514
+ # 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点
515
+ tree = cKDTree(pred_array)
516
+ dists, pred_idxs = tree.query(gt_array, k=1)
517
+
518
+ # --- 核心修改部分 ---
519
+
520
+ # 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引)
521
+ # 这样我们就可以按距离对所有可能的匹配进行排序
522
+ possible_matches = []
523
+ for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)):
524
+ if dist <= threshold:
525
+ possible_matches.append((dist, gt_idx, pred_idx))
526
+
527
+ # 2. 按距离从小到大排序(贪心策略)
528
+ possible_matches.sort(key=lambda x: x[0])
529
+
530
+ matches = 0
531
+ # 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配
532
+ used_pred_indices = set()
533
+ used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨
534
+
535
+ # 3. 遍历排序后的可能匹配,进行一对一分配
536
+ for dist, gt_idx, pred_idx in possible_matches:
537
+ # 如果这个预测点和这个真实点都还没有被使用过
538
+ if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices:
539
+ matches += 1
540
+ used_pred_indices.add(pred_idx)
541
+ used_gt_indices.add(gt_idx)
542
+
543
+ # --- 修改结束 ---
544
+
545
+ # matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total
546
+ recall = matches / gt_total if gt_total > 0 else 0.0
547
+ precision = matches / pred_total if pred_total > 0 else 0.0
548
+
549
+ # 计算F1时,使用标准的 Precision 和 Recall 定义
550
+ if (precision + recall) == 0:
551
+ f1 = 0.0
552
+ else:
553
+ f1 = 2 * (precision * recall) / (precision + recall)
554
+
555
+ return {
556
+ 'recall': recall,
557
+ 'precision': precision,
558
+ 'f1': f1,
559
+ 'matches': matches,
560
+ 'pred_count': pred_total,
561
+ 'gt_count': gt_total
562
+ }
563
+
564
+ def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
565
+ """
566
+ 一个折衷的顶点指标计算方案。
567
+ 它沿用“为每个真实点寻找最近预测点”的逻辑,
568
+ 但通过修正计算方式,确保Precision和F1值不会超过1.0。
569
+ """
570
+ # 假设 pred_coords 和 gt_coords 是 PyTorch 张量
571
+ pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
572
+ gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
573
+
574
+ pred_total = len(pred_array)
575
+ gt_total = len(gt_array)
576
+
577
+ if pred_total == 0 or gt_total == 0:
578
+ return {
579
+ 'recall': 0.0,
580
+ 'precision': 0.0,
581
+ 'f1': 0.0,
582
+ 'matches': 0,
583
+ 'pred_count': pred_total,
584
+ 'gt_count': gt_total
585
+ }
586
+
587
+ # 在预测点上构建KD-Tree,为每个真实点查找最近的预测点
588
+ tree = cKDTree(pred_array)
589
+ dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引
590
+
591
+ # 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall)
592
+ # 这和您的第一个函数完全一样。
593
+ # 这个值代表了“有多少个真实点被成功找到了”。
594
+ matches_from_gt = np.sum(dists <= threshold)
595
+
596
+ # 2. 计算 Recall (召回率)
597
+ # 召回率的分母是真实点的总数,所以这里的计算是合理的。
598
+ recall = matches_from_gt / gt_total if gt_total > 0 else 0.0
599
+
600
+ # 3. 计算 Precision (精确率) - ✅ 这是核心修正点
601
+ # 精确率的分母是预测点的总数。
602
+ # 分子(True Positives)不能超过预测点的总数。
603
+ # 因此,我们取 matches_from_gt 和 pred_total 中的较小值。
604
+ # 这解决了 Precision > 1 的问题。
605
+ tp_for_precision = min(matches_from_gt, pred_total)
606
+ precision = tp_for_precision / pred_total if pred_total > 0 else 0.0
607
+
608
+ # 4. 使用标准的F1分数公式
609
+ # 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score,
610
+ # 更常用的是基于 Precision 和 Recall 的调和平均数。
611
+ if (precision + recall) == 0:
612
+ f1 = 0.0
613
+ else:
614
+ f1 = 2 * (precision * recall) / (precision + recall)
615
+
616
+ return {
617
+ 'recall': recall,
618
+ 'precision': precision,
619
+ 'f1': f1,
620
+ 'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察
621
+ 'pred_count': pred_total,
622
+ 'gt_count': gt_total
623
+ }
624
+
625
+ def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False):
626
+ if len(p1) == 0 or len(p2) == 0:
627
+ return float('nan')
628
+
629
+ dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean()
630
+
631
+ if one_sided:
632
+ return dist_p1_p2.item()
633
+ else:
634
+ dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean()
635
+ return (dist_p1_p2 + dist_p2_p1).item() / 2
636
+
637
+ def visualize_latent_space_pca(self, sample_idx: int):
638
+ """
639
+ Encodes a sample, performs PCA on its latent features, and saves a
640
+ colored PLY file for visualization.
641
+
642
+ The position of each point in the PLY file corresponds to the spatial
643
+ location in the latent grid.
644
+
645
+ The color of each point represents the first three principal components
646
+ of its feature vector.
647
+ """
648
+ print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---")
649
+ self.vae.eval()
650
+
651
+ try:
652
+ # 1. Get the latent representation for the sample
653
+ latent = self._get_latent_for_sample(sample_idx)
654
+ except ValueError as e:
655
+ print(f"Error: {e}")
656
+ return
657
+
658
+ latent_coords = latent.coords.detach().cpu().numpy()
659
+ latent_feats = latent.feats.detach().cpu().numpy()
660
+
661
+ if latent_feats.shape[0] < 3:
662
+ print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.")
663
+ return
664
+
665
+ print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...")
666
+
667
+ # 2. Perform PCA to reduce feature dimensions to 3
668
+ pca = PCA(n_components=3)
669
+ pca_features = pca.fit_transform(latent_feats)
670
+
671
+ print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}")
672
+ print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}")
673
+
674
+ # 3. Normalize the PCA components to be used as RGB colors [0, 255]
675
+ # We normalize each component independently to maximize color contrast
676
+ normalized_colors = np.zeros_like(pca_features)
677
+ for i in range(3):
678
+ min_val = pca_features[:, i].min()
679
+ max_val = pca_features[:, i].max()
680
+ if max_val - min_val > 1e-6:
681
+ normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val)
682
+ else:
683
+ normalized_colors[:, i] = 0.5 # Handle case of constant value
684
+
685
+ colors_uint8 = (normalized_colors * 255).astype(np.uint8)
686
+
687
+ # 4. Prepare spatial coordinates for the point cloud
688
+ # latent_coords is (batch_idx, x, y, z), we want the xyz part
689
+ spatial_coords = latent_coords[:, 1:]
690
+
691
+ # 5. Create and save the colored PLY file
692
+ try:
693
+ # Create a Trimesh PointCloud object
694
+ point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8)
695
+
696
+ # Define the output filename
697
+ filename = f"sample_{sample_idx}_latent_pca.ply"
698
+ ply_path = os.path.join(self.output_voxel_dir, filename)
699
+
700
+ # Export the file
701
+ point_cloud.export(ply_path)
702
+ print(f"--> Successfully saved PCA visualization to: {ply_path}")
703
+
704
+ except Exception as e:
705
+ print(f"Error during Trimesh export: {e}")
706
+ print("Please ensure 'trimesh' is installed correctly.")
707
+
708
+ def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor:
709
+ """
710
+ Encodes a single sample and returns its latent representation.
711
+ """
712
+ print(f"--> Encoding sample {sample_idx} to get its latent vector...")
713
+ # Get data for the specified sample
714
+ batch_data = self.dataset[sample_idx]
715
+ if batch_data is None:
716
+ raise ValueError(f"Sample at index {sample_idx} could not be loaded.")
717
+
718
+ # Use the collate function to form a batch
719
+ batch_data = collate_fn_pointnet([batch_data])
720
+
721
+ with torch.no_grad():
722
+ # 1. Get input data and move to device
723
+ gt_vertex_voxels_256 = batch_data['gt_vertex_voxels_256'].to(self.device)
724
+ combined_voxels_256 = batch_data['combined_voxels_256'].to(self.device)
725
+ active_coords = batch_data['active_voxels_128'].to(self.device)
726
+ point_cloud = batch_data['point_cloud_128'].to(self.device)
727
+
728
+ vtx_128 = downsample_voxels(gt_vertex_voxels_256, input_resolution=256, output_resolution=128)
729
+ edge_128 = downsample_voxels(combined_voxels_256, input_resolution=256, output_resolution=128)
730
+
731
+
732
+ active_voxel_feats = self.voxel_encoder(
733
+ p=point_cloud,
734
+ sparse_coords=active_coords,
735
+ res=128,
736
+ bbox_size=(-0.5, 0.5),
737
+
738
+ # voxel_label=active_labels,
739
+ )
740
+
741
+ sparse_input = SparseTensor(
742
+ feats=active_voxel_feats,
743
+ coords=active_coords.int()
744
+ )
745
+
746
+ # 2. Encode to get the latent representation
747
+ latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,)
748
+ print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}")
749
+ return latent_128
750
+
751
+
752
+ def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.):
753
+ total_samples = len(self.dataset)
754
+ eval_samples = min(num_samples or total_samples, total_samples)
755
+ # sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples)
756
+ sample_indices = range(eval_samples)
757
+
758
+ eval_dataset = Subset(self.dataset, sample_indices)
759
+ eval_loader = DataLoader(
760
+ eval_dataset,
761
+ batch_size=1,
762
+ shuffle=False,
763
+ collate_fn=partial(collate_fn_pointnet),
764
+ num_workers=self.config['training']['num_workers'],
765
+ pin_memory=True,
766
+ )
767
+
768
+ per_sample_metrics = {
769
+ 'vertex': {res: [] for res in [128, 256]},
770
+ 'edge': {res: [] for res in [128, 256]},
771
+ 'sample_names': []
772
+ }
773
+ avg_metrics = {
774
+ 'vertex': {res: defaultdict(list) for res in [128, 256]},
775
+ 'edge': {res: defaultdict(list) for res in [128, 256]},
776
+ }
777
+
778
+ self.vae.eval()
779
+
780
+ for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")):
781
+ if batch_data is None:
782
+ continue
783
+ sample_idx = sample_indices[batch_idx]
784
+ sample_name = f'sample_{sample_idx}'
785
+ per_sample_metrics['sample_names'].append(sample_name)
786
+
787
+ # batch_save_path = f"/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/215000_sample_active_vis_42seed_1000complex/gt_data_batch_{batch_idx}.pt"
788
+ # if not os.path.exists(batch_save_path):
789
+ # print(f"Warning: Saved batch file not found: {batch_save_path}")
790
+ # continue
791
+ # batch_data = torch.load(batch_save_path, map_location=self.device)
792
+
793
+ with torch.no_grad():
794
+ # 1. Get input data
795
+ combined_voxels_256 = batch_data['combined_voxels_256'].to(self.device)
796
+ combined_voxel_labels_256 = batch_data['combined_voxel_labels_256'].to(self.device)
797
+ gt_combined_endpoints_256 = batch_data['gt_combined_endpoints_256'].to(self.device)
798
+ gt_combined_errors_256 = batch_data['gt_combined_errors_256'].to(self.device)
799
+
800
+ edge_mask = (combined_voxel_labels_256 == 1)
801
+
802
+ gt_edge_endpoints_256 = gt_combined_endpoints_256[edge_mask].to(self.device)
803
+ gt_edge_errors_256 = gt_combined_errors_256[edge_mask].to(self.device)
804
+ gt_edge_voxels_256 = combined_voxels_256[edge_mask].to(self.device)
805
+
806
+ p1 = gt_edge_endpoints_256[:, 1:4].float()
807
+ p2 = gt_edge_endpoints_256[:, 4:7].float()
808
+
809
+ mask = ( (p1[:,0] < p2[:,0]) |
810
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
811
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
812
+
813
+ pA = torch.where(mask[:, None], p1, p2) # smaller one
814
+ pB = torch.where(mask[:, None], p2, p1) # larger one
815
+
816
+ d = pB - pA
817
+ dir_gt = F.normalize(d, dim=-1, eps=1e-6)
818
+
819
+ gt_vertex_voxels_256 = batch_data['gt_vertex_voxels_256'].to(self.device).int()
820
+
821
+ vtx_128 = downsample_voxels(gt_vertex_voxels_256, input_resolution=256, output_resolution=128)
822
+
823
+ edge_128 = downsample_voxels(combined_voxels_256, input_resolution=256, output_resolution=128)
824
+ edge_256 = combined_voxels_256
825
+
826
+ print('vtx_128.shape', vtx_128.shape)
827
+ print('edge_128.shape', edge_128.shape)
828
+
829
+ gt_edge_voxels_list = [
830
+ edge_128,
831
+ edge_256,
832
+ ]
833
+
834
+ active_coords = batch_data['active_voxels_128'].to(self.device)
835
+ point_cloud = batch_data['point_cloud_128'].to(self.device)
836
+
837
+
838
+
839
+ active_voxel_feats = self.voxel_encoder(
840
+ p=point_cloud,
841
+ sparse_coords=active_coords,
842
+ res=128,
843
+ bbox_size=(-0.5, 0.5),
844
+
845
+ # voxel_label=active_labels,
846
+ )
847
+
848
+ sparse_input = SparseTensor(
849
+ feats=active_voxel_feats,
850
+ coords=active_coords.int()
851
+ )
852
+
853
+ latent_128, posterior = self.vae.encode(sparse_input)
854
+
855
+
856
+ # load_path = f'/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/215000_sample_active_vis_42seed_1000complex/sample_latent_{batch_idx}.pt'
857
+ # latent_128 = torch.load(load_path, map_location=self.device)
858
+
859
+ print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std())
860
+ print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean())
861
+ print('latent_128.coords.shape', latent_128.coords.shape)
862
+
863
+ # latent_128 = torch.load(f"/root/Trisf/output_slat_flow_matching/ckpts/1100_chair_sample/110000step_sample/sample_results_samples_{batch_idx}.pt", map_location=self.device)
864
+ latent_128 = SparseTensor(
865
+ coords=latent_128.coords,
866
+ feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats),
867
+ )
868
+
869
+ # self.output_voxel_dir = os.path.dirname(load_path)
870
+ # self.output_obj_dir = os.path.dirname(load_path)
871
+
872
+ # 7. Decoding with separate vertex and edge processing
873
+ decoded_results = self.vae.decode(
874
+ latent_128,
875
+ gt_vertex_voxels_list=[],
876
+ gt_edge_voxels_list=[],
877
+ training=False,
878
+
879
+ inference_threshold=0.5,
880
+ vis_last_layer=False,
881
+ )
882
+
883
+ error = 0 # decoded_results[-1]['edge']['predicted_offset_feats']
884
+
885
+ if self.config['model'].get('pred_direction', False):
886
+ pred_dir = decoded_results[-1]['edge']['predicted_direction_feats']
887
+ zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0
888
+ num_zeros = zero_mask.sum().item()
889
+ print("Number of zero vectors:", num_zeros)
890
+
891
+ pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
892
+ print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
893
+ print('pred_dir.shape', pred_dir.shape)
894
+ if pred_edge_coords_3d.shape[-1] == 4:
895
+ pred_edge_coords_3d = pred_edge_coords_3d[:, 1:]
896
+ # visualize_directions(pred_edge_coords_3d, pred_dir, sample_ratio=0.02)
897
+ save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply")
898
+ # visualize_colored_points_ply(pred_edge_coords_3d - error / 2. + 0.5, pred_dir, save_pth)
899
+ visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth)
900
+
901
+ save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply")
902
+ # visualize_colored_points_ply((gt_edge_voxels_256[:, 1:] - gt_edge_errors_256[:, 1:] + 0.5), dir_gt, save_pth)
903
+ visualize_colored_points_ply((gt_edge_voxels_256[:, 1:]), dir_gt, save_pth)
904
+
905
+
906
+ pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords']
907
+ pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
908
+
909
+ pred_edge_coords_np = np.round(pred_edge_coords_3d.cpu().numpy()).astype(int)
910
+
911
+ gt_vertex_voxels_256 = batch_data['gt_vertex_voxels_256'][:, 1:].to(self.device)
912
+ gt_edge_voxels_256 = batch_data['gt_edge_voxels_256'][:, 1:].to(self.device)
913
+ gt_edge_coords_np = np.round(gt_edge_voxels_256.cpu().numpy()).astype(int)
914
+
915
+
916
+ # Calculate metrics and save results
917
+ matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_256, threshold=threshold,)
918
+ print(f"\n----- Resolution {256} vtx -----")
919
+ print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
920
+ print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
921
+
922
+ self._save_voxel_ply(pred_vtx_coords_3d / 256., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx")
923
+ self._save_voxel_ply((pred_edge_coords_3d - error / 2. + 0.5) / 256, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge")
924
+
925
+ self._save_voxel_ply(gt_vertex_voxels_256 / 256, torch.zeros(len(gt_vertex_voxels_256)), f"{sample_name}_gt_vertex")
926
+ self._save_voxel_ply((combined_voxels_256[:, 1:] - gt_combined_errors_256[:, 1:] + 0.5) / 256., torch.zeros(len(gt_combined_errors_256)), f"{sample_name}_gt_edge")
927
+
928
+
929
+ # Calculate vertex-specific metrics
930
+ matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_256[:, 1:], threshold=threshold,)
931
+ print(f"\n----- Resolution {256} edge -----")
932
+ print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
933
+ print('gt_edge_voxels_256.shape', gt_edge_voxels_256.shape)
934
+
935
+ print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
936
+ print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
937
+
938
+ pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int)
939
+ pred_edges = []
940
+ gt_vertex_coords_np = np.round(gt_vertex_voxels_256.cpu().numpy()).astype(int)
941
+ if visualize:
942
+ if pred_vtx_coords_3d.shape[-1] == 4:
943
+ pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float()
944
+ else:
945
+ pred_vtx_coords_float = pred_vtx_coords_3d.float()
946
+
947
+ pred_vtx_feats = decoded_results[-1]['vertex']['feats']
948
+
949
+ # ==========================================
950
+ # Link Prediction & Mesh Generation
951
+ # ==========================================
952
+ print("Predicting connectivity...")
953
+
954
+ # 1. 预测边
955
+ # 注意:K_neighbors 的设置。如果是物体,64 足够了。
956
+ # 如果点非常稀疏,可能需要更大。
957
+ pred_edges = predict_mesh_connectivity(
958
+ connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的
959
+ vtx_feats=pred_vtx_feats,
960
+ vtx_coords=pred_vtx_coords_float,
961
+ batch_size=4096,
962
+ threshold=0.5,
963
+ k_neighbors=None,
964
+ device=self.device
965
+ )
966
+ print(f"Predicted {len(pred_edges)} edges.")
967
+
968
+ # 2. 构建三角形
969
+ num_verts = pred_vtx_coords_float.shape[0]
970
+ pred_faces = build_triangles_from_edges(pred_edges, num_verts)
971
+ print(f"Constructed {len(pred_faces)} triangles.")
972
+
973
+ # 3. 保存 OBJ
974
+ import trimesh
975
+
976
+ # 坐标归一化/还原 (根据你的需求,这里假设你是 0-256 的体素坐标)
977
+ # 如果想保存为归一化坐标:
978
+ mesh_verts = pred_vtx_coords_float.cpu().numpy() / 256.0
979
+
980
+ # 如果有 error offset,记得加上!
981
+ # 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
982
+ # 如果 vertex 也有 offset (如 dual contouring),在这里加上
983
+
984
+ # 移动到中心 (可选)
985
+ mesh_verts = mesh_verts - 0.5
986
+
987
+ mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
988
+
989
+ # 过滤孤立点 (可选)
990
+ # mesh.remove_unreferenced_vertices()
991
+
992
+ output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj")
993
+ mesh.export(output_obj_path)
994
+ print(f"Saved mesh to {output_obj_path}")
995
+
996
+ # 保存边线 (用于 Debug)
997
+ # 有时候三角形很难形成,只看边也很有用
998
+ edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply")
999
+ # self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison")
1000
+
1001
+
1002
+ # Process results at different resolutions
1003
+ for i, res in enumerate([128, 256]):
1004
+ if i >= len(decoded_results):
1005
+ continue
1006
+
1007
+ gt_key = f'gt_vertex_voxels_{res}'
1008
+ if gt_key not in batch_data:
1009
+ continue
1010
+ if i == 0:
1011
+ pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float()
1012
+ gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
1013
+ else:
1014
+ pred_coords_res = decoded_results[i]['vertex']['coords'].float()
1015
+ gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
1016
+
1017
+
1018
+ v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold)
1019
+
1020
+ per_sample_metrics['vertex'][res].append({
1021
+ 'recall': v_metrics['recall'],
1022
+ 'precision': v_metrics['precision'],
1023
+ 'f1': v_metrics['f1'],
1024
+ 'num_pred': len(pred_coords_res),
1025
+ 'num_gt': len(gt_coords_res)
1026
+ })
1027
+
1028
+ avg_metrics['vertex'][res]['recall'].append(v_metrics['recall'])
1029
+ avg_metrics['vertex'][res]['precision'].append(v_metrics['precision'])
1030
+ avg_metrics['vertex'][res]['f1'].append(v_metrics['f1'])
1031
+
1032
+ gt_edge_key = f'gt_edge_voxels_{res}'
1033
+ if gt_edge_key not in batch_data:
1034
+ continue
1035
+
1036
+ if i == 0:
1037
+ pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float()
1038
+ # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1039
+ idx = i
1040
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1041
+ elif i == 1:
1042
+ idx = i
1043
+ #################################
1044
+ # pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5
1045
+ # # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1046
+ # gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_256[:, 1:].to(self.device) + 0.5
1047
+
1048
+
1049
+ pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
1050
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1051
+
1052
+ # self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset")
1053
+ # self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset")
1054
+
1055
+ else:
1056
+ idx = i
1057
+ pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
1058
+ # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1059
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1060
+
1061
+ # self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res")
1062
+ # self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res")
1063
+
1064
+ e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold)
1065
+
1066
+ per_sample_metrics['edge'][res].append({
1067
+ 'recall': e_metrics['recall'],
1068
+ 'precision': e_metrics['precision'],
1069
+ 'f1': e_metrics['f1'],
1070
+ 'num_pred': len(pred_edge_coords_res),
1071
+ 'num_gt': len(gt_edge_coords_res)
1072
+ })
1073
+
1074
+ avg_metrics['edge'][res]['recall'].append(e_metrics['recall'])
1075
+ avg_metrics['edge'][res]['precision'].append(e_metrics['precision'])
1076
+ avg_metrics['edge'][res]['f1'].append(e_metrics['f1'])
1077
+
1078
+ avg_metrics_processed = {}
1079
+ for category, res_dict in avg_metrics.items():
1080
+ avg_metrics_processed[category] = {}
1081
+ for resolution, metric_dict in res_dict.items():
1082
+ avg_metrics_processed[category][resolution] = {
1083
+ metric_name: np.mean(values) if values else float('nan')
1084
+ for metric_name, values in metric_dict.items()
1085
+ }
1086
+
1087
+ result_data = {
1088
+ 'config': self.config,
1089
+ 'checkpoint': self.ckpt_path,
1090
+ 'dataset': self.dataset_path,
1091
+ 'num_samples': eval_samples,
1092
+ 'per_sample_metrics': per_sample_metrics,
1093
+ 'avg_metrics': avg_metrics_processed
1094
+ }
1095
+
1096
+ results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml")
1097
+ with open(results_file_path, 'w') as f:
1098
+ yaml.dump(result_data, f, default_flow_style=False)
1099
+
1100
+ return result_data
1101
+
1102
+
1103
+
1104
+ def _generate_line_voxels(
1105
+ self,
1106
+ p1: torch.Tensor,
1107
+ p2: torch.Tensor
1108
+ ) -> Tuple[
1109
+ List[Tuple[int, int, int]],
1110
+ List[Tuple[torch.Tensor, torch.Tensor]],
1111
+ List[np.ndarray]
1112
+ ]:
1113
+ """
1114
+ Improved version using better sampling strategy
1115
+ """
1116
+ p1_np = p1 #.cpu().numpy()
1117
+ p2_np = p2 #.cpu().numpy()
1118
+ voxel_dict = OrderedDict()
1119
+
1120
+ # Use proper 3D line voxelization algorithm
1121
+ def bresenham_3d(p1, p2):
1122
+ """3D Bresenham's line algorithm"""
1123
+ x1, y1, z1 = np.round(p1).astype(int)
1124
+ x2, y2, z2 = np.round(p2).astype(int)
1125
+
1126
+ points = []
1127
+ dx = abs(x2 - x1)
1128
+ dy = abs(y2 - y1)
1129
+ dz = abs(z2 - z1)
1130
+
1131
+ xs = 1 if x2 > x1 else -1
1132
+ ys = 1 if y2 > y1 else -1
1133
+ zs = 1 if z2 > z1 else -1
1134
+
1135
+ # Driving axis is X
1136
+ if dx >= dy and dx >= dz:
1137
+ err_1 = 2 * dy - dx
1138
+ err_2 = 2 * dz - dx
1139
+ for i in range(dx + 1):
1140
+ points.append((x1, y1, z1))
1141
+ if err_1 > 0:
1142
+ y1 += ys
1143
+ err_1 -= 2 * dx
1144
+ if err_2 > 0:
1145
+ z1 += zs
1146
+ err_2 -= 2 * dx
1147
+ err_1 += 2 * dy
1148
+ err_2 += 2 * dz
1149
+ x1 += xs
1150
+
1151
+ # Driving axis is Y
1152
+ elif dy >= dx and dy >= dz:
1153
+ err_1 = 2 * dx - dy
1154
+ err_2 = 2 * dz - dy
1155
+ for i in range(dy + 1):
1156
+ points.append((x1, y1, z1))
1157
+ if err_1 > 0:
1158
+ x1 += xs
1159
+ err_1 -= 2 * dy
1160
+ if err_2 > 0:
1161
+ z1 += zs
1162
+ err_2 -= 2 * dy
1163
+ err_1 += 2 * dx
1164
+ err_2 += 2 * dz
1165
+ y1 += ys
1166
+
1167
+ # Driving axis is Z
1168
+ else:
1169
+ err_1 = 2 * dx - dz
1170
+ err_2 = 2 * dy - dz
1171
+ for i in range(dz + 1):
1172
+ points.append((x1, y1, z1))
1173
+ if err_1 > 0:
1174
+ x1 += xs
1175
+ err_1 -= 2 * dz
1176
+ if err_2 > 0:
1177
+ y1 += ys
1178
+ err_2 -= 2 * dz
1179
+ err_1 += 2 * dx
1180
+ err_2 += 2 * dy
1181
+ z1 += zs
1182
+
1183
+ return points
1184
+
1185
+ # Get all voxels using Bresenham algorithm
1186
+ voxel_coords = bresenham_3d(p1_np, p2_np)
1187
+
1188
+ # Add all voxels to dictionary
1189
+ for coord in voxel_coords:
1190
+ voxel_dict[tuple(coord)] = (p1, p2)
1191
+
1192
+ voxel_coords = list(voxel_dict.keys())
1193
+ endpoint_pairs = list(voxel_dict.values())
1194
+
1195
+ # --- compute error vectors ---
1196
+ error_vectors = []
1197
+ diff = p2_np - p1_np
1198
+ d_norm_sq = np.dot(diff, diff)
1199
+
1200
+ for v in voxel_coords:
1201
+ v_center = np.array(v, dtype=float) + 0.5
1202
+ if d_norm_sq == 0: # degenerate line
1203
+ closest = p1_np
1204
+ else:
1205
+ t = np.dot(v_center - p1_np, diff) / d_norm_sq
1206
+ t = np.clip(t, 0.0, 1.0)
1207
+ closest = p1_np + t * diff
1208
+ error_vectors.append(v_center - closest)
1209
+
1210
+ return voxel_coords, endpoint_pairs, error_vectors
1211
+
1212
+
1213
+
1214
+ # 使用示例
1215
+ def set_seed(seed: int):
1216
+ random.seed(seed)
1217
+ np.random.seed(seed)
1218
+ torch.manual_seed(seed)
1219
+ if torch.cuda.is_available():
1220
+ torch.cuda.manual_seed(seed)
1221
+ torch.cuda.manual_seed_all(seed)
1222
+ torch.backends.cudnn.deterministic = True
1223
+ torch.backends.cudnn.benchmark = False
1224
+
1225
+ def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
1226
+ set_seed(42)
1227
+ tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
1228
+ result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD)
1229
+
1230
+ # 生成文件名
1231
+ epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0]
1232
+ dataset_name = os.path.basename(os.path.normpath(dataset_path))
1233
+
1234
+ # 保存简版报告(TXT)
1235
+ summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt")
1236
+ with open(summary_path, 'w') as f:
1237
+ # 头部信息
1238
+ f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n")
1239
+ f.write(f"Dataset: {dataset_name}\n")
1240
+ f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n")
1241
+
1242
+ # 平均指标
1243
+ f.write("=== Average Metrics ===\n")
1244
+ for category, data in result_data['avg_metrics'].items():
1245
+ if isinstance(data, dict): # 处理多分辨率情况
1246
+ f.write(f"\n{category.upper()}:\n")
1247
+ for res, metrics in data.items():
1248
+ f.write(f" Resolution {res}:\n")
1249
+ for k, v in metrics.items():
1250
+ # 确保值是数字类型后再格式化
1251
+ if isinstance(v, (int, float)):
1252
+ f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
1253
+ else:
1254
+ f.write(f" {str(k).ljust(15)}: {str(v)}\n")
1255
+ else: # 处理非多分辨率情况
1256
+ f.write(f"\n{category.upper()}:\n")
1257
+ for k, v in data.items():
1258
+ if isinstance(v, (int, float)):
1259
+ f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
1260
+ else:
1261
+ f.write(f" {str(k).ljust(15)}: {str(v)}\n")
1262
+
1263
+ # 样本级详细统计
1264
+ f.write("\n\n=== Detailed Per-Sample Metrics ===\n")
1265
+ for name, vertex_metrics, edge_metrics in zip(
1266
+ result_data['per_sample_metrics']['sample_names'],
1267
+ zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256]]),
1268
+ zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256]])
1269
+ ):
1270
+ # 样本标题
1271
+ f.write(f"\n◆ Sample: {name}\n")
1272
+
1273
+ # 顶点指标
1274
+ f.write(f"[Vertex Prediction]\n")
1275
+ f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
1276
+ for res, metrics in zip([128, 256], vertex_metrics):
1277
+ f.write(f" {str(res).ljust(10)} "
1278
+ f"{metrics['recall']:.4f} "
1279
+ f"{metrics['precision']:.4f} "
1280
+ f"{metrics['f1']:.4f} "
1281
+ f"{metrics['num_pred']}/{metrics['num_gt']}\n")
1282
+
1283
+ # Edge指标
1284
+ f.write(f"[Edge Prediction]\n")
1285
+ f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
1286
+ for res, metrics in zip([128, 256], edge_metrics):
1287
+ f.write(f" {str(res).ljust(10)} "
1288
+ f"{metrics['recall']:.4f} "
1289
+ f"{metrics['precision']:.4f} "
1290
+ f"{metrics['f1']:.4f} "
1291
+ f"{metrics['num_pred']}/{metrics['num_gt']}\n")
1292
+
1293
+ f.write("-"*60 + "\n")
1294
+
1295
+ print(f"Saved summary to: {summary_path}")
1296
+ return result_data
1297
+
1298
+
1299
+ if __name__ == '__main__':
1300
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
1301
+ evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
1302
+ EPOCH_START = 12
1303
+ EPOCH_END = 12
1304
+ CHAMFER_EDGE_THRESHOLD=0.5
1305
+ NUM_SAMPLES=50
1306
+ VISUALIZE=True
1307
+ THRESHOLD=1.5
1308
+ VISUAL_FIELD=False
1309
+
1310
+ ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to256_dir_sorted_dora_head_small/checkpoint_epoch13_batch6000_loss0.1381.pt'
1311
+ dataset_path = '/gemini/user/private/zhaotianhao/data/test_mesh'
1312
+
1313
+ if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
1314
+ RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
1315
+ else:
1316
+ RENDERS_DIR = ''
1317
+
1318
+
1319
+ ckpt_dir = os.path.dirname(ckpt_path)
1320
+ eval_dir = os.path.join(ckpt_dir, "evaluate")
1321
+ os.makedirs(eval_dir, exist_ok=True)
1322
+
1323
+ if False:
1324
+ for i in range(NUM_SAMPLES):
1325
+ print("--- Starting Latent Space PCA Visualization ---")
1326
+ tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
1327
+ tester.visualize_latent_space_pca(sample_idx=i)
1328
+ print("--- PCA Visualization Finished ---")
1329
+
1330
+ if not evaluate_all_checkpoints:
1331
+ evaluate_checkpoint(ckpt_path, dataset_path, eval_dir)
1332
+ else:
1333
+ pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
1334
+
1335
+ filtered_pt_files = []
1336
+ for f in pt_files:
1337
+ try:
1338
+ parts = f.split('_')
1339
+ epoch_str = parts[1].replace('epoch', '')
1340
+ epoch = int(epoch_str)
1341
+ if EPOCH_START <= epoch <= EPOCH_END:
1342
+ filtered_pt_files.append(f)
1343
+ except Exception as e:
1344
+ print(f"Warning: Could not parse epoch from {f}: {e}")
1345
+ continue
1346
+
1347
+ for pt_file in filtered_pt_files:
1348
+ full_ckpt_path = os.path.join(ckpt_dir, pt_file)
1349
+ evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir)
test_slat_vae_128to512_pointnet_vae_head.py ADDED
@@ -0,0 +1,1636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from tqdm import tqdm
7
+ from collections import defaultdict
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+ from torch.utils.data import DataLoader, Subset
11
+ from triposf.modules.sparse.basic import SparseTensor
12
+
13
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
14
+
15
+ from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead
16
+ from utils import load_pretrained_woself
17
+
18
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
19
+
20
+
21
+ from functools import partial
22
+ import itertools
23
+ from typing import List, Tuple, Set
24
+ from collections import OrderedDict
25
+ from scipy.spatial import cKDTree
26
+ from sklearn.neighbors import KDTree
27
+
28
+ import trimesh
29
+
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import time
33
+
34
+ from sklearn.decomposition import PCA
35
+ import matplotlib.pyplot as plt
36
+
37
+ import networkx as nx
38
+
39
+ def predict_mesh_connectivity(
40
+ connection_head,
41
+ vtx_feats,
42
+ vtx_coords,
43
+ batch_size=10000,
44
+ threshold=0.5,
45
+ k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测
46
+ device='cuda'
47
+ ):
48
+ """
49
+ Args:
50
+ connection_head: 训练好的 MLP 模型
51
+ vtx_feats: [N, C] 顶点特征
52
+ vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边)
53
+ batch_size: MLP 推理的 batch size
54
+ threshold: 判定连接的概率阈值
55
+ k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。
56
+ """
57
+ num_verts = vtx_feats.shape[0]
58
+ if num_verts < 3:
59
+ return [], [] # 无法构成三角形
60
+
61
+ connection_head.eval()
62
+
63
+ # --- 1. 生成候选边 (Candidate Edges) ---
64
+ if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts:
65
+ # 策略 A: 局部 KNN (推荐)
66
+ # 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版
67
+ # 或者直接暴力 cdist 如果 N < 10000
68
+
69
+ # 为了简单且高效,这里演示简单的 cdist (注意显存)
70
+ # 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree
71
+ dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N]
72
+
73
+ # 取 topk (smallest distance),排除自己
74
+ # values: [N, K], indices: [N, K]
75
+ _, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False)
76
+ neighbor_indices = indices[:, 1:] # 去掉第一列(自己)
77
+
78
+ # 构建 source, target 索引
79
+ src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten()
80
+ dst = neighbor_indices.flatten()
81
+
82
+ # 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重
83
+ # 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可
84
+ # 这里为了简单,我们生成 u < v 的 mask
85
+ mask = src < dst
86
+ u_indices = src[mask]
87
+ v_indices = dst[mask]
88
+
89
+ else:
90
+ # 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用
91
+ u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device)
92
+
93
+ # --- 2. 批量推理 ---
94
+ all_probs = []
95
+ num_candidates = u_indices.shape[0]
96
+
97
+ with torch.no_grad():
98
+ for i in range(0, num_candidates, batch_size):
99
+ end = min(i + batch_size, num_candidates)
100
+ batch_u = u_indices[i:end]
101
+ batch_v = v_indices[i:end]
102
+
103
+ feat_u = vtx_feats[batch_u]
104
+ feat_v = vtx_feats[batch_v]
105
+
106
+ # Symmetric Forward (和你训练时保持一致)
107
+ # A -> B
108
+ input_uv = torch.cat([feat_u, feat_v], dim=-1)
109
+ logits_uv = connection_head(input_uv)
110
+
111
+ # B -> A
112
+ input_vu = torch.cat([feat_v, feat_u], dim=-1)
113
+ logits_vu = connection_head(input_vu)
114
+
115
+ # Sum logits
116
+ logits = (logits_uv + logits_vu) / 2.
117
+ probs = torch.sigmoid(logits)
118
+ all_probs.append(probs)
119
+
120
+ all_probs = torch.cat(all_probs).squeeze() # [M]
121
+
122
+ # --- 3. 筛选连接边 ---
123
+ connected_mask = all_probs > threshold
124
+ final_u = u_indices[connected_mask].cpu().numpy()
125
+ final_v = v_indices[connected_mask].cpu().numpy()
126
+
127
+ edges = np.stack([final_u, final_v], axis=1) # [E, 2]
128
+
129
+ return edges
130
+
131
+ def build_triangles_from_edges(edges, num_verts):
132
+ """
133
+ 从边列表构建三角形。
134
+ 寻找图中所有的 3-Cliques (三元环)。
135
+ 这在图论中是一个经典问题,可以使用 networkx 库。
136
+ """
137
+ if len(edges) == 0:
138
+ return np.empty((0, 3), dtype=int)
139
+
140
+ G = nx.Graph()
141
+ G.add_nodes_from(range(num_verts))
142
+ G.add_edges_from(edges)
143
+
144
+ # 寻找所有的 3-cliques (三角形)
145
+ # enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的
146
+ # 或者使用 nx.triangles ? 不,那个只返回数量
147
+ # 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以
148
+
149
+ # 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w
150
+ triangles = []
151
+ adj = [set(G.neighbors(n)) for n in range(num_verts)]
152
+
153
+ # 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w
154
+ # 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u]
155
+
156
+ # 优化算法:
157
+ for u, v in edges:
158
+ if u > v: u, v = v, u # 确保有序
159
+
160
+ # 找公共邻居
161
+ common = adj[u].intersection(adj[v])
162
+ for w in common:
163
+ if w > v: # 强制顺序 u < v < w 防止重复
164
+ triangles.append([u, v, w])
165
+
166
+ return np.array(triangles)
167
+
168
+ def downsample_voxels(
169
+ voxels: torch.Tensor,
170
+ input_resolution: int,
171
+ output_resolution: int
172
+ ) -> torch.Tensor:
173
+ if input_resolution % output_resolution != 0:
174
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
175
+ f"by output_resolution ({output_resolution}).")
176
+
177
+ factor = input_resolution // output_resolution
178
+
179
+ downsampled_voxels = voxels.clone().to(torch.long)
180
+
181
+ downsampled_voxels[:, 1:] //= factor
182
+
183
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
184
+ return unique_downsampled_voxels
185
+
186
+ def visualize_colored_points_ply(coords, vectors, filename):
187
+ """
188
+ 可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。
189
+
190
+ Args:
191
+ coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。
192
+ vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。
193
+ filename (str): 保存输出文件的名称,必须是 .ply 格式。
194
+ """
195
+ # 确保输入是 numpy 数组
196
+ if isinstance(coords, torch.Tensor):
197
+ coords = coords.detach().cpu().numpy()
198
+ if isinstance(vectors, torch.Tensor):
199
+ vectors = vectors.detach().cpu().to(torch.float32).numpy()
200
+
201
+ # 检查输入数据是否为空,防止崩溃
202
+ if coords.size == 0 or vectors.size == 0:
203
+ print(f"警告:输入数据为空,未生成 {filename} 文件。")
204
+ return
205
+
206
+ # 将向量分量从 [-1, 1] 映射到 [0, 255]
207
+ # np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出
208
+ # (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2]
209
+ # * 127.5 将范围从 [0, 2] 缩放到 [0, 255]
210
+ colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8)
211
+
212
+ # 创建一个点云对象,并传入颜色信息
213
+ # trimesh.PointCloud 能够自动处理带颜色的点
214
+ points = trimesh.points.PointCloud(coords, colors=colors)
215
+ # 导出为 PLY 文件
216
+ points.export(filename, file_type='ply')
217
+ print(f"可视化文件已成功保存为: {filename}")
218
+
219
+
220
+ def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0):
221
+ # 转换为整数坐标并去重
222
+ print('len(pred_coords)', len(pred_coords))
223
+
224
+ pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0)
225
+ gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0)
226
+ print('len(pred_array)', len(pred_array))
227
+ pred_total = len(pred_array)
228
+ gt_total = len(gt_array)
229
+
230
+ # 如果没有点,直接返回
231
+ if pred_total == 0 or gt_total == 0:
232
+ return 0, 0.0, pred_total, gt_total
233
+
234
+ # 建立 KDTree(以 gt 为基准)
235
+ tree = KDTree(gt_array)
236
+
237
+ # 查找预测点到最近的 gt 点
238
+ dist, indices = tree.query(pred_array, k=1)
239
+ dist = dist.squeeze()
240
+ indices = indices.squeeze()
241
+
242
+ # 贪心去重:确保 1 对 1
243
+ matches = 0
244
+ used_gt = set()
245
+ for d, idx in zip(dist, indices):
246
+ if d <= threshold and idx not in used_gt:
247
+ matches += 1
248
+ used_gt.add(idx)
249
+
250
+ match_rate = matches / max(gt_total, 1)
251
+
252
+ return matches, match_rate, pred_total, gt_total
253
+
254
+ def flatten_coords_4d(coords_4d: torch.Tensor):
255
+ coords_4d_long = coords_4d.long()
256
+
257
+ base_x = 512
258
+ base_y = 512 * 512
259
+ base_z = 512 * 512 * 512
260
+
261
+ flat_coords = coords_4d_long[:, 0] * base_z + \
262
+ coords_4d_long[:, 1] * base_y + \
263
+ coords_4d_long[:, 2] * base_x + \
264
+ coords_4d_long[:, 3]
265
+ return flat_coords
266
+
267
+ class Tester:
268
+ def __init__(self, ckpt_path, config_path=None, dataset_path=None):
269
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
270
+ self.ckpt_path = ckpt_path
271
+
272
+ self.config = self._load_config(config_path)
273
+ self.dataset_path = dataset_path # or self.config['dataset']['path']
274
+ checkpoint = torch.load(self.ckpt_path, map_location='cpu')
275
+ self.epoch = checkpoint.get('epoch', 0)
276
+
277
+ self._init_models()
278
+ self._init_dataset()
279
+
280
+ self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results")
281
+ os.makedirs(self.result_dir, exist_ok=True)
282
+
283
+ dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '')
284
+ self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path),
285
+ f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs")
286
+ os.makedirs(self.output_voxel_dir, exist_ok=True)
287
+
288
+ self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path),
289
+ f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs")
290
+ os.makedirs(self.output_obj_dir, exist_ok=True)
291
+
292
+ def _save_logit_visualization(self, dense_vol, name, sample_name, ply_threshold=0.01):
293
+ """
294
+ 保存 Logit 的 3D .npy 文件、2D 最大投影热力图,以及带颜色和透明度的 3D .ply 点云
295
+
296
+ Args:
297
+ dense_vol: (H, W, D) numpy array, values in [0, 1]
298
+ name: str (e.g., "edge" or "vertex")
299
+ sample_name: str
300
+ ply_threshold: float, 只有概率大于此值的点才会被保存
301
+ """
302
+ # 1. 保存原始 Dense 数据 (可选)
303
+ npy_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_logits.npy")
304
+ # np.save(npy_path, dense_vol)
305
+
306
+ # 2. 生成 2D 投影热力图 (保持不变)
307
+ proj_x = np.max(dense_vol, axis=0)
308
+ proj_y = np.max(dense_vol, axis=1)
309
+ proj_z = np.max(dense_vol, axis=2)
310
+
311
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
312
+ im0 = axes[0].imshow(proj_x, cmap='turbo', vmin=0, vmax=1, origin='lower')
313
+ axes[0].set_title(f"{name} Max-Proj (YZ)")
314
+ im1 = axes[1].imshow(proj_y, cmap='turbo', vmin=0, vmax=1, origin='lower')
315
+ axes[1].set_title(f"{name} Max-Proj (XZ)")
316
+ im2 = axes[2].imshow(proj_z, cmap='turbo', vmin=0, vmax=1, origin='lower')
317
+ axes[2].set_title(f"{name} Max-Proj (XY)")
318
+
319
+ fig.colorbar(im2, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
320
+ plt.suptitle(f"{sample_name} - {name} Occupancy Probability")
321
+
322
+ png_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_heatmap.png")
323
+ plt.savefig(png_path, dpi=150)
324
+ plt.close(fig)
325
+
326
+ # ------------------------------------------------------------------
327
+ # 3. 保存为带颜色和透明度(RGBA)的 PLY 点云
328
+ # ------------------------------------------------------------------
329
+ # 筛选出概率大于阈值的点坐标
330
+ indices = np.argwhere(dense_vol > ply_threshold)
331
+
332
+ if len(indices) > 0:
333
+ # 获取这些点的概率值 [0, 1]
334
+ values = dense_vol[indices[:, 0], indices[:, 1], indices[:, 2]]
335
+
336
+ # 使用 matplotlib 的 colormap 进行颜色映射
337
+ import matplotlib.cm as cm
338
+ cmap = cm.get_cmap('turbo')
339
+
340
+ # map values [0, 1] to RGBA [0, 1] (N, 4)
341
+ colors_float = cmap(values)
342
+
343
+ # -------------------------------------------------------
344
+ # 【核心修改】:修改 Alpha 通道 (透明度)
345
+ # -------------------------------------------------------
346
+ # 让透明度直接等于概率值。
347
+ # 概率 1.0 -> Alpha 1.0 (完全不透明/颜色深)
348
+ # 概率 0.1 -> Alpha 0.1 (非常透明/颜色浅)
349
+ colors_float[:, 3] = values
350
+
351
+ # 转换为 uint8 [0, 255],保留 4 个通道 (R, G, B, A)
352
+ colors_uint8 = (colors_float * 255).astype(np.uint8)
353
+
354
+ # 坐标转换
355
+ vertices = indices
356
+
357
+ ply_filename = f"{sample_name}_{name}_logits_colored.ply"
358
+ ply_save_path = os.path.join(self.output_voxel_dir, ply_filename)
359
+
360
+ try:
361
+ # 使用 Trimesh 保存 (Trimesh 支持 (N, 4) 的 colors)
362
+ pcd = trimesh.points.PointCloud(vertices=vertices, colors=colors_uint8)
363
+ pcd.export(ply_save_path)
364
+ print(f"Saved colored RGBA logit PLY to {ply_save_path}")
365
+ except Exception as e:
366
+ print(f"Failed to save PLY with trimesh: {e}")
367
+ # Fallback: 手动写入 PLY (需要添加 alpha 属性)
368
+ with open(ply_save_path, 'w') as f:
369
+ f.write("ply\n")
370
+ f.write("format ascii 1.0\n")
371
+ f.write(f"element vertex {len(vertices)}\n")
372
+ f.write("property float x\n")
373
+ f.write("property float y\n")
374
+ f.write("property float z\n")
375
+ f.write("property uchar red\n")
376
+ f.write("property uchar green\n")
377
+ f.write("property uchar blue\n")
378
+ f.write("property uchar alpha\n") # 新增 Alpha 属性
379
+ f.write("end_header\n")
380
+ for i in range(len(vertices)):
381
+ v = vertices[i]
382
+ c = colors_uint8[i] # c is now (R, G, B, A)
383
+ f.write(f"{v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]} {c[3]}\n")
384
+
385
+ def _point_line_segment_distance(self, px, py, pz, x1, y1, z1, x2, y2, z2):
386
+ """
387
+ 计算点 (px,py,pz) 到线段 (x1,y1,z1)-(x2,y2,z2) 的最短距离的平方。
388
+ 全部输入为 Tensor,支持广播。
389
+ """
390
+ # 线段向量 AB
391
+ ABx = x2 - x1
392
+ ABy = y2 - y1
393
+ ABz = z2 - z1
394
+
395
+ # 向量 AP
396
+ APx = px - x1
397
+ APy = py - y1
398
+ APz = pz - z1
399
+
400
+ # AB 的长度平方
401
+ AB_sq = ABx**2 + ABy**2 + ABz**2
402
+
403
+ # 避免除以0 (如果两端点重合)
404
+ AB_sq = torch.clamp(AB_sq, min=1e-6)
405
+
406
+ # 投影系数 t = (AP · AB) / |AB|^2
407
+ t = (APx * ABx + APy * ABy + APz * ABz) / AB_sq
408
+
409
+ # 限制 t 在 [0, 1] 之间(线段约束)
410
+ t = torch.clamp(t, 0.0, 1.0)
411
+
412
+ # 最近点 (Projection)
413
+ closestX = x1 + t * ABx
414
+ closestY = y1 + t * ABy
415
+ closestZ = z1 + t * ABz
416
+
417
+ # 距离平方
418
+ dx = px - closestX
419
+ dy = py - closestY
420
+ dz = pz - closestZ
421
+
422
+ return dx**2 + dy**2 + dz**2
423
+
424
+ def _extract_mesh_projection_based(
425
+ self,
426
+ vtx_result: dict,
427
+ edge_result: dict,
428
+ resolution: int = 1024,
429
+ vtx_prob_threshold: float = 0.5,
430
+
431
+ # --- 你的新逻辑参数 ---
432
+ search_radius: float = 128.0, # 1. 候选边最大长度
433
+ project_dist_thresh: float = 1.5, # 2. 投影距离阈值 (管子半径,单位:voxel)
434
+ dir_align_threshold: float = 0.6, # 3. 方向相似度阈值 (cos theta)
435
+ connect_ratio_threshold: float = 0.4, # 4. 最终连接阈值 (匹配点数 / 理论长度)
436
+
437
+ edge_prob_threshold: float = 0.1, # 仅仅用于提取"存在的"体素
438
+ ):
439
+ t_start = time.perf_counter()
440
+
441
+ # ---------------------------------------------------------------------
442
+ # 1. 准备全局数据:提取所有"活着"的 Edge Voxels (作为点云处理)
443
+ # ---------------------------------------------------------------------
444
+ e_probs = torch.sigmoid(edge_result['occ_probs'][:, 0])
445
+ e_coords = edge_result['coords_4d'][:, 1:].float() # (N, 3)
446
+
447
+ # 获取方向向量
448
+ if 'predicted_direction_feats' in edge_result:
449
+ e_dirs = edge_result['predicted_direction_feats'] # (N, 3)
450
+ # 归一化方向
451
+ e_dirs = F.normalize(e_dirs, p=2, dim=1)
452
+ else:
453
+ print("Warning: No direction features, using dummy.")
454
+ e_dirs = torch.zeros_like(e_coords)
455
+
456
+ # 筛选有效的 Edge Voxels (Global Point Cloud)
457
+ valid_mask = e_probs > edge_prob_threshold
458
+
459
+ cloud_coords = e_coords[valid_mask] # (M, 3)
460
+ cloud_dirs = e_dirs[valid_mask] # (M, 3)
461
+
462
+ num_cloud = cloud_coords.shape[0]
463
+ print(f"[Projection] Global active edge voxels: {num_cloud}")
464
+
465
+ if num_cloud == 0:
466
+ return [], []
467
+
468
+ # ---------------------------------------------------------------------
469
+ # 2. 准备顶点和候选边
470
+ # ---------------------------------------------------------------------
471
+ v_probs = torch.sigmoid(vtx_result['occ_probs'][:, 0])
472
+ v_coords = vtx_result['coords_4d'][:, 1:].float()
473
+ v_mask = v_probs > vtx_prob_threshold
474
+ valid_v_coords = v_coords[v_mask] # (V, 3)
475
+
476
+ if valid_v_coords.shape[0] < 2:
477
+ return valid_v_coords.cpu().numpy() / resolution, []
478
+
479
+ # 生成所有可能的候选边 (基于距离粗筛)
480
+ dists = torch.cdist(valid_v_coords, valid_v_coords)
481
+ triu_mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
482
+ cand_mask = (dists < search_radius) & triu_mask
483
+ cand_indices = torch.nonzero(cand_mask, as_tuple=False) # (E_cand, 2)
484
+
485
+ p1s = valid_v_coords[cand_indices[:, 0]] # (E, 3)
486
+ p2s = valid_v_coords[cand_indices[:, 1]] # (E, 3)
487
+
488
+ num_candidates = p1s.shape[0]
489
+ print(f"[Projection] Checking {num_candidates} candidate pairs...")
490
+
491
+ # ---------------------------------------------------------------------
492
+ # 3. 循环处理候选边 (使用 Bounding Box 快速裁剪)
493
+ # ---------------------------------------------------------------------
494
+ final_edges = []
495
+
496
+ # 预计算所有候选边的方向和长度
497
+ edge_vecs = p2s - p1s
498
+ edge_lengths = torch.norm(edge_vecs, dim=1)
499
+ edge_dirs = F.normalize(edge_vecs, p=2, dim=1)
500
+
501
+ # 为了避免显存爆炸,也不要在 Python 里做太慢的循环
502
+ # 我们对点云进行操作太慢,对每一条边去遍历整个点云也太慢。
503
+ # 策略:
504
+ # 我们循环“边”,但在循环内部利用 mask 快速筛选点云。
505
+ # 由于 Python 循环 10000 次会很慢,我们只处理那些有希望的边。
506
+ # 这里为了演示逻辑的准确性,我们使用简单的循环,但在 GPU 上做计算。
507
+
508
+ # 将全局点云拆分到各个坐标轴,便于快速 BBox 筛选
509
+ cx, cy, cz = cloud_coords[:, 0], cloud_coords[:, 1], cloud_coords[:, 2]
510
+
511
+ # 优化:如果候选边太多,可以分块。这里假设边在 5万以内,点在 10万以内,可以处理。
512
+
513
+ # 这一步是瓶颈,我们尝试用 Python 循环,但只对局部点计算
514
+ # 为了加速,我们可以将点云放入 HashGrid 或者只是简单的 BBox Check。
515
+
516
+ # 让我们用简单的逻辑:对于每条边,找出 BBox 内的点,算距离。
517
+ # 这里的 batch_size 是指一次并行处理多少条边
518
+
519
+ batch_size = 128 # 每次处理 128 条边
520
+
521
+ for i in range(0, num_candidates, batch_size):
522
+ end = min(i + batch_size, num_candidates)
523
+
524
+ # 当前批次的边数据
525
+ b_p1 = p1s[i:end] # (B, 3)
526
+ b_p2 = p2s[i:end] # (B, 3)
527
+ b_dirs = edge_dirs[i:end] # (B, 3)
528
+ b_lens = edge_lengths[i:end] # (B,)
529
+
530
+ # --- 步骤 A: 投影 & 距离检查 ---
531
+ # 这是一个 (B, M) 的大矩阵计算,容易 OOM。
532
+ # M (点云数) 可能很大。
533
+ # 解决方法:我们反过来思考。
534
+ # 不计算矩阵,我们只对单个边进行循环?太慢。
535
+
536
+ # 实用优化:只对 bounding box 内的点进行距离计算。
537
+ # 由于 GPU 难以动态索引不规则数据,我们还是逐个边循环比较稳妥,
538
+ # 但为了 Python 速度,必须尽可能向量化。
539
+
540
+ # 这里我采用一种折中方案:逐个处理边,但是利用 torch.where 快速定位。
541
+ # 实际上,对于 Python 里的 for loop,几千次是可以接受的。
542
+
543
+ current_edges_indices = cand_indices[i:end]
544
+
545
+ for j in range(len(b_p1)):
546
+ # 单条边处理
547
+ p1 = b_p1[j]
548
+ p2 = b_p2[j]
549
+ e_dir = b_dirs[j]
550
+ e_len = b_lens[j].item()
551
+
552
+ # 1. Bounding Box Filter (快速大幅裁剪)
553
+ # 找出这条边 BBox 范围内的所有点 (+ padding)
554
+ padding = project_dist_thresh + 2.0
555
+ min_xyz = torch.min(p1, p2) - padding
556
+ max_xyz = torch.max(p1, p2) + padding
557
+
558
+ # 利用 boolean mask 筛选
559
+ mask_x = (cx >= min_xyz[0]) & (cx <= max_xyz[0])
560
+ mask_y = (cy >= min_xyz[1]) & (cy <= max_xyz[1])
561
+ mask_z = (cz >= min_xyz[2]) & (cz <= max_xyz[2])
562
+ bbox_mask = mask_x & mask_y & mask_z
563
+
564
+ subset_coords = cloud_coords[bbox_mask]
565
+ subset_dirs = cloud_dirs[bbox_mask]
566
+
567
+ if subset_coords.shape[0] == 0:
568
+ continue
569
+
570
+ # 2. 精确距离计算 (Projection Distance)
571
+ # 计算 subset 中每个点到线段 p1-p2 的距离平方
572
+ dist_sq = self._point_line_segment_distance(
573
+ subset_coords[:, 0], subset_coords[:, 1], subset_coords[:, 2],
574
+ p1[0], p1[1], p1[2],
575
+ p2[0], p2[1], p2[2]
576
+ )
577
+
578
+ # 3. 距离阈值过滤 (Keep voxels inside the tube)
579
+ dist_mask = dist_sq < (project_dist_thresh ** 2)
580
+
581
+ # 获取在管子内部的体素
582
+ tube_dirs = subset_dirs[dist_mask]
583
+
584
+ if tube_dirs.shape[0] == 0:
585
+ continue
586
+
587
+ # 4. 方向一致性检查 (Direction Check)
588
+ # 计算点积 (cos theta)
589
+ # e_dir 是 (3,), tube_dirs 是 (K, 3)
590
+ dot_prod = torch.matmul(tube_dirs, e_dir)
591
+
592
+ # 这里使用 abs,因为边可能是无向的,或者网络预测可能反向
593
+ # 如果你的网络严格预测流向,可以去掉 abs
594
+ dir_sim = torch.abs(dot_prod)
595
+
596
+ # 统计方向符合要求的体素数量
597
+ valid_voxel_count = (dir_sim > dir_align_threshold).sum().item()
598
+
599
+ # 5. 比值判决 (Ratio Check)
600
+ # 量化出的 Voxel 数目 ≈ 边的长度 (e_len)
601
+ # 如果 e_len 很小(比如<1),我们设为1防止除以0
602
+ theoretical_count = max(e_len, 1.0)
603
+
604
+ ratio = valid_voxel_count / theoretical_count
605
+
606
+ if ratio > connect_ratio_threshold:
607
+ # 找到了!
608
+ global_idx = i + j
609
+ edge_tuple = cand_indices[global_idx].cpu().numpy().tolist()
610
+ final_edges.append(edge_tuple)
611
+
612
+ t_end = time.perf_counter()
613
+ print(f"[Projection] Logic finished. Accepted {len(final_edges)} edges. Time={t_end - t_start:.4f}s")
614
+
615
+ out_vertices = valid_v_coords.cpu().numpy() / resolution
616
+ return out_vertices, final_edges
617
+
618
+ def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str):
619
+ if coords.numel() == 0:
620
+ return
621
+
622
+ coords_np = coords.cpu().to(torch.float32).numpy()
623
+ labels_np = labels.cpu().to(torch.float32).numpy()
624
+
625
+ colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8)
626
+ colors[labels_np == 0] = [255, 0, 0]
627
+ colors[labels_np == 1] = [0, 0, 255]
628
+
629
+ try:
630
+ import trimesh
631
+ point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors)
632
+ ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
633
+ point_cloud.export(ply_path)
634
+ except ImportError:
635
+ ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
636
+ with open(ply_path, 'w') as f:
637
+ f.write("ply\n")
638
+ f.write("format ascii 1.0\n")
639
+ f.write(f"element vertex {coords_np.shape[0]}\n")
640
+ f.write("property float x\n")
641
+ f.write("property float y\n")
642
+ f.write("property float z\n")
643
+ f.write("property uchar red\n")
644
+ f.write("property uchar green\n")
645
+ f.write("property uchar blue\n")
646
+ f.write("end_header\n")
647
+ for i in range(coords_np.shape[0]):
648
+ f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
649
+
650
+ def _load_config(self, config_path=None):
651
+ if config_path and os.path.exists(config_path):
652
+ with open(config_path) as f:
653
+ return yaml.safe_load(f)
654
+
655
+ ckpt_dir = os.path.dirname(self.ckpt_path)
656
+ possible_configs = [
657
+ os.path.join(ckpt_dir, "config.yaml"),
658
+ os.path.join(os.path.dirname(ckpt_dir), "config.yaml")
659
+ ]
660
+
661
+ for config_file in possible_configs:
662
+ if os.path.exists(config_file):
663
+ with open(config_file) as f:
664
+ print(f"Loaded config from: {config_file}")
665
+ return yaml.safe_load(f)
666
+
667
+ checkpoint = torch.load(self.ckpt_path, map_location='cpu')
668
+ if 'config' in checkpoint:
669
+ print("Loaded config from checkpoint")
670
+ return checkpoint['config']
671
+
672
+ raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.")
673
+
674
+ def _init_models(self):
675
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
676
+ in_channels=15,
677
+ hidden_dim=256,
678
+ out_channels=1024,
679
+ scatter_type='mean',
680
+ n_blocks=5,
681
+ resolution=128,
682
+
683
+ ).to(self.device)
684
+
685
+ self.connection_head = ConnectionHead(
686
+ channels=128 * 2,
687
+ out_channels=1,
688
+ mlp_ratio=4,
689
+ ).to(self.device)
690
+
691
+ self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation
692
+ in_channels=self.config['model']['in_channels'],
693
+ latent_dim=self.config['model']['latent_dim'],
694
+ encoder_blocks=self.config['model']['encoder_blocks'],
695
+ # decoder_blocks=self.config['model']['decoder_blocks'],
696
+ decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'],
697
+ decoder_blocks_edge=self.config['model']['decoder_blocks_edge'],
698
+ num_heads=8,
699
+ num_head_channels=64,
700
+ mlp_ratio=4.0,
701
+ attn_mode="swin",
702
+ window_size=8,
703
+ pe_mode="ape",
704
+ use_fp16=False,
705
+ use_checkpoint=False,
706
+ qk_rms_norm=False,
707
+ using_subdivide=True,
708
+ using_attn=self.config['model']['using_attn'],
709
+ attn_first=self.config['model'].get('attn_first', True),
710
+ pred_direction=self.config['model'].get('pred_direction', False),
711
+ ).to(self.device)
712
+
713
+ load_pretrained_woself(
714
+ checkpoint_path=self.ckpt_path,
715
+ voxel_encoder=self.voxel_encoder,
716
+ connection_head=self.connection_head,
717
+ vae=self.vae,
718
+ )
719
+ # --- 【新增】在这里添加权重检查逻辑 ---
720
+ print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---")
721
+ has_nan_inf = False
722
+ if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"):
723
+ has_nan_inf = True
724
+
725
+ if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"):
726
+ has_nan_inf = True
727
+
728
+ if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"):
729
+ has_nan_inf = True
730
+
731
+ if not has_nan_inf:
732
+ print("--- 权重检查通过。未发现 NaN/Inf 值。 ---")
733
+ else:
734
+ # 如果发现坏值,直接抛出异常,因为评估无法继续
735
+ raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。")
736
+ # --- 检查逻辑结束 ---
737
+
738
+ self.vae.eval()
739
+ self.voxel_encoder.eval()
740
+ self.connection_head.eval()
741
+
742
+ def _init_dataset(self):
743
+ self.dataset = VoxelVertexDataset_edge(
744
+ root_dir=self.dataset_path,
745
+ base_resolution=self.config['dataset']['base_resolution'],
746
+ min_resolution=self.config['dataset']['min_resolution'],
747
+ cache_dir='/gemini/user/private/zhaotianhao/dataset_cache/test_15c_dora',
748
+ # cache_dir=self.config['dataset']['cache_dir'],
749
+ renders_dir=self.config['dataset']['renders_dir'],
750
+
751
+ # filter_active_voxels=self.config['dataset']['filter_active_voxels'],
752
+ filter_active_voxels=False,
753
+ cache_filter_path=self.config['dataset']['cache_filter_path'],
754
+
755
+ sample_type=self.config['dataset']['sample_type'],
756
+ active_voxel_res=128,
757
+ pc_sample_number=819200,
758
+
759
+ )
760
+
761
+ self.dataloader = DataLoader(
762
+ self.dataset,
763
+ batch_size=1,
764
+ shuffle=False,
765
+ collate_fn=partial(collate_fn_pointnet),
766
+ num_workers=0,
767
+ pin_memory=True,
768
+ # prefetch_factor=4,
769
+ )
770
+
771
+ def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool:
772
+ """
773
+ 检查模型的所有参数中是否存在 NaN 或 Inf 值。
774
+
775
+ Args:
776
+ model (torch.nn.Module): 要检查的模型。
777
+ model_name (str): 模型的名称,用于打印日志。
778
+
779
+ Returns:
780
+ bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。
781
+ """
782
+ found_issue = False
783
+ for name, param in model.named_parameters():
784
+ if torch.isnan(param.data).any():
785
+ print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!")
786
+ found_issue = True
787
+ if torch.isinf(param.data).any():
788
+ print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!")
789
+ found_issue = True
790
+ return found_issue
791
+
792
+
793
+ def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
794
+ """
795
+ 修改后的函数,确保一对一匹配,并优先匹配最近的点对。
796
+ """
797
+ pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
798
+ gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
799
+
800
+ pred_total = len(pred_array)
801
+ gt_total = len(gt_array)
802
+
803
+ if pred_total == 0 or gt_total == 0:
804
+ return {
805
+ 'recall': 0.0,
806
+ 'precision': 0.0,
807
+ 'f1': 0.0,
808
+ 'matches': 0,
809
+ 'pred_count': pred_total,
810
+ 'gt_count': gt_total
811
+ }
812
+
813
+ # 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点
814
+ tree = cKDTree(pred_array)
815
+ dists, pred_idxs = tree.query(gt_array, k=1)
816
+
817
+ # --- 核心修改部分 ---
818
+
819
+ # 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引)
820
+ # 这样我们就可以按距离对所有可能的匹配进行排序
821
+ possible_matches = []
822
+ for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)):
823
+ if dist <= threshold:
824
+ possible_matches.append((dist, gt_idx, pred_idx))
825
+
826
+ # 2. 按距离从小到大排序(贪心策略)
827
+ possible_matches.sort(key=lambda x: x[0])
828
+
829
+ matches = 0
830
+ # 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配
831
+ used_pred_indices = set()
832
+ used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨
833
+
834
+ # 3. 遍历排序后的可能匹配,进行一对一分配
835
+ for dist, gt_idx, pred_idx in possible_matches:
836
+ # 如果这个预测点和这个真实点都还没有被使用过
837
+ if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices:
838
+ matches += 1
839
+ used_pred_indices.add(pred_idx)
840
+ used_gt_indices.add(gt_idx)
841
+
842
+ # --- 修改结束 ---
843
+
844
+ # matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total
845
+ recall = matches / gt_total if gt_total > 0 else 0.0
846
+ precision = matches / pred_total if pred_total > 0 else 0.0
847
+
848
+ # 计算F1时,使用标准的 Precision 和 Recall 定义
849
+ if (precision + recall) == 0:
850
+ f1 = 0.0
851
+ else:
852
+ f1 = 2 * (precision * recall) / (precision + recall)
853
+
854
+ return {
855
+ 'recall': recall,
856
+ 'precision': precision,
857
+ 'f1': f1,
858
+ 'matches': matches,
859
+ 'pred_count': pred_total,
860
+ 'gt_count': gt_total
861
+ }
862
+
863
+ def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
864
+ """
865
+ 一个折衷的顶点指标计算方案。
866
+ 它沿用“为每个真实点寻找最近预测点”的逻辑,
867
+ 但通过修正计算方式,确保Precision和F1值不会超过1.0。
868
+ """
869
+ # 假设 pred_coords 和 gt_coords 是 PyTorch 张量
870
+ pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
871
+ gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
872
+
873
+ pred_total = len(pred_array)
874
+ gt_total = len(gt_array)
875
+
876
+ if pred_total == 0 or gt_total == 0:
877
+ return {
878
+ 'recall': 0.0,
879
+ 'precision': 0.0,
880
+ 'f1': 0.0,
881
+ 'matches': 0,
882
+ 'pred_count': pred_total,
883
+ 'gt_count': gt_total
884
+ }
885
+
886
+ # 在预测点上构建KD-Tree,为每个真实点查找最近的预测点
887
+ tree = cKDTree(pred_array)
888
+ dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引
889
+
890
+ # 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall)
891
+ # 这和您的第一个函数完全一样。
892
+ # 这个值代表了“有多少个真实点被成功找到了”。
893
+ matches_from_gt = np.sum(dists <= threshold)
894
+
895
+ # 2. 计算 Recall (召回率)
896
+ # 召回率的分母是真实点的总数,所以这里的计算是合理的。
897
+ recall = matches_from_gt / gt_total if gt_total > 0 else 0.0
898
+
899
+ # 3. 计算 Precision (精确率) - ✅ 这是核心修正点
900
+ # 精确率的分母是预测点的总数。
901
+ # 分子(True Positives)不能超过预测点的总数。
902
+ # 因此,我们取 matches_from_gt 和 pred_total 中的较小值。
903
+ # 这解决了 Precision > 1 的问题。
904
+ tp_for_precision = min(matches_from_gt, pred_total)
905
+ precision = tp_for_precision / pred_total if pred_total > 0 else 0.0
906
+
907
+ # 4. 使用标准的F1分数公式
908
+ # 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score,
909
+ # 更常用的是基于 Precision 和 Recall 的调和平均数。
910
+ if (precision + recall) == 0:
911
+ f1 = 0.0
912
+ else:
913
+ f1 = 2 * (precision * recall) / (precision + recall)
914
+
915
+ return {
916
+ 'recall': recall,
917
+ 'precision': precision,
918
+ 'f1': f1,
919
+ 'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察
920
+ 'pred_count': pred_total,
921
+ 'gt_count': gt_total
922
+ }
923
+
924
+ def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False):
925
+ if len(p1) == 0 or len(p2) == 0:
926
+ return float('nan')
927
+
928
+ dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean()
929
+
930
+ if one_sided:
931
+ return dist_p1_p2.item()
932
+ else:
933
+ dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean()
934
+ return (dist_p1_p2 + dist_p2_p1).item() / 2
935
+
936
+ def visualize_latent_space_pca(self, sample_idx: int):
937
+ """
938
+ Encodes a sample, performs PCA on its latent features, and saves a
939
+ colored PLY file for visualization.
940
+
941
+ The position of each point in the PLY file corresponds to the spatial
942
+ location in the latent grid.
943
+
944
+ The color of each point represents the first three principal components
945
+ of its feature vector.
946
+ """
947
+ print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---")
948
+ self.vae.eval()
949
+
950
+ try:
951
+ # 1. Get the latent representation for the sample
952
+ latent = self._get_latent_for_sample(sample_idx)
953
+ except ValueError as e:
954
+ print(f"Error: {e}")
955
+ return
956
+
957
+ latent_coords = latent.coords.detach().cpu().numpy()
958
+ latent_feats = latent.feats.detach().cpu().numpy()
959
+
960
+ if latent_feats.shape[0] < 3:
961
+ print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.")
962
+ return
963
+
964
+ print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...")
965
+
966
+ # 2. Perform PCA to reduce feature dimensions to 3
967
+ pca = PCA(n_components=3)
968
+ pca_features = pca.fit_transform(latent_feats)
969
+
970
+ print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}")
971
+ print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}")
972
+
973
+ # 3. Normalize the PCA components to be used as RGB colors [0, 255]
974
+ # We normalize each component independently to maximize color contrast
975
+ normalized_colors = np.zeros_like(pca_features)
976
+ for i in range(3):
977
+ min_val = pca_features[:, i].min()
978
+ max_val = pca_features[:, i].max()
979
+ if max_val - min_val > 1e-6:
980
+ normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val)
981
+ else:
982
+ normalized_colors[:, i] = 0.5 # Handle case of constant value
983
+
984
+ colors_uint8 = (normalized_colors * 255).astype(np.uint8)
985
+
986
+ # 4. Prepare spatial coordinates for the point cloud
987
+ # latent_coords is (batch_idx, x, y, z), we want the xyz part
988
+ spatial_coords = latent_coords[:, 1:]
989
+
990
+ # 5. Create and save the colored PLY file
991
+ try:
992
+ # Create a Trimesh PointCloud object
993
+ point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8)
994
+
995
+ # Define the output filename
996
+ filename = f"sample_{sample_idx}_latent_pca.ply"
997
+ ply_path = os.path.join(self.output_voxel_dir, filename)
998
+
999
+ # Export the file
1000
+ point_cloud.export(ply_path)
1001
+ print(f"--> Successfully saved PCA visualization to: {ply_path}")
1002
+
1003
+ except Exception as e:
1004
+ print(f"Error during Trimesh export: {e}")
1005
+ print("Please ensure 'trimesh' is installed correctly.")
1006
+
1007
+ def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor:
1008
+ """
1009
+ Encodes a single sample and returns its latent representation.
1010
+ """
1011
+ print(f"--> Encoding sample {sample_idx} to get its latent vector...")
1012
+ # Get data for the specified sample
1013
+ batch_data = self.dataset[sample_idx]
1014
+ if batch_data is None:
1015
+ raise ValueError(f"Sample at index {sample_idx} could not be loaded.")
1016
+
1017
+ # Use the collate function to form a batch
1018
+ batch_data = collate_fn_pointnet([batch_data])
1019
+
1020
+ with torch.no_grad():
1021
+ active_coords = batch_data['active_voxels_128'].to(self.device)
1022
+ point_cloud = batch_data['point_cloud_128'].to(self.device)
1023
+
1024
+ active_voxel_feats = self.voxel_encoder(
1025
+ p=point_cloud,
1026
+ sparse_coords=active_coords,
1027
+ res=128,
1028
+ bbox_size=(-0.5, 0.5),
1029
+ )
1030
+
1031
+ sparse_input = SparseTensor(
1032
+ feats=active_voxel_feats,
1033
+ coords=active_coords.int()
1034
+ )
1035
+
1036
+ # 2. Encode to get the latent representation
1037
+ latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,)
1038
+ print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}")
1039
+ return latent_128
1040
+
1041
+
1042
+
1043
+ def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.):
1044
+ total_samples = len(self.dataset)
1045
+ eval_samples = min(num_samples or total_samples, total_samples)
1046
+ sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples)
1047
+ # sample_indices = range(eval_samples)
1048
+
1049
+ eval_dataset = Subset(self.dataset, sample_indices)
1050
+ eval_loader = DataLoader(
1051
+ eval_dataset,
1052
+ batch_size=1,
1053
+ shuffle=False,
1054
+ collate_fn=partial(collate_fn_pointnet),
1055
+ num_workers=self.config['training']['num_workers'],
1056
+ pin_memory=True,
1057
+ )
1058
+
1059
+ per_sample_metrics = {
1060
+ 'vertex': {res: [] for res in [128, 256, 512]},
1061
+ 'edge': {res: [] for res in [128, 256, 512]},
1062
+ 'sample_names': []
1063
+ }
1064
+ avg_metrics = {
1065
+ 'vertex': {res: defaultdict(list) for res in [128, 256, 512]},
1066
+ 'edge': {res: defaultdict(list) for res in [128, 256, 512]},
1067
+ }
1068
+
1069
+ self.vae.eval()
1070
+
1071
+ for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")):
1072
+ if batch_data is None:
1073
+ continue
1074
+ sample_idx = sample_indices[batch_idx]
1075
+ sample_name = f'sample_{sample_idx}'
1076
+ per_sample_metrics['sample_names'].append(sample_name)
1077
+
1078
+ # batch_save_path = f"/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/215000_sample_active_vis_42seed_1000complex/gt_data_batch_{batch_idx}.pt"
1079
+ # if not os.path.exists(batch_save_path):
1080
+ # print(f"Warning: Saved batch file not found: {batch_save_path}")
1081
+ # continue
1082
+ # batch_data = torch.load(batch_save_path, map_location=self.device)
1083
+
1084
+ with torch.no_grad():
1085
+ # 1. Get input data
1086
+ combined_voxels_512 = batch_data['combined_voxels_512'].to(self.device)
1087
+ combined_voxel_labels_512 = batch_data['combined_voxel_labels_512'].to(self.device)
1088
+ gt_combined_endpoints_512 = batch_data['gt_combined_endpoints_512'].to(self.device)
1089
+ gt_combined_errors_512 = batch_data['gt_combined_errors_512'].to(self.device)
1090
+
1091
+ edge_mask = (combined_voxel_labels_512 == 1)
1092
+
1093
+ gt_edge_endpoints_512 = gt_combined_endpoints_512[edge_mask].to(self.device)
1094
+
1095
+ gt_edge_voxels_512 = combined_voxels_512[edge_mask].to(self.device)
1096
+
1097
+ p1 = gt_edge_endpoints_512[:, 1:4].float()
1098
+ p2 = gt_edge_endpoints_512[:, 4:7].float()
1099
+
1100
+ mask = ( (p1[:,0] < p2[:,0]) |
1101
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
1102
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
1103
+
1104
+ pA = torch.where(mask[:, None], p1, p2) # smaller one
1105
+ pB = torch.where(mask[:, None], p2, p1) # larger one
1106
+
1107
+ d = pB - pA
1108
+ dir_gt = F.normalize(d, dim=-1, eps=1e-6)
1109
+
1110
+ gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'].to(self.device).int()
1111
+
1112
+ vtx_128 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=128)
1113
+ vtx_256 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=256)
1114
+
1115
+ edge_128 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=128)
1116
+ edge_256 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=256)
1117
+ edge_512 = combined_voxels_512
1118
+
1119
+
1120
+ gt_edge_voxels_list = [
1121
+ edge_128,
1122
+ edge_256,
1123
+ edge_512,
1124
+ ]
1125
+
1126
+ active_coords = batch_data['active_voxels_128'].to(self.device)
1127
+ point_cloud = batch_data['point_cloud_128'].to(self.device)
1128
+
1129
+
1130
+ active_voxel_feats = self.voxel_encoder(
1131
+ p=point_cloud,
1132
+ sparse_coords=active_coords,
1133
+ res=128,
1134
+ bbox_size=(-0.5, 0.5),
1135
+ )
1136
+
1137
+ sparse_input = SparseTensor(
1138
+ feats=active_voxel_feats,
1139
+ coords=active_coords.int()
1140
+ )
1141
+
1142
+ latent_128, posterior = self.vae.encode(sparse_input)
1143
+
1144
+
1145
+ # load_path = f'/gemini/user/private/zhaotianhao/checkpoints/output_slat_flow_matching_active/8w_128to256_head_rope/215000_sample_active_vis_42seed_1000complex/sample_latent_{batch_idx}.pt'
1146
+ # latent_128 = torch.load(load_path, map_location=self.device)
1147
+
1148
+ print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std())
1149
+ print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean())
1150
+ print('latent_128.coords.shape', latent_128.coords.shape)
1151
+
1152
+ # latent_128 = torch.load(f"/root/Trisf/output_slat_flow_matching/ckpts/1100_chair_sample/110000step_sample/sample_results_samples_{batch_idx}.pt", map_location=self.device)
1153
+ latent_128 = SparseTensor(
1154
+ coords=latent_128.coords,
1155
+ feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats),
1156
+ )
1157
+
1158
+ # self.output_voxel_dir = os.path.dirname(load_path)
1159
+ # self.output_obj_dir = os.path.dirname(load_path)
1160
+
1161
+ # 7. Decoding with separate vertex and edge processing
1162
+ decoded_results = self.vae.decode(
1163
+ latent_128,
1164
+ gt_vertex_voxels_list=[],
1165
+ gt_edge_voxels_list=[],
1166
+ training=False,
1167
+
1168
+ inference_threshold=0.5,
1169
+ vis_last_layer=False,
1170
+ )
1171
+
1172
+ error = 0 # decoded_results[-1]['edge']['predicted_offset_feats']
1173
+
1174
+ if self.config['model'].get('pred_direction', False):
1175
+ pred_dir = decoded_results[-1]['edge']['predicted_direction_feats']
1176
+ zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0
1177
+ num_zeros = zero_mask.sum().item()
1178
+ print("Number of zero vectors:", num_zeros)
1179
+
1180
+ pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
1181
+ print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
1182
+ print('pred_dir.shape', pred_dir.shape)
1183
+ if pred_edge_coords_3d.shape[-1] == 4:
1184
+ pred_edge_coords_3d = pred_edge_coords_3d[:, 1:]
1185
+
1186
+ save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply")
1187
+ visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth)
1188
+
1189
+ save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply")
1190
+ visualize_colored_points_ply((gt_edge_voxels_512[:, 1:]), dir_gt, save_pth)
1191
+
1192
+
1193
+ pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords']
1194
+ pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
1195
+
1196
+
1197
+ gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'][:, 1:].to(self.device)
1198
+ gt_edge_voxels_512 = batch_data['gt_edge_voxels_512'][:, 1:].to(self.device)
1199
+
1200
+
1201
+ # Calculate metrics and save results
1202
+ matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_512, threshold=threshold,)
1203
+ print(f"\n----- Resolution {512} vtx -----")
1204
+ print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
1205
+ print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
1206
+
1207
+ self._save_voxel_ply(pred_vtx_coords_3d / 512., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx")
1208
+ self._save_voxel_ply((pred_edge_coords_3d) / 512, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge")
1209
+
1210
+ self._save_voxel_ply(gt_vertex_voxels_512 / 512, torch.zeros(len(gt_vertex_voxels_512)), f"{sample_name}_gt_vertex")
1211
+ self._save_voxel_ply((combined_voxels_512[:, 1:]) / 512., torch.zeros(len(gt_combined_errors_512)), f"{sample_name}_gt_edge")
1212
+
1213
+
1214
+ # Calculate vertex-specific metrics
1215
+ matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_512[:, 1:], threshold=threshold,)
1216
+ print(f"\n----- Resolution {512} edge -----")
1217
+ print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
1218
+ print('gt_edge_voxels_512.shape', gt_edge_voxels_512.shape)
1219
+
1220
+ print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
1221
+ print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
1222
+
1223
+ pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int)
1224
+ pred_edges = []
1225
+ gt_vertex_coords_np = np.round(gt_vertex_voxels_512.cpu().numpy()).astype(int)
1226
+ if visualize:
1227
+ if pred_vtx_coords_3d.shape[-1] == 4:
1228
+ pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float()
1229
+ else:
1230
+ pred_vtx_coords_float = pred_vtx_coords_3d.float()
1231
+
1232
+ pred_vtx_feats = decoded_results[-1]['vertex']['feats']
1233
+
1234
+ # ==========================================
1235
+ # Link Prediction & Mesh Generation
1236
+ # ==========================================
1237
+ print("Predicting connectivity...")
1238
+
1239
+ # 1. 预测边
1240
+ # 注意:K_neighbors 的设置。如果是物体,64 足够了。
1241
+ # 如果点非常稀疏,可能需要更大。
1242
+ pred_edges = predict_mesh_connectivity(
1243
+ connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的
1244
+ vtx_feats=pred_vtx_feats,
1245
+ vtx_coords=pred_vtx_coords_float,
1246
+ batch_size=4096,
1247
+ threshold=0.5,
1248
+ k_neighbors=None,
1249
+ device=self.device
1250
+ )
1251
+ print(f"Predicted {len(pred_edges)} edges.")
1252
+
1253
+ # 2. 构建三角形
1254
+ num_verts = pred_vtx_coords_float.shape[0]
1255
+ pred_faces = build_triangles_from_edges(pred_edges, num_verts)
1256
+ print(f"Constructed {len(pred_faces)} triangles.")
1257
+
1258
+ # 3. 保存 OBJ
1259
+ import trimesh
1260
+
1261
+ # 坐标归一化/还原 (根据你的需求,这里假设你是 0-512 的体素坐标)
1262
+ # 如果想保存为归一化坐标:
1263
+ mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0
1264
+
1265
+ # 如果有 error offset,记得加上!
1266
+ # 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
1267
+ # 如果 vertex 也有 offset (如 dual contouring),在这里加上
1268
+
1269
+ # 移动到中心 (可选)
1270
+ mesh_verts = mesh_verts - 0.5
1271
+
1272
+ mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
1273
+
1274
+ # 过滤孤立点 (可选)
1275
+ # mesh.remove_unreferenced_vertices()
1276
+
1277
+ output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj")
1278
+ mesh.export(output_obj_path)
1279
+ print(f"Saved mesh to {output_obj_path}")
1280
+
1281
+ # 保存边线 (用于 Debug)
1282
+ # 有时候三角形很难形成,只看边也很有用
1283
+ edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply")
1284
+ # self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison")
1285
+
1286
+
1287
+ # Process results at different resolutions
1288
+ for i, res in enumerate([128, 256, 512]):
1289
+ if i >= len(decoded_results):
1290
+ continue
1291
+
1292
+ gt_key = f'gt_vertex_voxels_{res}'
1293
+ if gt_key not in batch_data:
1294
+ continue
1295
+ if i == 0:
1296
+ pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float()
1297
+ gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
1298
+ else:
1299
+ pred_coords_res = decoded_results[i]['vertex']['coords'].float()
1300
+ gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
1301
+
1302
+
1303
+ v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold)
1304
+
1305
+ per_sample_metrics['vertex'][res].append({
1306
+ 'recall': v_metrics['recall'],
1307
+ 'precision': v_metrics['precision'],
1308
+ 'f1': v_metrics['f1'],
1309
+ 'num_pred': len(pred_coords_res),
1310
+ 'num_gt': len(gt_coords_res)
1311
+ })
1312
+
1313
+ avg_metrics['vertex'][res]['recall'].append(v_metrics['recall'])
1314
+ avg_metrics['vertex'][res]['precision'].append(v_metrics['precision'])
1315
+ avg_metrics['vertex'][res]['f1'].append(v_metrics['f1'])
1316
+
1317
+ gt_edge_key = f'gt_edge_voxels_{res}'
1318
+ if gt_edge_key not in batch_data:
1319
+ continue
1320
+
1321
+ if i == 0:
1322
+ pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float()
1323
+ # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1324
+ idx = i
1325
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1326
+ elif i == 1:
1327
+ idx = i
1328
+ #################################
1329
+ # pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5
1330
+ # # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1331
+ # gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_512[:, 1:].to(self.device) + 0.5
1332
+
1333
+
1334
+ pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
1335
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1336
+
1337
+ # self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset")
1338
+ # self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset")
1339
+
1340
+ else:
1341
+ idx = i
1342
+ pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
1343
+ # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1344
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1345
+
1346
+ # self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res")
1347
+ # self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res")
1348
+
1349
+ e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold)
1350
+
1351
+ per_sample_metrics['edge'][res].append({
1352
+ 'recall': e_metrics['recall'],
1353
+ 'precision': e_metrics['precision'],
1354
+ 'f1': e_metrics['f1'],
1355
+ 'num_pred': len(pred_edge_coords_res),
1356
+ 'num_gt': len(gt_edge_coords_res)
1357
+ })
1358
+
1359
+ avg_metrics['edge'][res]['recall'].append(e_metrics['recall'])
1360
+ avg_metrics['edge'][res]['precision'].append(e_metrics['precision'])
1361
+ avg_metrics['edge'][res]['f1'].append(e_metrics['f1'])
1362
+
1363
+ avg_metrics_processed = {}
1364
+ for category, res_dict in avg_metrics.items():
1365
+ avg_metrics_processed[category] = {}
1366
+ for resolution, metric_dict in res_dict.items():
1367
+ avg_metrics_processed[category][resolution] = {
1368
+ metric_name: np.mean(values) if values else float('nan')
1369
+ for metric_name, values in metric_dict.items()
1370
+ }
1371
+
1372
+ result_data = {
1373
+ 'config': self.config,
1374
+ 'checkpoint': self.ckpt_path,
1375
+ 'dataset': self.dataset_path,
1376
+ 'num_samples': eval_samples,
1377
+ 'per_sample_metrics': per_sample_metrics,
1378
+ 'avg_metrics': avg_metrics_processed
1379
+ }
1380
+
1381
+ results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml")
1382
+ with open(results_file_path, 'w') as f:
1383
+ yaml.dump(result_data, f, default_flow_style=False)
1384
+
1385
+ return result_data
1386
+
1387
+ def _generate_line_voxels(
1388
+ self,
1389
+ p1: torch.Tensor,
1390
+ p2: torch.Tensor
1391
+ ) -> Tuple[
1392
+ List[Tuple[int, int, int]],
1393
+ List[Tuple[torch.Tensor, torch.Tensor]],
1394
+ List[np.ndarray]
1395
+ ]:
1396
+ """
1397
+ Improved version using better sampling strategy
1398
+ """
1399
+ p1_np = p1 #.cpu().numpy()
1400
+ p2_np = p2 #.cpu().numpy()
1401
+ voxel_dict = OrderedDict()
1402
+
1403
+ # Use proper 3D line voxelization algorithm
1404
+ def bresenham_3d(p1, p2):
1405
+ """3D Bresenham's line algorithm"""
1406
+ x1, y1, z1 = np.round(p1).astype(int)
1407
+ x2, y2, z2 = np.round(p2).astype(int)
1408
+
1409
+ points = []
1410
+ dx = abs(x2 - x1)
1411
+ dy = abs(y2 - y1)
1412
+ dz = abs(z2 - z1)
1413
+
1414
+ xs = 1 if x2 > x1 else -1
1415
+ ys = 1 if y2 > y1 else -1
1416
+ zs = 1 if z2 > z1 else -1
1417
+
1418
+ # Driving axis is X
1419
+ if dx >= dy and dx >= dz:
1420
+ err_1 = 2 * dy - dx
1421
+ err_2 = 2 * dz - dx
1422
+ for i in range(dx + 1):
1423
+ points.append((x1, y1, z1))
1424
+ if err_1 > 0:
1425
+ y1 += ys
1426
+ err_1 -= 2 * dx
1427
+ if err_2 > 0:
1428
+ z1 += zs
1429
+ err_2 -= 2 * dx
1430
+ err_1 += 2 * dy
1431
+ err_2 += 2 * dz
1432
+ x1 += xs
1433
+
1434
+ # Driving axis is Y
1435
+ elif dy >= dx and dy >= dz:
1436
+ err_1 = 2 * dx - dy
1437
+ err_2 = 2 * dz - dy
1438
+ for i in range(dy + 1):
1439
+ points.append((x1, y1, z1))
1440
+ if err_1 > 0:
1441
+ x1 += xs
1442
+ err_1 -= 2 * dy
1443
+ if err_2 > 0:
1444
+ z1 += zs
1445
+ err_2 -= 2 * dy
1446
+ err_1 += 2 * dx
1447
+ err_2 += 2 * dz
1448
+ y1 += ys
1449
+
1450
+ # Driving axis is Z
1451
+ else:
1452
+ err_1 = 2 * dx - dz
1453
+ err_2 = 2 * dy - dz
1454
+ for i in range(dz + 1):
1455
+ points.append((x1, y1, z1))
1456
+ if err_1 > 0:
1457
+ x1 += xs
1458
+ err_1 -= 2 * dz
1459
+ if err_2 > 0:
1460
+ y1 += ys
1461
+ err_2 -= 2 * dz
1462
+ err_1 += 2 * dx
1463
+ err_2 += 2 * dy
1464
+ z1 += zs
1465
+
1466
+ return points
1467
+
1468
+ # Get all voxels using Bresenham algorithm
1469
+ voxel_coords = bresenham_3d(p1_np, p2_np)
1470
+
1471
+ # Add all voxels to dictionary
1472
+ for coord in voxel_coords:
1473
+ voxel_dict[tuple(coord)] = (p1, p2)
1474
+
1475
+ voxel_coords = list(voxel_dict.keys())
1476
+ endpoint_pairs = list(voxel_dict.values())
1477
+
1478
+ # --- compute error vectors ---
1479
+ error_vectors = []
1480
+ diff = p2_np - p1_np
1481
+ d_norm_sq = np.dot(diff, diff)
1482
+
1483
+ for v in voxel_coords:
1484
+ v_center = np.array(v, dtype=float) + 0.5
1485
+ if d_norm_sq == 0: # degenerate line
1486
+ closest = p1_np
1487
+ else:
1488
+ t = np.dot(v_center - p1_np, diff) / d_norm_sq
1489
+ t = np.clip(t, 0.0, 1.0)
1490
+ closest = p1_np + t * diff
1491
+ error_vectors.append(v_center - closest)
1492
+
1493
+ return voxel_coords, endpoint_pairs, error_vectors
1494
+
1495
+
1496
+ # 使用示例
1497
+ def set_seed(seed: int):
1498
+ random.seed(seed)
1499
+ np.random.seed(seed)
1500
+ torch.manual_seed(seed)
1501
+ if torch.cuda.is_available():
1502
+ torch.cuda.manual_seed(seed)
1503
+ torch.cuda.manual_seed_all(seed)
1504
+ torch.backends.cudnn.deterministic = True
1505
+ torch.backends.cudnn.benchmark = False
1506
+
1507
+ def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
1508
+ set_seed(42)
1509
+ tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
1510
+ result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD)
1511
+
1512
+ # 生成文件名
1513
+ epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0]
1514
+ dataset_name = os.path.basename(os.path.normpath(dataset_path))
1515
+
1516
+ # 保存简版报告(TXT)
1517
+ summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt")
1518
+ with open(summary_path, 'w') as f:
1519
+ # 头部信息
1520
+ f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n")
1521
+ f.write(f"Dataset: {dataset_name}\n")
1522
+ f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n")
1523
+
1524
+ # 平均指标
1525
+ f.write("=== Average Metrics ===\n")
1526
+ for category, data in result_data['avg_metrics'].items():
1527
+ if isinstance(data, dict): # 处理多分辨率情况
1528
+ f.write(f"\n{category.upper()}:\n")
1529
+ for res, metrics in data.items():
1530
+ f.write(f" Resolution {res}:\n")
1531
+ for k, v in metrics.items():
1532
+ # 确保值是数字类型后再格式化
1533
+ if isinstance(v, (int, float)):
1534
+ f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
1535
+ else:
1536
+ f.write(f" {str(k).ljust(15)}: {str(v)}\n")
1537
+ else: # 处理非多分辨率情况
1538
+ f.write(f"\n{category.upper()}:\n")
1539
+ for k, v in data.items():
1540
+ if isinstance(v, (int, float)):
1541
+ f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
1542
+ else:
1543
+ f.write(f" {str(k).ljust(15)}: {str(v)}\n")
1544
+
1545
+ # 样本级详细统计
1546
+ f.write("\n\n=== Detailed Per-Sample Metrics ===\n")
1547
+ for name, vertex_metrics, edge_metrics in zip(
1548
+ result_data['per_sample_metrics']['sample_names'],
1549
+ zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256, 512]]),
1550
+ zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256, 512]])
1551
+ ):
1552
+ # 样本标题
1553
+ f.write(f"\n◆ Sample: {name}\n")
1554
+
1555
+ # 顶点指标
1556
+ f.write(f"[Vertex Prediction]\n")
1557
+ f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
1558
+ for res, metrics in zip([128, 256, 512], vertex_metrics):
1559
+ f.write(f" {str(res).ljust(10)} "
1560
+ f"{metrics['recall']:.4f} "
1561
+ f"{metrics['precision']:.4f} "
1562
+ f"{metrics['f1']:.4f} "
1563
+ f"{metrics['num_pred']}/{metrics['num_gt']}\n")
1564
+
1565
+ # Edge指标
1566
+ f.write(f"[Edge Prediction]\n")
1567
+ f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
1568
+ for res, metrics in zip([128, 256, 512], edge_metrics):
1569
+ f.write(f" {str(res).ljust(10)} "
1570
+ f"{metrics['recall']:.4f} "
1571
+ f"{metrics['precision']:.4f} "
1572
+ f"{metrics['f1']:.4f} "
1573
+ f"{metrics['num_pred']}/{metrics['num_gt']}\n")
1574
+
1575
+ f.write("-"*60 + "\n")
1576
+
1577
+ print(f"Saved summary to: {summary_path}")
1578
+ return result_data
1579
+
1580
+
1581
+ if __name__ == '__main__':
1582
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
1583
+ evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
1584
+ EPOCH_START = 0
1585
+ EPOCH_END = 12
1586
+ CHAMFER_EDGE_THRESHOLD=0.5
1587
+ NUM_SAMPLES=20
1588
+ VISUALIZE=True
1589
+ THRESHOLD=1.5
1590
+ VISUAL_FIELD=False
1591
+
1592
+ ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch10433_loss1.2657.pt'
1593
+ ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch2000_loss0.3315.pt'
1594
+
1595
+ dataset_path = '/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/test'
1596
+ dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized'
1597
+ # dataset_path = '/gemini/user/private/zhaotianhao/data/trellis500k_compress_glb'
1598
+ dataset_path = '/gemini/user/private/zhaotianhao/data/unique_files_glb_under6000face_2degree_30ratio_0.01'
1599
+
1600
+ if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
1601
+ RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
1602
+ else:
1603
+ RENDERS_DIR = ''
1604
+
1605
+
1606
+ ckpt_dir = os.path.dirname(ckpt_path)
1607
+ eval_dir = os.path.join(ckpt_dir, "evaluate")
1608
+ os.makedirs(eval_dir, exist_ok=True)
1609
+
1610
+ if False:
1611
+ for i in range(NUM_SAMPLES):
1612
+ print("--- Starting Latent Space PCA Visualization ---")
1613
+ tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
1614
+ tester.visualize_latent_space_pca(sample_idx=i)
1615
+ print("--- PCA Visualization Finished ---")
1616
+
1617
+ if not evaluate_all_checkpoints:
1618
+ evaluate_checkpoint(ckpt_path, dataset_path, eval_dir)
1619
+ else:
1620
+ pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
1621
+
1622
+ filtered_pt_files = []
1623
+ for f in pt_files:
1624
+ try:
1625
+ parts = f.split('_')
1626
+ epoch_str = parts[1].replace('epoch', '')
1627
+ epoch = int(epoch_str)
1628
+ if EPOCH_START <= epoch <= EPOCH_END:
1629
+ filtered_pt_files.append(f)
1630
+ except Exception as e:
1631
+ print(f"Warning: Could not parse epoch from {f}: {e}")
1632
+ continue
1633
+
1634
+ for pt_file in filtered_pt_files:
1635
+ full_ckpt_path = os.path.join(ckpt_dir, pt_file)
1636
+ evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir)
train_slat_flow_128to1024_pointnet.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ['ATTN_BACKEND'] = 'xformers' # xformers is generally compatible with DDP
3
+ import numpy as np
4
+ import torch
5
+ import yaml
6
+ from torch.utils.data import DataLoader, DistributedSampler
7
+ from functools import partial
8
+ import torch.nn.functional as F
9
+ from torch.optim import AdamW
10
+ from torch.amp import GradScaler, autocast
11
+ from typing import *
12
+ from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig
13
+ import torch.distributed as dist
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+
16
+ # --- Updated Imports based on VAE script ---
17
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
18
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
19
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
20
+ from triposf.modules.sparse.basic import SparseTensor
21
+
22
+ from trellis.models.structured_latent_flow import SLatFlowModel
23
+ from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
24
+ from safetensors.torch import load_file
25
+ import torch.multiprocessing as mp
26
+ import open3d as o3d
27
+ from PIL import Image
28
+ import torch.nn as nn
29
+
30
+ from triposf.modules.utils import DiagonalGaussianDistribution
31
+ import torchvision.transforms as transforms
32
+ import re
33
+
34
+ # --- Distributed Setup Functions ---
35
+ def setup_distributed(backend="nccl"):
36
+ """Initializes the distributed environment."""
37
+ if not dist.is_initialized():
38
+ rank = int(os.environ["RANK"])
39
+ world_size = int(os.environ["WORLD_SIZE"])
40
+ local_rank = int(os.environ["LOCAL_RANK"])
41
+
42
+ torch.cuda.set_device(local_rank)
43
+ dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
44
+
45
+ return int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
46
+
47
+ def cleanup_distributed():
48
+ dist.destroy_process_group()
49
+
50
+ # --- Modified Trainer Class ---
51
+ class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
52
+ def __init__(self, *args, rank: int, local_rank: int, world_size: int, **kwargs):
53
+ super().__init__(*args, **kwargs)
54
+ self.cfg = kwargs.pop('cfg', None)
55
+ if self.cfg is None:
56
+ raise ValueError("Configuration dictionary 'cfg' must be provided.")
57
+
58
+ # --- Distributed-related attributes ---
59
+ self.rank = rank
60
+ self.local_rank = local_rank
61
+ self.world_size = world_size
62
+ self.device = torch.device(f"cuda:{self.local_rank}")
63
+ self.is_master = (self.rank == 0)
64
+
65
+ self.i_save = self.cfg['training']['save_every']
66
+ self.save_dir = self.cfg['training']['output_dir']
67
+
68
+ self.resolution = 128
69
+ self.condition_type = 'image'
70
+ self.is_cond = False
71
+ self.img_res = 518
72
+
73
+ if self.is_master:
74
+ os.makedirs(self.save_dir, exist_ok=True)
75
+ print(f"Checkpoints and logs will be saved to: {self.save_dir}")
76
+
77
+ # Initialize components and set up for DDP
78
+ self._init_components(
79
+ clip_model_path=self.cfg['training'].get('clip_model_path', None),
80
+ dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
81
+ vae_path=self.cfg['training']['vae_path'],
82
+ )
83
+
84
+ self._setup_ddp()
85
+
86
+ self.denoiser_checkpoint_path = self.cfg['training'].get('denoiser_checkpoint_path', None)
87
+
88
+ trainable_params = list(self.denoiser.parameters())
89
+ self.optimizer = AdamW(trainable_params, lr=self.cfg['training'].get('lr', 0.0001), weight_decay=0.0)
90
+
91
+ self.scaler = GradScaler()
92
+
93
+ if self.is_master:
94
+ print("Using Automatic Mixed Precision (AMP) with GradScaler.")
95
+
96
+ def _init_components(self,
97
+ clip_model_path=None,
98
+ dinov2_model_path=None,
99
+ vae_path=None,
100
+ ):
101
+ """
102
+ Initializes VAE, VoxelEncoder (PointNet), and condition models.
103
+ """
104
+ def load_file_func(path, device='cpu'):
105
+ return torch.load(path, map_location=device)
106
+
107
+ def _load_and_broadcast(model, load_fn=None, path=None, strict=True):
108
+ if self.is_master:
109
+ try:
110
+ state = load_fn(path) if load_fn else model.state_dict()
111
+ except Exception as e:
112
+ raise RuntimeError(f"Failed to load weights from {path}: {e}")
113
+ else:
114
+ state = None
115
+
116
+ dist.barrier()
117
+ state_b = [state] if self.is_master else [None]
118
+ dist.broadcast_object_list(state_b, src=0)
119
+
120
+ try:
121
+ # Handle potential key mismatches (e.g. 'module.' prefix)
122
+ model.load_state_dict(state_b[0], strict=strict)
123
+ except Exception as e:
124
+ if self.is_master: print(f"Strict loading failed for {model.__class__.__name__}, trying non-strict: {e}")
125
+ model.load_state_dict(state_b[0], strict=False)
126
+
127
+ # ------------------------- Voxel Encoder (PointNet) -------------------------
128
+ # Matching the VAE script configuration
129
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
130
+ in_channels=15,
131
+ hidden_dim=256,
132
+ out_channels=1024,
133
+ scatter_type='mean',
134
+ n_blocks=5,
135
+ resolution=128,
136
+ add_label=False,
137
+ ).to(self.device)
138
+
139
+ # ------------------------- VAE -------------------------
140
+ self.vae = VoxelVAE(
141
+ in_channels=self.cfg['model']['in_channels'],
142
+ latent_dim=self.cfg['model']['latent_dim'],
143
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
144
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
145
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
146
+ num_heads=8,
147
+ num_head_channels=64,
148
+ mlp_ratio=4.0,
149
+ attn_mode="swin",
150
+ window_size=8,
151
+ pe_mode="ape",
152
+ use_fp16=False,
153
+ use_checkpoint=True,
154
+ qk_rms_norm=False,
155
+ using_subdivide=True,
156
+ using_attn=self.cfg['model']['using_attn'],
157
+ attn_first=self.cfg['model'].get('attn_first', True),
158
+ pred_direction=self.cfg['model'].get('pred_direction', False),
159
+ ).to(self.device)
160
+
161
+
162
+ # ------------------------- Conditioning -------------------------
163
+ if self.condition_type == 'text':
164
+ self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
165
+ if self.is_master:
166
+ self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
167
+ else:
168
+ config = CLIPTextConfig.from_pretrained(clip_model_path)
169
+ self.condition_model = CLIPTextModel(config)
170
+ _load_and_broadcast(self.condition_model)
171
+
172
+ elif self.condition_type == 'image':
173
+ if self.is_master:
174
+ print("Initializing for IMAGE conditioning (DINOv2).")
175
+
176
+ # Update paths as per your environment
177
+ local_repo_path = "/root/Trisf/dinov2_resources/dinov2-main"
178
+ weights_path = "/root/Trisf/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
179
+
180
+ dinov2_model = torch.hub.load(
181
+ repo_or_dir=local_repo_path,
182
+ model='dinov2_vitl14_reg',
183
+ source='local',
184
+ pretrained=False
185
+ )
186
+ self.condition_model = dinov2_model
187
+
188
+ _load_and_broadcast(self.condition_model, load_fn=torch.load, path=weights_path)
189
+
190
+ self.image_cond_model_transform = transforms.Compose([
191
+ transforms.ToTensor(),
192
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
193
+ ])
194
+ else:
195
+ raise ValueError(f"Unsupported condition type: {self.condition_type}")
196
+
197
+ self.condition_model.to(self.device).eval()
198
+ for p in self.condition_model.parameters(): p.requires_grad = False
199
+
200
+ # ------------------------- Load VAE/Encoder Weights -------------------------
201
+ # Load weights corresponding to the logic in VAE script's `load_pretrained_woself`
202
+ # Assuming checkpoint contains 'vae' and 'voxel_encoder' keys
203
+ _load_and_broadcast(self.vae,
204
+ load_fn=lambda p: load_file_func(p)['vae'],
205
+ path=vae_path)
206
+
207
+ _load_and_broadcast(self.voxel_encoder,
208
+ load_fn=lambda p: load_file_func(p)['voxel_encoder'],
209
+ path=vae_path)
210
+
211
+ self.vae.eval()
212
+ self.voxel_encoder.eval()
213
+ for p in self.vae.parameters(): p.requires_grad = False
214
+ for p in self.voxel_encoder.parameters(): p.requires_grad = False
215
+
216
+
217
+ def _load_denoiser(self):
218
+ """Loads a checkpoint for the denoiser."""
219
+ path = self.denoiser_checkpoint_path
220
+ if not path or not os.path.isfile(path):
221
+ if self.is_master:
222
+ print("No valid checkpoint path provided for denoiser. Starting from scratch.")
223
+ return
224
+
225
+ if self.is_master:
226
+ print(f"Loading denoiser checkpoint from: {path}")
227
+ checkpoint = torch.load(path, map_location=self.device)
228
+ else:
229
+ checkpoint = None
230
+
231
+ dist.barrier()
232
+ dist_list = [checkpoint] if self.is_master else [None]
233
+ dist.broadcast_object_list(dist_list, src=0)
234
+ checkpoint = dist_list[0]
235
+
236
+ try:
237
+ self.denoiser.module.load_state_dict(checkpoint['denoiser'])
238
+ if self.is_master: print("Denoiser weights loaded successfully.")
239
+ except Exception as e:
240
+ if self.is_master: print(f"[ERROR] Failed to load denoiser state_dict: {e}")
241
+
242
+ if 'step' in checkpoint and self.is_master:
243
+ print(f"Checkpoint from step {checkpoint['step']}.")
244
+
245
+ dist.barrier()
246
+
247
+ def _setup_ddp(self):
248
+ """Sets up DDP and DataLoaders."""
249
+ self.denoiser = self.denoiser.to(self.device)
250
+ self.denoiser = DDP(self.denoiser, device_ids=[self.local_rank])
251
+
252
+ for param in self.denoiser.parameters():
253
+ param.requires_grad = True
254
+
255
+ # Use the Dataset from the VAE script
256
+ self.dataset = VoxelVertexDataset_edge(
257
+ root_dir=self.cfg['dataset']['path'],
258
+ base_resolution=self.cfg['dataset']['base_resolution'],
259
+ min_resolution=self.cfg['dataset']['min_resolution'],
260
+ cache_dir=self.cfg['dataset']['cache_dir'],
261
+ renders_dir=self.cfg['dataset']['renders_dir'],
262
+
263
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
264
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
265
+
266
+ active_voxel_res=128,
267
+ pc_sample_number=819200,
268
+
269
+ sample_type=self.cfg['dataset']['sample_type'],
270
+ )
271
+
272
+ self.sampler = DistributedSampler(
273
+ self.dataset,
274
+ num_replicas=self.world_size,
275
+ rank=self.rank,
276
+ shuffle=True
277
+ )
278
+
279
+ # Use collate_fn_pointnet
280
+ self.dataloader = DataLoader(
281
+ self.dataset,
282
+ batch_size=self.cfg['training']['batch_size'],
283
+ shuffle=False,
284
+ collate_fn=partial(collate_fn_pointnet,),
285
+ num_workers=self.cfg['training']['num_workers'],
286
+ pin_memory=True,
287
+ sampler=self.sampler,
288
+ prefetch_factor=4,
289
+ persistent_workers=True,
290
+ drop_last=True,
291
+ )
292
+
293
+ @torch.no_grad()
294
+ def encode_image(self, images) -> torch.Tensor:
295
+ if isinstance(images, torch.Tensor):
296
+ batch_tensor = images.to(self.device)
297
+ elif isinstance(images, list):
298
+ assert all(isinstance(i, Image.Image) for i in images), "Image list should be list of PIL images"
299
+ image = [i.resize((518, 518), Image.LANCZOS) for i in images]
300
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
301
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
302
+ batch_tensor = torch.stack(image).to(self.device)
303
+ else:
304
+ raise ValueError(f"Unsupported type of image: {type(image)}")
305
+
306
+ if batch_tensor.shape[-2:] != (518, 518):
307
+ batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
308
+
309
+ features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
310
+ patchtokens = F.layer_norm(features, features.shape[-1:])
311
+ return patchtokens
312
+
313
+ def process_batch(self, batch):
314
+ preprocessed_images = batch['image']
315
+ cond_ = self.encode_image(preprocessed_images)
316
+ return cond_
317
+
318
+ def train(self, num_epochs=10000):
319
+ # Unconditional handling for text/image
320
+ if self.is_cond == False:
321
+ if self.condition_type == 'text':
322
+ txt = ['']
323
+ encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
324
+ tokens = encoding['input_ids'].to(self.device)
325
+ with torch.no_grad():
326
+ cond_ = self.condition_model(input_ids=tokens).last_hidden_state
327
+ else:
328
+ blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
329
+ with torch.no_grad():
330
+ dummy_cond = self.encode_image([blank_img])
331
+ cond_ = torch.zeros_like(dummy_cond)
332
+ if self.is_master: print(f"Generated unconditional image prompt with shape: {cond_.shape}")
333
+
334
+ self._load_denoiser()
335
+ self.denoiser.train()
336
+
337
+ # Step tracking
338
+ step = 0
339
+ if self.denoiser_checkpoint_path:
340
+ match = re.search(r'step(\d+)', self.denoiser_checkpoint_path)
341
+ if match:
342
+ step = int(match.group(1))
343
+ step = 0
344
+
345
+ for epoch in range(num_epochs):
346
+ self.sampler.set_epoch(epoch)
347
+ epoch_losses = []
348
+
349
+ for i, batch in enumerate(self.dataloader):
350
+ self.optimizer.zero_grad()
351
+
352
+ # --- Conditioning ---
353
+ if self.is_cond and self.condition_type == 'image':
354
+ cond_ = self.process_batch(batch)
355
+
356
+ # Retrieve Data from collate_fn_pointnet
357
+ point_cloud = batch['point_cloud_128'].to(self.device)
358
+ active_coords = batch['active_voxels_128'].to(self.device)
359
+
360
+ # Handle Batch Size for Conditioning
361
+ batch_size = int(active_coords[:, 0].max().item() + 1)
362
+ if cond_.shape[0] != batch_size:
363
+ cond_ = cond_.expand(batch_size, -1, -1).contiguous().to(self.device)
364
+ else:
365
+ cond_ = cond_.to(self.device)
366
+
367
+ with autocast(device_type='cuda', dtype=torch.bfloat16):
368
+ with torch.no_grad():
369
+ # 1. Encode Point Cloud to Features on Active Voxels
370
+ # The encoder processes point clouds and scatters features to `active_coords`
371
+ active_voxel_feats = self.voxel_encoder(
372
+ p=point_cloud,
373
+ sparse_coords=active_coords,
374
+ res=128,
375
+ bbox_size=(-0.5, 0.5),
376
+ )
377
+
378
+ # 2. Prepare Sparse Input for VAE
379
+ sparse_input = SparseTensor(
380
+ feats=active_voxel_feats,
381
+ coords=active_coords.int()
382
+ )
383
+
384
+ # 3. Get Latent Distribution from VAE
385
+ # We use the encode method of VoxelVAE to get posterior
386
+ latent_128, posterior = self.vae.encode(sparse_input)
387
+
388
+
389
+ # 5. Calculate Diffusion Loss
390
+ terms, _ = self.training_losses(x_0=latent_128, cond=cond_)
391
+ loss = terms['loss']
392
+
393
+ self.scaler.scale(loss).backward()
394
+ self.scaler.step(self.optimizer)
395
+ self.scaler.update()
396
+
397
+ with torch.no_grad():
398
+ avg_loss = loss.detach()
399
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
400
+
401
+ step += 1
402
+
403
+ # --- Logging and Saving ---
404
+ if self.is_master:
405
+ epoch_losses.append(avg_loss.item())
406
+ if step % 10 == 0:
407
+ print(f"Epoch {epoch+1} Step {step}: "
408
+ f"Rank0_Loss = {loss.item():.4f}, "
409
+ f"Global_Avg_Loss = {avg_loss.item():.4f}, "
410
+ f"Epoch_Mean = {np.mean(epoch_losses):.4f}")
411
+
412
+ if step % self.i_save == 0 or step == 1:
413
+ checkpoint = {
414
+ 'denoiser': self.denoiser.module.state_dict(),
415
+ 'step': step
416
+ }
417
+ loss_val_str = f"{loss.item():.6f}".replace('.', '_')
418
+ save_path = os.path.join(self.save_dir, f"checkpoint_step{step}_loss{loss_val_str}.pt")
419
+ torch.save(checkpoint, save_path)
420
+ print(f"Saved checkpoint to {save_path}")
421
+
422
+ if self.is_master:
423
+ avg_loss = np.mean(epoch_losses) if epoch_losses else 0
424
+ log_path = os.path.join(self.save_dir, "loss_log.txt")
425
+ with open(log_path, "a") as f:
426
+ f.write(f"Epoch {epoch+1}, Step {step}, AvgLoss {avg_loss:.6f}\n")
427
+
428
+ dist.barrier()
429
+ # torch.cuda.empty_cache()
430
+ # gc.collect()
431
+
432
+ if self.is_master:
433
+ print("Training complete.")
434
+
435
+ def main():
436
+ if mp.get_start_method(allow_none=True) != 'spawn':
437
+ mp.set_start_method('spawn', force=True)
438
+
439
+ rank, local_rank, world_size = setup_distributed()
440
+ torch.manual_seed(42)
441
+ np.random.seed(42)
442
+
443
+ # Path to your config
444
+ config_path = "/root/Trisf/config_slat_flow_128to1024_pointnet_head.yaml"
445
+ with open(config_path) as f:
446
+ cfg = yaml.safe_load(f)
447
+
448
+ # Initialize Flow Model (on CPU first)
449
+ diffusion_model = SLatFlowModel(
450
+ resolution=cfg['flow']['resolution'],
451
+ in_channels=cfg['flow']['in_channels'],
452
+ out_channels=cfg['flow']['out_channels'],
453
+ model_channels=cfg['flow']['model_channels'],
454
+ cond_channels=cfg['flow']['cond_channels'],
455
+ num_blocks=cfg['flow']['num_blocks'],
456
+ num_heads=cfg['flow']['num_heads'],
457
+ mlp_ratio=cfg['flow']['mlp_ratio'],
458
+ patch_size=cfg['flow']['patch_size'],
459
+ num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
460
+ io_block_channels=cfg['flow']['io_block_channels'],
461
+ pe_mode=cfg['flow']['pe_mode'],
462
+ qk_rms_norm=cfg['flow']['qk_rms_norm'],
463
+ qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
464
+ use_fp16=cfg['flow'].get('use_fp16', False),
465
+ )
466
+
467
+ torch.manual_seed(42 + rank)
468
+ np.random.seed(42 + rank)
469
+
470
+ trainer = SLatFlowMatchingTrainer(
471
+ denoiser=diffusion_model,
472
+ t_schedule=cfg['t_schedule'],
473
+ sigma_min=cfg['sigma_min'],
474
+ cfg=cfg,
475
+ rank=rank,
476
+ local_rank=local_rank,
477
+ world_size=world_size,
478
+ )
479
+
480
+ trainer.train()
481
+ cleanup_distributed()
482
+
483
+ if __name__ == '__main__':
484
+ main()
train_slat_vae_512_128to1024_pointnet.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import os
3
+ # os.environ['ATTN_BACKEND'] = 'xformers'
4
+ import yaml
5
+ import torch
6
+ import time
7
+ from datetime import datetime
8
+ from torch.utils.data import DataLoader
9
+ from functools import partial
10
+ from triposf.modules.sparse.basic import SparseTensor
11
+ import torch.nn.functional as F
12
+ from torch.optim import AdamW
13
+ from torch.cuda.amp import GradScaler, autocast
14
+
15
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder import VoxelVAE
16
+
17
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet
18
+
19
+
20
+ from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
21
+
22
+ from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss
23
+
24
+ import torch.distributed as dist
25
+ from torch.nn.parallel import DistributedDataParallel as DDP
26
+ from torch.utils.data.distributed import DistributedSampler
27
+
28
+ from transformers import get_cosine_schedule_with_warmup
29
+ import math
30
+
31
+ def flatten_coords_4d(coords_4d: torch.Tensor):
32
+ coords_4d_long = coords_4d.long()
33
+
34
+ base_x = 1024
35
+ base_y = 1024 * 1024
36
+ base_z = 1024 * 1024 * 1024
37
+
38
+ flat_coords = coords_4d_long[:, 0] * base_z + \
39
+ coords_4d_long[:, 1] * base_y + \
40
+ coords_4d_long[:, 2] * base_x + \
41
+ coords_4d_long[:, 3]
42
+ return flat_coords
43
+
44
+ def downsample_voxels(
45
+ voxels: torch.Tensor,
46
+ input_resolution: int,
47
+ output_resolution: int
48
+ ) -> torch.Tensor:
49
+ if input_resolution % output_resolution != 0:
50
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
51
+ f"by output_resolution ({output_resolution}).")
52
+
53
+ factor = input_resolution // output_resolution
54
+
55
+ downsampled_voxels = voxels.clone().to(torch.long)
56
+
57
+ downsampled_voxels[:, 1:] //= factor
58
+
59
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
60
+ return unique_downsampled_voxels
61
+
62
+ class Trainer:
63
+ def __init__(self, config_path, rank, world_size, local_rank):
64
+ self.rank = rank
65
+ self.world_size = world_size
66
+ self.local_rank = local_rank
67
+ self.is_master = self.rank == 0
68
+
69
+ self.load_config(config_path)
70
+ self.accum_steps = max(1, 8 // self.cfg['training']['batch_size'])
71
+
72
+ self.config_hash = self.save_config_with_hash()
73
+ self.init_device()
74
+ self.init_dirs()
75
+ self.init_components()
76
+ self.init_training()
77
+
78
+ self.train_loss_history = []
79
+ self.eval_loss_history = []
80
+ self.best_eval_loss = float('inf')
81
+
82
+ def save_config_with_hash(self):
83
+ import hashlib
84
+
85
+ # Serialize config to hash
86
+ config_str = yaml.dump(self.cfg)
87
+ config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
88
+
89
+ # Prepare all flags as string for formatting
90
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
91
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
92
+
93
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
94
+
95
+ # Format save_dir with all placeholders
96
+ self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
97
+ dataset_name=dataset_name,
98
+ config_hash=config_hash,
99
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
100
+ multires=self.cfg['model']['multires'],
101
+ add_block_embed=add_block_embed_flag,
102
+ using_attn=using_attn_flag,
103
+ batch_size=self.cfg['training']['batch_size'],
104
+ )
105
+
106
+ if self.is_master:
107
+ os.makedirs(self.save_dir, exist_ok=True)
108
+ config_path = os.path.join(self.save_dir, "config.yaml")
109
+ with open(config_path, 'w') as f:
110
+ yaml.dump(self.cfg, f)
111
+
112
+ dist.barrier()
113
+ return config_hash
114
+
115
+ def save_checkpoint(self, epoch, avg_loss, batch_idx):
116
+ if not self.is_master:
117
+ return
118
+ checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
119
+ config_path = os.path.join(self.save_dir, "config.yaml")
120
+
121
+ torch.save({
122
+ 'voxel_encoder': self.voxel_encoder.module.state_dict(),
123
+ 'vae': self.vae.module.state_dict(),
124
+ 'epoch': epoch,
125
+ 'loss': avg_loss,
126
+ 'config': self.cfg
127
+ }, checkpoint_path)
128
+
129
+ def quoted_presenter(dumper, data):
130
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
131
+
132
+ yaml.add_representer(str, quoted_presenter)
133
+
134
+ with open(config_path, 'w') as f:
135
+ yaml.dump(self.cfg, f)
136
+
137
+ def load_config(self, config_path):
138
+ with open(config_path) as f:
139
+ self.cfg = yaml.safe_load(f)
140
+
141
+ # Extract and convert flags for formatting
142
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
143
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
144
+
145
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
146
+
147
+ self.save_dir = self.cfg['experiment']['save_dir'].format(
148
+ dataset_name=dataset_name,
149
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
150
+ multires=self.cfg['model']['multires'],
151
+ add_block_embed=add_block_embed_flag,
152
+ using_attn=using_attn_flag,
153
+ batch_size=self.cfg['training']['batch_size'],
154
+ )
155
+
156
+ if self.is_master:
157
+ os.makedirs(self.save_dir, exist_ok=True)
158
+ dist.barrier()
159
+
160
+
161
+ def init_device(self):
162
+ self.device = torch.device(f"cuda:{self.local_rank}")
163
+
164
+ def init_dirs(self):
165
+ self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
166
+ if self.is_master:
167
+ with open(self.log_file, "a") as f:
168
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
169
+ f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
170
+
171
+ def init_components(self):
172
+ # self.dataset = VoxelVertexDataset_edge_shapenet(
173
+ # root_dir=self.cfg['dataset']['path'],
174
+ # file_list_path=self.cfg['dataset']['file_list_path'],
175
+ # map_file_path=self.cfg['dataset']['map_file_path'],
176
+
177
+ # base_resolution=self.cfg['dataset']['base_resolution'],
178
+ # min_resolution=self.cfg['dataset']['min_resolution'],
179
+ # cache_dir=self.cfg['dataset']['cache_dir'],
180
+ # renders_dir=self.cfg['dataset']['renders_dir'],
181
+
182
+ # filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
183
+ # cache_filter_path=self.cfg['dataset']['cache_filter_path'],
184
+ # )
185
+
186
+ self.dataset = VoxelVertexDataset_edge(
187
+ root_dir=self.cfg['dataset']['path'],
188
+ base_resolution=self.cfg['dataset']['base_resolution'],
189
+ min_resolution=self.cfg['dataset']['min_resolution'],
190
+ cache_dir=self.cfg['dataset']['cache_dir'],
191
+ renders_dir=self.cfg['dataset']['renders_dir'],
192
+
193
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
194
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
195
+
196
+ active_voxel_res=128,
197
+ pc_sample_number=819200,
198
+
199
+ sample_type=self.cfg['dataset']['sample_type'],
200
+ )
201
+
202
+ self.sampler = DistributedSampler(
203
+ self.dataset,
204
+ num_replicas=self.world_size,
205
+ rank=self.rank,
206
+ shuffle=True,
207
+ )
208
+
209
+ self.dataloader = DataLoader(
210
+ self.dataset,
211
+ batch_size=self.cfg['training']['batch_size'],
212
+ shuffle=False,
213
+ collate_fn=partial(collate_fn_pointnet,),
214
+ num_workers=self.cfg['training']['num_workers'],
215
+ pin_memory=True,
216
+ sampler=self.sampler,
217
+ # prefetch_factor=4,
218
+ persistent_workers=True,
219
+ )
220
+
221
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
222
+ in_channels=15,
223
+ hidden_dim=256,
224
+ out_channels=1024,
225
+ scatter_type='mean',
226
+ n_blocks=5,
227
+ resolution=128,
228
+ add_label=False,
229
+ ).to(self.device)
230
+
231
+ # ablation 3: voxelvae_1volume, have tested
232
+ self.vae = VoxelVAE(
233
+ in_channels=self.cfg['model']['in_channels'],
234
+ latent_dim=self.cfg['model']['latent_dim'],
235
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
236
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
237
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
238
+ num_heads=8,
239
+ num_head_channels=64,
240
+ mlp_ratio=4.0,
241
+ attn_mode="swin",
242
+ window_size=8,
243
+ pe_mode="ape",
244
+ use_fp16=False,
245
+ use_checkpoint=True,
246
+ qk_rms_norm=False,
247
+ using_subdivide=True,
248
+ using_attn=self.cfg['model']['using_attn'],
249
+ attn_first=self.cfg['model'].get('attn_first', True),
250
+ pred_direction=self.cfg['model'].get('pred_direction', False),
251
+ ).to(self.device)
252
+
253
+ if self.cfg['training']['from_pretrained']:
254
+ load_pretrained_woself(
255
+ checkpoint_path=self.cfg['training']['checkpoint_path'],
256
+ voxel_encoder=self.voxel_encoder,
257
+ vae=self.vae,
258
+ optimizer=None,
259
+ )
260
+
261
+ self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
262
+ self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=False)
263
+
264
+ def init_training(self):
265
+ self.optimizer = AdamW(
266
+ list(self.vae.module.parameters()) +
267
+ list(self.voxel_encoder.module.parameters()),
268
+ lr=self.cfg['training']['lr'],
269
+ weight_decay=0.01,
270
+ )
271
+
272
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
273
+ # print('num_update_steps_per_epoch', num_update_steps_per_epoch) # 1305
274
+ max_epochs = self.cfg['training']['max_epochs']
275
+ num_training_steps = max_epochs * num_update_steps_per_epoch
276
+
277
+ num_warmup_steps = 1000
278
+
279
+ # self.scheduler = torch.optim.lr_scheduler.LambdaLR(
280
+ # self.optimizer,
281
+ # lr_lambda=lambda epoch: self.cfg['training']['gamma'] ** (epoch // self.cfg['training']['step_size'])
282
+ # )
283
+
284
+ self.scheduler = get_cosine_schedule_with_warmup(
285
+ self.optimizer,
286
+ num_warmup_steps=num_warmup_steps,
287
+ num_training_steps=num_training_steps
288
+ )
289
+
290
+ self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=10.0).to(self.device)
291
+ self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
292
+ self.asyloss = AsymmetricFocalLoss(
293
+ gamma_pos=0.0,
294
+ gamma_neg=4.0,
295
+ clip=0.05,
296
+ )
297
+
298
+ self.bce_loss = torch.nn.BCEWithLogitsLoss()
299
+
300
+ self.dice_loss = DiceLoss()
301
+ self.scaler = GradScaler()
302
+
303
+ def train_step(self, batch):
304
+ """Modified training step that handles vertex and edge voxels separately after initial prediction."""
305
+ # 1. Retrieve data from batch
306
+ combined_voxels_1024 = batch['combined_voxels_1024'].to(self.device)
307
+ combined_voxel_labels_1024 = batch['combined_voxel_labels_1024'].to(self.device)
308
+ gt_vertex_voxels_1024 = batch['gt_vertex_voxels_1024'].to(self.device)
309
+
310
+ gt_edge_voxels_1024_ = batch['gt_edge_voxels_1024'].to(self.device)
311
+ gt_combined_endpoints_1024 = batch['gt_combined_endpoints_1024'].to(self.device)
312
+ gt_combined_errors_1024 = batch['gt_combined_errors_1024'].to(self.device)
313
+
314
+ edge_mask = (combined_voxel_labels_1024 == 1)
315
+
316
+ gt_edge_endpoints_1024 = gt_combined_endpoints_1024[edge_mask]
317
+ gt_edge_errors_1024 = gt_combined_errors_1024[edge_mask]
318
+ gt_edge_voxels_1024 = combined_voxels_1024[edge_mask].to(self.device)
319
+
320
+ # print('gt_edge_voxels_1024_-gt_edge_voxels_1024.sum()', (gt_edge_voxels_1024_-gt_edge_voxels_1024).sum())
321
+
322
+
323
+ p1 = gt_edge_endpoints_1024[:, 1:4].float()
324
+ p2 = gt_edge_endpoints_1024[:, 4:7].float()
325
+
326
+ mask = ( (p1[:,0] < p2[:,0]) |
327
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
328
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
329
+
330
+ pA = torch.where(mask[:, None], p1, p2) # smaller one
331
+ pB = torch.where(mask[:, None], p2, p1) # larger one
332
+
333
+ d = pB - pA
334
+ dir_gt = F.normalize(d, dim=-1, eps=1e-6)
335
+ vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
336
+ vtx_256 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=256)
337
+ vtx_512 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=512)
338
+ vtx_1024 = gt_vertex_voxels_1024
339
+
340
+ edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
341
+ edge_256 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=256)
342
+ edge_512 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=512)
343
+ edge_1024 = combined_voxels_1024
344
+
345
+ active_coords = batch['active_voxels_128'].to(self.device)
346
+ point_cloud = batch['point_cloud_128'].to(self.device)
347
+
348
+ active_voxel_feats = self.voxel_encoder(
349
+ p=point_cloud,
350
+ sparse_coords=active_coords,
351
+ res=128,
352
+ bbox_size=(-0.5, 0.5),
353
+
354
+ # voxel_label=active_labels,
355
+ )
356
+
357
+ sparse_input = SparseTensor(
358
+ feats=active_voxel_feats,
359
+ coords=active_coords.int()
360
+ )
361
+
362
+ gt_edge_voxels_list = [
363
+ edge_128,
364
+ edge_256,
365
+ edge_512,
366
+ edge_1024,
367
+ ]
368
+
369
+ gt_vertex_voxels_list = [
370
+ vtx_128,
371
+ vtx_256,
372
+ vtx_512,
373
+ vtx_1024,
374
+ ]
375
+
376
+
377
+ results, posterior, latent_128 = self.vae(
378
+ sparse_input,
379
+ gt_vertex_voxels_list=gt_vertex_voxels_list,
380
+ gt_edge_voxels_list=gt_edge_voxels_list,
381
+ training=True,
382
+ sample_ratio=0.,
383
+ )
384
+
385
+ # print("results[-1]['edge']['coords_4d'][1827:1830]", results[-1]['edge']['coords_4d'][1827:1830])
386
+ total_loss = 0.
387
+ prune_loss_total = 0.
388
+ vertex_loss_total = 0.
389
+ edge_loss_total=0.
390
+
391
+ with autocast(dtype=torch.bfloat16):
392
+ initial_result = results[0]
393
+ vertex_mask = initial_result['vertex_mask']
394
+ vtx_logits = initial_result['vtx_feats']
395
+ vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
396
+
397
+ edge_mask = initial_result['edge_mask']
398
+ edge_logits = initial_result['edge_feats']
399
+ edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
400
+
401
+ vertex_loss_total += vertex_loss
402
+ edge_loss_total += edge_loss
403
+
404
+ total_loss += vertex_loss
405
+ total_loss += edge_loss
406
+
407
+ # Process each level's results
408
+ for idx, res_dict in enumerate(results[1:], start=1):
409
+ # Vertex branch losses
410
+ vertex_pred_coords = res_dict['vertex']['occ_coords']
411
+ vertex_occ_probs = res_dict['vertex']['occ_probs']
412
+ vertex_gt_coords = res_dict['vertex']['coords']
413
+
414
+ vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=1024).float()
415
+ # print('vertex_labels.sum()', vertex_labels.sum(), idx)
416
+ vertex_logits = vertex_occ_probs.squeeze()
417
+
418
+ # if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
419
+ vertex_prune_loss = self.focal_loss(vertex_logits, vertex_labels)
420
+ # vertex_prune_loss = self.dice_loss(vertex_logits, vertex_labels)
421
+
422
+ # dilation 1: bce loss
423
+ # vertex_prune_loss = self.bce_loss(vertex_logits, vertex_labels,)
424
+
425
+ prune_loss_total += vertex_prune_loss
426
+ total_loss += vertex_prune_loss
427
+
428
+
429
+ # Edge branch losses
430
+ edge_pred_coords = res_dict['edge']['occ_coords']
431
+ edge_occ_probs = res_dict['edge']['occ_probs']
432
+ edge_gt_coords = res_dict['edge']['coords']
433
+
434
+ edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=1024).float()
435
+ # print('edge_labels.sum()', edge_labels.sum(), idx)
436
+ edge_logits = edge_occ_probs.squeeze()
437
+ # if edge_labels.sum() > 0 and edge_labels.sum() < len(edge_labels):
438
+ edge_prune_loss = self.focal_loss(edge_logits, edge_labels)
439
+
440
+ # dilation 1: bce loss
441
+ # edge_prune_loss = self.bce_loss(edge_logits, edge_labels,)
442
+
443
+ prune_loss_total += edge_prune_loss
444
+ total_loss += edge_prune_loss
445
+
446
+ if idx == 3:
447
+ pred_coords = res_dict['edge']['coords_4d'] # [N,4] (b,x,y,z)
448
+ pred_feats = res_dict['edge']['predicted_offset_feats'] # [N,C]
449
+
450
+ gt_coords = gt_edge_voxels_1024.to(pred_coords.device) # [M,4]
451
+ gt_feats = gt_edge_errors_1024[:, 1:].to(pred_coords.device) # [M,C]
452
+
453
+ pred_keys = flatten_coords_4d(pred_coords)
454
+ gt_keys = flatten_coords_4d(gt_coords)
455
+
456
+ sorted_pred_keys, pred_order = torch.sort(pred_keys)
457
+ pred_coords_sorted = pred_coords[pred_order]
458
+ pred_feats_sorted = pred_feats[pred_order]
459
+
460
+
461
+ sorted_gt_keys, gt_order = torch.sort(gt_keys)
462
+ gt_coords_sorted = gt_coords[gt_order]
463
+ gt_feats_sorted = gt_feats[gt_order]
464
+
465
+
466
+ # pos = torch.searchsorted(sorted_gt_keys, sorted_pred_keys)
467
+ # valid_pos = pos < len(sorted_gt_keys)
468
+ # matched = torch.zeros_like(valid_pos, dtype=torch.bool)
469
+ # matched[valid_pos] = (sorted_gt_keys[pos[valid_pos]] == sorted_pred_keys[valid_pos])
470
+ # valid_mask = valid_pos & matched
471
+
472
+ pos = torch.searchsorted(sorted_gt_keys, sorted_pred_keys)
473
+ valid_mask = (pos < len(sorted_gt_keys)) & (sorted_gt_keys[pos] == sorted_pred_keys)
474
+
475
+ if valid_mask.any():
476
+ matched_pred_feats = pred_feats_sorted[valid_mask]
477
+ matched_gt_feats = gt_feats_sorted[pos[valid_mask]]
478
+
479
+
480
+ mse_loss_feats = self.mse_loss(matched_pred_feats, matched_gt_feats * 2)
481
+ total_loss += mse_loss_feats
482
+
483
+ if self.cfg['model'].get('pred_direction', False):
484
+ pred_dirs = res_dict['edge']['predicted_direction_feats']
485
+ dir_gt_device = dir_gt.to(pred_coords.device)
486
+
487
+ pred_dirs_sorted = pred_dirs[pred_order]
488
+ dir_gt_sorted = dir_gt_device[gt_order]
489
+
490
+ matched_pred_dirs = pred_dirs_sorted[valid_mask]
491
+ matched_gt_dirs = dir_gt_sorted[pos[valid_mask]]
492
+
493
+ mse_loss_dirs = self.mse_loss(matched_pred_dirs, matched_gt_dirs)
494
+ total_loss += mse_loss_dirs
495
+ else:
496
+ mse_loss_feats = torch.tensor(0., device=pred_coords.device)
497
+
498
+ if self.cfg['model'].get('pred_direction', False):
499
+ mse_loss_dirs = torch.tensor(0., device=pred_coords.device)
500
+
501
+
502
+ # KL loss
503
+ kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
504
+ total_loss += kl_loss
505
+
506
+ # Backpropagation
507
+ scaled_total_loss = total_loss / self.accum_steps
508
+ self.scaler.scale(scaled_total_loss).backward()
509
+
510
+ return {
511
+ 'total_loss': total_loss.item(),
512
+ 'kl_loss': kl_loss.item(),
513
+ 'prune_loss': prune_loss_total.item(),
514
+ 'vertex_loss': vertex_loss_total.item(),
515
+ 'edge_loss': edge_loss_total.item(),
516
+ 'offset_loss': mse_loss_feats.item(),
517
+ 'direction_loss': mse_loss_dirs.item(),
518
+ }
519
+
520
+
521
+ def train(self):
522
+ accum_steps = self.accum_steps
523
+ for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
524
+ self.dataloader.sampler.set_epoch(epoch)
525
+ # Initialize metrics
526
+ metrics = {
527
+ 'total_loss': 0.0,
528
+ 'kl_loss': 0.0,
529
+ 'prune_loss': 0.0,
530
+ 'vertex_loss': 0.0,
531
+ 'edge_loss': 0.0,
532
+ 'offset_loss': 0.0,
533
+ 'direction_loss': 0.0,
534
+ }
535
+ num_batches = 0
536
+ self.optimizer.zero_grad(set_to_none=True)
537
+
538
+ for i, batch in enumerate(self.dataloader):
539
+ # Get all losses from train_step
540
+ if batch is None:
541
+ continue
542
+ step_losses = self.train_step(batch)
543
+
544
+ # Accumulate losses
545
+ for key in metrics:
546
+ metrics[key] += step_losses[key]
547
+ num_batches += 1
548
+
549
+ if (i + 1) % accum_steps == 0:
550
+ self.scaler.unscale_(self.optimizer)
551
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
552
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
553
+
554
+ self.scaler.step(self.optimizer)
555
+ self.scaler.update()
556
+ self.optimizer.zero_grad(set_to_none=True)
557
+
558
+ self.scheduler.step()
559
+
560
+ # Print batch-level metrics
561
+ if self.is_master:
562
+ avg_metric = {key: value / num_batches for key, value in metrics.items()}
563
+ print(
564
+ f"[Epoch {epoch}] Batch:{num_batches} "
565
+ f"AvgL:{avg_metric['total_loss']:.4f} "
566
+ f"Loss: {step_losses['total_loss']:.4f}, "
567
+ f"KLL: {step_losses['kl_loss']:.4f}, "
568
+ f"PruneL: {step_losses['prune_loss']:.4f}, "
569
+ f"VertexL: {step_losses['vertex_loss']:.4f}, "
570
+ f"EdgeL: {step_losses['edge_loss']:.4f}, "
571
+ f"OffsetL: {step_losses['offset_loss']:.4f}, "
572
+ f"DireL: {step_losses['direction_loss']:.4f}, "
573
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
574
+ )
575
+
576
+ if i % 2000 == 0 and i != 0:
577
+ self.save_checkpoint(epoch, avg_metric['total_loss'], i)
578
+ with open(self.log_file, "a") as f:
579
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
580
+ log_line = (
581
+ f"Epoch {epoch:05d} | "
582
+ f"Batch {i:05d} | "
583
+ f"Loss: {avg_metric['total_loss']:.6f} "
584
+ f"Avg KLL: {avg_metric['kl_loss']:.4f} "
585
+ f"Avg PruneL: {avg_metric['prune_loss']:.4f} "
586
+ f"Avg VertexL: {avg_metric['vertex_loss']:.4f} "
587
+ f"Avg EdgeL: {avg_metric['edge_loss']:.4f} "
588
+ f"Avg OffsetL: {avg_metric['offset_loss']:.4f} "
589
+ f"Avg DireL: {avg_metric['direction_loss']:.4f} "
590
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
591
+ f"[{current_time}]\n"
592
+ )
593
+ f.write(log_line)
594
+
595
+ if num_batches % accum_steps != 0:
596
+ self.scaler.unscale_(self.optimizer)
597
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
598
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
599
+
600
+ self.scaler.step(self.optimizer)
601
+ self.scaler.update()
602
+ self.optimizer.zero_grad(set_to_none=True)
603
+
604
+ self.scheduler.step()
605
+
606
+ # Calculate epoch averages
607
+ avg_metrics = {key: value / num_batches for key, value in metrics.items()}
608
+ self.train_loss_history.append(avg_metrics['total_loss'])
609
+
610
+
611
+ # Log to file
612
+ if self.is_master:
613
+ with open(self.log_file, "a") as f:
614
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
615
+ log_line = (
616
+ f"Epoch {epoch:05d} | "
617
+ f"Loss: {avg_metrics['total_loss']:.6f} "
618
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
619
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
620
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
621
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
622
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
623
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
624
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
625
+ f"[{current_time}]\n"
626
+ )
627
+ f.write(log_line)
628
+
629
+ # Print epoch summary
630
+ print(
631
+ f"[Epoch {epoch}] "
632
+ f"Avg Loss: {avg_metrics['total_loss']:.4f} "
633
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
634
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
635
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
636
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
637
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
638
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
639
+ f"[{current_time}]\n"
640
+ )
641
+
642
+ # Save checkpoint
643
+ if epoch % self.cfg['training']['save_every'] == 0:
644
+ self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
645
+
646
+ # Update learning rate
647
+ if self.is_master:
648
+ current_lr = self.optimizer.param_groups[0]['lr']
649
+ print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
650
+
651
+ dist.barrier()
652
+
653
+
654
+ def main():
655
+ # Initialize the process group
656
+ dist.init_process_group(backend='nccl')
657
+
658
+ # Get rank and world size from environment variables set by the launcher
659
+ rank = int(os.environ['RANK'])
660
+ world_size = int(os.environ['WORLD_SIZE'])
661
+ local_rank = int(os.environ['LOCAL_RANK'])
662
+
663
+ # Set the device for the current process. This is crucial.
664
+ torch.cuda.set_device(local_rank)
665
+ torch.manual_seed(42+rank)
666
+
667
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
668
+ # Pass the distributed info to the Trainer
669
+ trainer = Trainer(
670
+ config_path="/gemini/user/private/zhaotianhao/Triposf/config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_addhead.yaml",
671
+ rank=rank,
672
+ world_size=world_size,
673
+ local_rank=local_rank
674
+ )
675
+ trainer.train()
676
+
677
+ # Clean up the process group
678
+ dist.destroy_process_group()
679
+
680
+
681
+ if __name__ == '__main__':
682
+ main()
train_slat_vae_512_128to1024_pointnet_addhead.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import os
3
+ # os.environ['ATTN_BACKEND'] = 'xformers'
4
+ import yaml
5
+ import torch
6
+ import time
7
+ from datetime import datetime
8
+ from torch.utils.data import DataLoader
9
+ from functools import partial
10
+ from triposf.modules.sparse.basic import SparseTensor
11
+ import torch.nn.functional as F
12
+ from torch.optim import AdamW
13
+ from torch.cuda.amp import GradScaler, autocast
14
+
15
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_addhead import VoxelVAE
16
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
17
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
18
+
19
+ from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss, FocalLoss
20
+
21
+ import torch.distributed as dist
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+ from torch.utils.data.distributed import DistributedSampler
24
+
25
+ from transformers import get_cosine_schedule_with_warmup
26
+ import math
27
+
28
+ from torchvision.ops import sigmoid_focal_loss
29
+
30
+ import numpy as np
31
+ import open3d as o3d
32
+
33
+
34
+ def flatten_coords_4d(coords_4d: torch.Tensor):
35
+ coords_4d_long = coords_4d.long()
36
+
37
+ base_x = 1024
38
+ base_y = 1024 * 1024
39
+ base_z = 1024 * 1024 * 1024
40
+
41
+ flat_coords = coords_4d_long[:, 0] * base_z + \
42
+ coords_4d_long[:, 1] * base_y + \
43
+ coords_4d_long[:, 2] * base_x + \
44
+ coords_4d_long[:, 3]
45
+ return flat_coords
46
+
47
+ def downsample_voxels(
48
+ voxels: torch.Tensor,
49
+ input_resolution: int,
50
+ output_resolution: int
51
+ ) -> torch.Tensor:
52
+ if input_resolution % output_resolution != 0:
53
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
54
+ f"by output_resolution ({output_resolution}).")
55
+
56
+ factor = input_resolution // output_resolution
57
+
58
+ downsampled_voxels = voxels.clone().to(torch.long)
59
+
60
+ downsampled_voxels[:, 1:] //= factor
61
+
62
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
63
+ return unique_downsampled_voxels
64
+
65
+ class Trainer:
66
+ def __init__(self, config_path, rank, world_size, local_rank):
67
+ self.rank = rank
68
+ self.world_size = world_size
69
+ self.local_rank = local_rank
70
+ self.is_master = self.rank == 0
71
+
72
+ self.load_config(config_path)
73
+ self.accum_steps = max(1, 8 // self.cfg['training']['batch_size'])
74
+
75
+ self.config_hash = self.save_config_with_hash()
76
+ self.init_device()
77
+ self.init_dirs()
78
+ self.init_components()
79
+ self.init_training()
80
+
81
+ self.train_loss_history = []
82
+ self.eval_loss_history = []
83
+ self.best_eval_loss = float('inf')
84
+
85
+ def save_config_with_hash(self):
86
+ import hashlib
87
+
88
+ # Serialize config to hash
89
+ config_str = yaml.dump(self.cfg)
90
+ config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
91
+
92
+ # Prepare all flags as string for formatting
93
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
94
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
95
+
96
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
97
+
98
+ # Format save_dir with all placeholders
99
+ self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
100
+ dataset_name=dataset_name,
101
+ config_hash=config_hash,
102
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
103
+ multires=self.cfg['model']['multires'],
104
+ add_block_embed=add_block_embed_flag,
105
+ using_attn=using_attn_flag,
106
+ batch_size=self.cfg['training']['batch_size'],
107
+ )
108
+
109
+ if self.is_master:
110
+ os.makedirs(self.save_dir, exist_ok=True)
111
+ config_path = os.path.join(self.save_dir, "config.yaml")
112
+ with open(config_path, 'w') as f:
113
+ yaml.dump(self.cfg, f)
114
+
115
+ dist.barrier()
116
+ return config_hash
117
+
118
+ def save_checkpoint(self, epoch, avg_loss, batch_idx):
119
+ if not self.is_master:
120
+ return
121
+ checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
122
+ config_path = os.path.join(self.save_dir, "config.yaml")
123
+
124
+ torch.save({
125
+ 'voxel_encoder': self.voxel_encoder.module.state_dict(),
126
+ 'vae': self.vae.module.state_dict(),
127
+ 'connection_head': self.connection_head.module.state_dict(),
128
+ 'epoch': epoch,
129
+ 'loss': avg_loss,
130
+ 'config': self.cfg
131
+ }, checkpoint_path)
132
+
133
+ def quoted_presenter(dumper, data):
134
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
135
+
136
+ yaml.add_representer(str, quoted_presenter)
137
+
138
+ with open(config_path, 'w') as f:
139
+ yaml.dump(self.cfg, f)
140
+
141
+ def load_config(self, config_path):
142
+ with open(config_path) as f:
143
+ self.cfg = yaml.safe_load(f)
144
+
145
+ # Extract and convert flags for formatting
146
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
147
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
148
+
149
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
150
+
151
+ self.save_dir = self.cfg['experiment']['save_dir'].format(
152
+ dataset_name=dataset_name,
153
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
154
+ multires=self.cfg['model']['multires'],
155
+ add_block_embed=add_block_embed_flag,
156
+ using_attn=using_attn_flag,
157
+ batch_size=self.cfg['training']['batch_size'],
158
+ )
159
+
160
+ if self.is_master:
161
+ os.makedirs(self.save_dir, exist_ok=True)
162
+ dist.barrier()
163
+
164
+
165
+ def init_device(self):
166
+ self.device = torch.device(f"cuda:{self.local_rank}")
167
+
168
+ def init_dirs(self):
169
+ self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
170
+ if self.is_master:
171
+ with open(self.log_file, "a") as f:
172
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
173
+ f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
174
+
175
+ def init_components(self):
176
+ self.dataset = VoxelVertexDataset_edge(
177
+ root_dir=self.cfg['dataset']['path'],
178
+ base_resolution=self.cfg['dataset']['base_resolution'],
179
+ min_resolution=self.cfg['dataset']['min_resolution'],
180
+ cache_dir=self.cfg['dataset']['cache_dir'],
181
+ renders_dir=self.cfg['dataset']['renders_dir'],
182
+
183
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
184
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
185
+
186
+ active_voxel_res=128,
187
+ pc_sample_number=819200,
188
+
189
+ sample_type=self.cfg['dataset']['sample_type'],
190
+ )
191
+
192
+ self.sampler = DistributedSampler(
193
+ self.dataset,
194
+ num_replicas=self.world_size,
195
+ rank=self.rank,
196
+ shuffle=True,
197
+ )
198
+
199
+ self.dataloader = DataLoader(
200
+ self.dataset,
201
+ batch_size=self.cfg['training']['batch_size'],
202
+ shuffle=False,
203
+ collate_fn=partial(collate_fn_pointnet,),
204
+ num_workers=self.cfg['training']['num_workers'],
205
+ pin_memory=True,
206
+ sampler=self.sampler,
207
+ prefetch_factor=4,
208
+ persistent_workers=True,
209
+ )
210
+
211
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
212
+ in_channels=15,
213
+ hidden_dim=256,
214
+ out_channels=1024,
215
+ scatter_type='mean',
216
+ n_blocks=5,
217
+ resolution=128,
218
+ add_label=False,
219
+ ).to(self.device)
220
+
221
+ self.connection_head = ConnectionHead(
222
+ channels=32 * 2,
223
+ out_channels=1,
224
+ mlp_ratio=16,
225
+ ).to(self.device)
226
+
227
+
228
+ # ablation 3: voxelvae_1volume, have tested
229
+ self.vae = VoxelVAE(
230
+ in_channels=self.cfg['model']['in_channels'],
231
+ latent_dim=self.cfg['model']['latent_dim'],
232
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
233
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
234
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
235
+ num_heads=8,
236
+ num_head_channels=64,
237
+ mlp_ratio=4.0,
238
+ attn_mode="swin",
239
+ window_size=8,
240
+ pe_mode="ape",
241
+ use_fp16=False,
242
+ use_checkpoint=True,
243
+ qk_rms_norm=False,
244
+ using_subdivide=True,
245
+ using_attn=self.cfg['model']['using_attn'],
246
+ attn_first=self.cfg['model'].get('attn_first', True),
247
+ pred_direction=self.cfg['model'].get('pred_direction', False),
248
+ ).to(self.device)
249
+
250
+ if self.cfg['training']['from_pretrained']:
251
+ load_pretrained_woself(
252
+ checkpoint_path=self.cfg['training']['checkpoint_path'],
253
+ voxel_encoder=self.voxel_encoder,
254
+ vae=self.vae,
255
+ connection_head=self.connection_head,
256
+ optimizer=None,
257
+ )
258
+
259
+ self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
260
+ self.connection_head = DDP(self.connection_head, device_ids=[self.local_rank], find_unused_parameters=False)
261
+ self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=True)
262
+
263
+ def init_training(self):
264
+ self.optimizer = AdamW(
265
+ list(self.vae.module.parameters()) +
266
+ list(self.voxel_encoder.module.parameters()) +
267
+ list(self.connection_head.module.parameters()),
268
+ lr=self.cfg['training']['lr'],
269
+ weight_decay=0.01,
270
+ )
271
+
272
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
273
+ max_epochs = self.cfg['training']['max_epochs']
274
+ num_training_steps = max_epochs * num_update_steps_per_epoch
275
+
276
+ num_warmup_steps = 400
277
+
278
+ self.scheduler = get_cosine_schedule_with_warmup(
279
+ self.optimizer,
280
+ num_warmup_steps=num_warmup_steps,
281
+ num_training_steps=num_training_steps
282
+ )
283
+
284
+ self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=10.0).to(self.device)
285
+ self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
286
+ self.asyloss = AsymmetricFocalLoss(
287
+ gamma_pos=0.0,
288
+ gamma_neg=4.0,
289
+ clip=0.05,
290
+ )
291
+
292
+ self.bce_loss = torch.nn.BCEWithLogitsLoss()
293
+
294
+ self.dice_loss = DiceLoss()
295
+ self.scaler = GradScaler()
296
+
297
+ def train_step(self, batch, b_idx):
298
+ """Modified training step that handles vertex and edge voxels separately after initial prediction."""
299
+ # 1. Retrieve data from batch
300
+ combined_voxels_1024 = batch['combined_voxels_1024'].to(self.device)
301
+ combined_voxel_labels_1024 = batch['combined_voxel_labels_1024'].to(self.device)
302
+ gt_vertex_voxels_1024 = batch['gt_vertex_voxels_1024'].to(self.device)
303
+
304
+ # gt_edge_voxels_1024 = batch['gt_edge_voxels_1024'].to(self.device)
305
+ # gt_combined_endpoints_1024 = batch['gt_combined_endpoints_1024'].to(self.device)
306
+ # gt_combined_errors_1024 = batch['gt_combined_errors_1024'].to(self.device)
307
+
308
+ # gt_edges = batch['gt_vertex_edge_indices_256'].to(self.device)
309
+ gt_edges = batch['gt_vertex_edge_indices_1024'].to(self.device)
310
+
311
+ edge_mask = (combined_voxel_labels_1024 == 1)
312
+
313
+ # gt_edge_endpoints_1024 = gt_combined_endpoints_1024[edge_mask]
314
+ # gt_edge_errors_1024 = gt_combined_errors_1024[edge_mask]
315
+
316
+ # p1 = gt_edge_endpoints_1024[:, 1:4].float()
317
+ # p2 = gt_edge_endpoints_1024[:, 4:7].float()
318
+
319
+ # mask = ( (p1[:,0] < p2[:,0]) |
320
+ # ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
321
+ # ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
322
+
323
+ # pA = torch.where(mask[:, None], p1, p2) # smaller one
324
+ # pB = torch.where(mask[:, None], p2, p1) # larger one
325
+
326
+ # d = pB - pA
327
+ # dir_gt = F.normalize(d, dim=-1, eps=1e-6)
328
+ vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
329
+ vtx_256 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=256)
330
+ vtx_512 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=512)
331
+ vtx_1024 = gt_vertex_voxels_1024
332
+
333
+ edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
334
+ edge_256 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=256)
335
+ edge_512 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=512)
336
+ edge_1024 = combined_voxels_1024
337
+
338
+ active_coords = batch['active_voxels_128'].to(self.device)
339
+ point_cloud = batch['point_cloud_128'].to(self.device)
340
+
341
+ with autocast(dtype=torch.bfloat16):
342
+ active_voxel_feats = self.voxel_encoder(
343
+ p=point_cloud,
344
+ sparse_coords=active_coords,
345
+ res=128,
346
+ bbox_size=(-0.5, 0.5),
347
+ )
348
+
349
+ sparse_input = SparseTensor(
350
+ feats=active_voxel_feats,
351
+ coords=active_coords.int()
352
+ )
353
+
354
+ gt_edge_voxels_list = [
355
+ edge_128,
356
+ edge_256,
357
+ edge_512,
358
+ edge_1024,
359
+ ]
360
+
361
+ gt_vertex_voxels_list = [
362
+ vtx_128,
363
+ vtx_256,
364
+ vtx_512,
365
+ vtx_1024,
366
+ ]
367
+
368
+ results, posterior, latent_128 = self.vae(
369
+ sparse_input,
370
+ gt_vertex_voxels_list=gt_vertex_voxels_list,
371
+ gt_edge_voxels_list=gt_edge_voxels_list,
372
+ training=True,
373
+ sample_ratio=0.,
374
+ )
375
+
376
+ total_loss = 0.
377
+ prune_loss_total = 0.
378
+ vertex_loss_total = 0.
379
+ edge_loss_total=0.
380
+
381
+ initial_result = results[0]
382
+ vertex_mask = initial_result['vertex_mask']
383
+ vtx_logits = initial_result['vtx_feats']
384
+ vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
385
+
386
+ edge_mask = initial_result['edge_mask']
387
+ edge_logits = initial_result['edge_feats']
388
+ edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
389
+ # edge_loss = self.bce_loss(edge_logits.squeeze(-1), edge_mask.float())
390
+
391
+ vertex_loss_total += vertex_loss
392
+ edge_loss_total += edge_loss
393
+
394
+ total_loss += vertex_loss
395
+ total_loss += edge_loss
396
+
397
+ # Process each level's results
398
+ for idx, res_dict in enumerate(results[1:], start=1):
399
+ # Vertex branch losses
400
+ vertex_pred_coords = res_dict['vertex']['occ_coords']
401
+ vertex_occ_probs = res_dict['vertex']['occ_probs']
402
+ vertex_gt_coords = res_dict['vertex']['coords']
403
+
404
+ vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=1024).float()
405
+ # print('vertex_labels.sum()', vertex_labels.sum(), idx)
406
+ vertex_logits = vertex_occ_probs.squeeze()
407
+
408
+ # if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
409
+ vertex_prune_loss = self.focal_loss(vertex_logits, vertex_labels)
410
+
411
+ prune_loss_total += vertex_prune_loss
412
+ total_loss += vertex_prune_loss
413
+
414
+
415
+ # Edge branch losses
416
+ edge_pred_coords = res_dict['edge']['occ_coords']
417
+ edge_occ_probs = res_dict['edge']['occ_probs']
418
+ edge_gt_coords = res_dict['edge']['coords']
419
+
420
+ edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=1024).float()
421
+ edge_logits = edge_occ_probs.squeeze()
422
+ edge_prune_loss = self.focal_loss(edge_logits, edge_labels)
423
+
424
+ prune_loss_total += edge_prune_loss
425
+ total_loss += edge_prune_loss
426
+
427
+ if idx == 3:
428
+ mse_loss_feats = torch.tensor(0., device=self.device)
429
+ mse_loss_dirs = torch.tensor(0., device=self.device)
430
+ # connection_loss = torch.tensor(0., device=self.device)
431
+
432
+ # --- Vertex Branch (Connection Loss 核心) ---
433
+ vtx_pred_coords = res_dict['vertex']['coords_4d'] # [N, 4]
434
+ vtx_pred_feats = res_dict['vertex']['feats'] # [N, C]
435
+
436
+ # 1.1 排序 (既用于匹配 GT,也用于快速寻找空间邻居)
437
+ vtx_pred_keys = flatten_coords_4d(vtx_pred_coords)
438
+ vtx_pred_keys_sorted, vtx_pred_order = torch.sort(vtx_pred_keys)
439
+
440
+ # 1.2 匹配 GT
441
+ vtx_gt_keys = flatten_coords_4d(gt_vertex_voxels_1024.to(self.device))
442
+ vtx_pos = torch.searchsorted(vtx_pred_keys_sorted, vtx_gt_keys)
443
+ vtx_pos = vtx_pos.clamp(max=len(vtx_pred_keys_sorted) - 1)
444
+ vtx_match_mask = (vtx_pred_keys_sorted[vtx_pos] == vtx_gt_keys)
445
+
446
+ gt_to_pred_mapping = torch.full((len(vtx_gt_keys),), -1, device=self.device, dtype=torch.long)
447
+ matched_pred_indices = vtx_pred_order[vtx_pos[vtx_match_mask]]
448
+ gt_to_pred_mapping[vtx_match_mask] = matched_pred_indices
449
+
450
+ # ====================================================
451
+ # 2. 构建核心数据:正样本 Hash 集合
452
+ # ====================================================
453
+ # 这里的 pos_u/pos_v 仅用于构建 "什么是真连接" 的查询表
454
+ u_gt, v_gt = gt_edges[:, 0], gt_edges[:, 1]
455
+ u_pred = gt_to_pred_mapping[u_gt]
456
+ v_pred = gt_to_pred_mapping[v_gt]
457
+
458
+ valid_edge_mask = (u_pred != -1) & (v_pred != -1)
459
+ real_pos_u = u_pred[valid_edge_mask]
460
+ real_pos_v = v_pred[valid_edge_mask]
461
+
462
+ num_real_pos = real_pos_u.shape[0]
463
+ num_total_nodes = vtx_pred_coords.shape[0]
464
+
465
+ if num_real_pos > 0:
466
+ # 2. 构建候选样本 (Candidate Generation)
467
+ # ====================================================
468
+ cand_u_list = []
469
+ cand_v_list = []
470
+
471
+ batch_ids = vtx_pred_coords[:, 0]
472
+ unique_batches = torch.unique(batch_ids)
473
+
474
+ RADIUS = 64
475
+ MAX_PTS_FOR_DIST = 12000
476
+ K_RANDOM = 32
477
+
478
+ for b_id in unique_batches:
479
+ mask_b = (batch_ids == b_id)
480
+ indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
481
+ coords_b = vtx_pred_coords[mask_b, 1:4].float() # (x,y,z)
482
+ num_b = coords_b.shape[0]
483
+
484
+ if num_b < 2: continue
485
+
486
+ # --- A. Radius Graph (Hard Negatives) ---
487
+ if num_b <= MAX_PTS_FOR_DIST:
488
+ # 计算距离矩阵 [M, M]
489
+ # 注意:autocast 下 float16 的 cdist 可能精度不够,建议转 float32
490
+ dist_mat = torch.cdist(coords_b.float(), coords_b.float())
491
+
492
+ # 找到距离小于 Radius 的点对 (排除自环)
493
+ adj_mat = (dist_mat < RADIUS) & (dist_mat > 1e-6)
494
+
495
+ # 提取索引 (local indices in batch)
496
+ src_local, dst_local = torch.nonzero(adj_mat, as_tuple=True)
497
+
498
+ # 映射回全局索引
499
+ cand_u_list.append(indices_b[src_local])
500
+ cand_v_list.append(indices_b[dst_local])
501
+ else:
502
+ print('num_b is big!')
503
+ pass
504
+
505
+ # --- B. Random Sampling (Easy Negatives) ---
506
+ # 随机生成 num_b * K 对
507
+ n_rand = num_b * K_RANDOM
508
+ rand_src_local = torch.randint(0, num_b, (n_rand,), device=self.device)
509
+ rand_dst_local = torch.randint(0, num_b, (n_rand,), device=self.device)
510
+
511
+ # 映射回全局索引
512
+ cand_u_list.append(indices_b[rand_src_local])
513
+ cand_v_list.append(indices_b[rand_dst_local])
514
+
515
+ # 合并所有来源 (GT + Radius + Random)
516
+ # 注意:我们把 real_pos 也加进来,确保正样本一定在列表里
517
+ all_u = torch.cat([real_pos_u] + cand_u_list)
518
+ all_v = torch.cat([real_pos_v] + cand_v_list)
519
+
520
+
521
+
522
+ # 3. 去重与 Labeling (Deduplication & Labeling)
523
+ # ====================================================
524
+ # 构造无向边 Hash: min * N + max
525
+ # 确保 MAX_NODES 足够大,比如 1000000 或 num_total_nodes
526
+ HASH_BASE = num_total_nodes + 100
527
+
528
+ p_min = torch.min(all_u, all_v)
529
+ p_max = torch.max(all_u, all_v)
530
+
531
+ # 过滤掉自环 (u==v)
532
+ valid_pair = (p_min != p_max)
533
+ p_min = p_min[valid_pair]
534
+ p_max = p_max[valid_pair]
535
+
536
+
537
+ all_hashes = p_min.long() * HASH_BASE + p_max.long()
538
+
539
+ # --- 核心:去重 ---
540
+ unique_hashes = torch.unique(all_hashes)
541
+
542
+ # 解码回 u, v
543
+ final_u = unique_hashes // HASH_BASE
544
+ final_v = unique_hashes % HASH_BASE
545
+
546
+ # --- Labeling ---
547
+ # 构建 GT 的 Hash 表用于查询
548
+ gt_min = torch.min(real_pos_u, real_pos_v)
549
+ gt_max = torch.max(real_pos_u, real_pos_v)
550
+ gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
551
+ gt_hashes = torch.unique(gt_hashes) # GT 也去重一下保险
552
+ gt_hashes_sorted, _ = torch.sort(gt_hashes)
553
+
554
+ # 查询 unique_hashes 是否在 gt_hashes 中
555
+ # 使用 searchsorted
556
+ idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
557
+ idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
558
+ is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
559
+
560
+ targets = is_connected.float().unsqueeze(-1) # [N_pairs, 1]
561
+
562
+ # ============================================================
563
+ # 4. 前向传播与 Loss
564
+ # ====================================================
565
+ feat_u = vtx_pred_feats[final_u]
566
+ feat_v = vtx_pred_feats[final_v]
567
+
568
+ # 对称特征融合
569
+ feat_uv = torch.cat([feat_u, feat_v], dim=-1)
570
+ feat_vu = torch.cat([feat_v, feat_u], dim=-1)
571
+
572
+ logits_uv = self.connection_head(feat_uv)
573
+ logits_vu = self.connection_head(feat_vu)
574
+ logits = (logits_uv + logits_vu) / 2.
575
+
576
+ connection_loss = self.asyloss(logits, targets)
577
+ total_loss += connection_loss
578
+ else:
579
+ connection_loss = torch.tensor(0., device=self.device)
580
+
581
+
582
+ # KL loss
583
+ kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
584
+ total_loss += kl_loss
585
+
586
+ # Backpropagation
587
+ scaled_total_loss = total_loss / self.accum_steps
588
+ # self.scaler.scale(scaled_total_loss).backward()
589
+ scaled_total_loss.backward()
590
+
591
+ return {
592
+ 'total_loss': total_loss.item(),
593
+ 'kl_loss': kl_loss.item(),
594
+ 'prune_loss': prune_loss_total.item(),
595
+ 'vertex_loss': vertex_loss_total.item(),
596
+ 'edge_loss': edge_loss_total.item(),
597
+ 'offset_loss': mse_loss_feats.item(),
598
+ 'direction_loss': mse_loss_dirs.item(),
599
+ 'connection_loss': connection_loss.item(),
600
+ }
601
+
602
+
603
+ def train(self):
604
+ accum_steps = self.accum_steps
605
+ for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
606
+ self.dataloader.sampler.set_epoch(epoch)
607
+ # Initialize metrics
608
+ metrics = {
609
+ 'total_loss': 0.0,
610
+ 'kl_loss': 0.0,
611
+ 'prune_loss': 0.0,
612
+ 'vertex_loss': 0.0,
613
+ 'edge_loss': 0.0,
614
+ 'offset_loss': 0.0,
615
+ 'direction_loss': 0.0,
616
+ 'connection_loss': 0.0,
617
+ }
618
+ num_batches = 0
619
+ self.optimizer.zero_grad(set_to_none=True)
620
+
621
+ for i, batch in enumerate(self.dataloader):
622
+ # Get all losses from train_step
623
+ if batch is None:
624
+ continue
625
+ step_losses = self.train_step(batch, i)
626
+
627
+ # Accumulate losses
628
+ for key in metrics:
629
+ metrics[key] += step_losses[key]
630
+
631
+ num_batches += 1
632
+
633
+ if (i + 1) % accum_steps == 0:
634
+ # self.scaler.unscale_(self.optimizer)
635
+ # torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
636
+ # torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
637
+ # torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
638
+
639
+ # self.scaler.step(self.optimizer)
640
+ # self.scaler.update()
641
+ # self.optimizer.zero_grad(set_to_none=True)
642
+
643
+ # self.scheduler.step()
644
+
645
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
646
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
647
+ torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
648
+ self.optimizer.step()
649
+ self.optimizer.zero_grad(set_to_none=True)
650
+ self.scheduler.step()
651
+
652
+
653
+ # Print batch-level metrics
654
+ if self.is_master:
655
+ avg_metric = {key: value / num_batches for key, value in metrics.items()}
656
+ print(
657
+ f"[Epoch {epoch}] Batch:{num_batches} "
658
+ f"AvgL:{avg_metric['total_loss']:.4f} "
659
+ f"Loss: {step_losses['total_loss']:.4f}, "
660
+ f"KLL: {step_losses['kl_loss']:.4f}, "
661
+ f"PruneL: {step_losses['prune_loss']:.4f}, "
662
+ f"VertexL: {step_losses['vertex_loss']:.4f}, "
663
+ f"EdgeL: {step_losses['edge_loss']:.4f}, "
664
+ f"OffsetL: {step_losses['offset_loss']:.4f}, "
665
+ f"DireL: {step_losses['direction_loss']:.4f}, "
666
+ f"ConL: {step_losses['connection_loss']:.4f}, "
667
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
668
+ )
669
+
670
+ if i % 2000 == 0 and i != 0:
671
+ self.save_checkpoint(epoch, avg_metric['total_loss'], i)
672
+ with open(self.log_file, "a") as f:
673
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
674
+ log_line = (
675
+ f"Epoch {epoch:05d} | "
676
+ f"Batch {i:05d} | "
677
+ f"Loss: {avg_metric['total_loss']:.6f} "
678
+ f"Avg KLL: {avg_metric['kl_loss']:.4f} "
679
+ f"Avg PruneL: {avg_metric['prune_loss']:.4f} "
680
+ f"Avg VertexL: {avg_metric['vertex_loss']:.4f} "
681
+ f"Avg EdgeL: {avg_metric['edge_loss']:.4f} "
682
+ f"Avg OffsetL: {avg_metric['offset_loss']:.4f} "
683
+ f"Avg DireL: {avg_metric['direction_loss']:.4f} "
684
+ f"Avg ConL: {avg_metric['connection_loss']:.4f} "
685
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
686
+ f"[{current_time}]\n"
687
+ )
688
+ f.write(log_line)
689
+
690
+ if num_batches % accum_steps != 0:
691
+ # self.scaler.unscale_(self.optimizer)
692
+ # torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
693
+ # torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
694
+ # torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
695
+
696
+ # self.scaler.step(self.optimizer)
697
+ # self.scaler.update()
698
+ # self.optimizer.zero_grad(set_to_none=True)
699
+
700
+ # self.scheduler.step()
701
+
702
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
703
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
704
+ torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
705
+
706
+ self.optimizer.step()
707
+ self.optimizer.zero_grad(set_to_none=True)
708
+ self.scheduler.step()
709
+
710
+ # Calculate epoch averages
711
+ avg_metrics = {key: value / num_batches for key, value in metrics.items()}
712
+ self.train_loss_history.append(avg_metrics['total_loss'])
713
+
714
+
715
+ # Log to file
716
+ if self.is_master:
717
+ with open(self.log_file, "a") as f:
718
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
719
+ log_line = (
720
+ f"Epoch {epoch:05d} | "
721
+ f"Loss: {avg_metrics['total_loss']:.6f} "
722
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
723
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
724
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
725
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
726
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
727
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
728
+ f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
729
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
730
+ f"[{current_time}]\n"
731
+ )
732
+ f.write(log_line)
733
+
734
+ # Print epoch summary
735
+ print(
736
+ f"[Epoch {epoch}] "
737
+ f"Avg Loss: {avg_metrics['total_loss']:.4f} "
738
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
739
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
740
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
741
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
742
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
743
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
744
+ f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
745
+ f"[{current_time}]\n"
746
+ )
747
+
748
+ # Save checkpoint
749
+ if epoch % self.cfg['training']['save_every'] == 0:
750
+ self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
751
+
752
+ # Update learning rate
753
+ if self.is_master:
754
+ current_lr = self.optimizer.param_groups[0]['lr']
755
+ print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
756
+
757
+ dist.barrier()
758
+
759
+
760
+ def main():
761
+ # Initialize the process group
762
+ dist.init_process_group(backend='nccl')
763
+
764
+ # Get rank and world size from environment variables set by the launcher
765
+ rank = int(os.environ['RANK'])
766
+ world_size = int(os.environ['WORLD_SIZE'])
767
+ local_rank = int(os.environ['LOCAL_RANK'])
768
+
769
+ # Set the device for the current process. This is crucial.
770
+ torch.cuda.set_device(local_rank)
771
+ torch.manual_seed(42+rank)
772
+
773
+ # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
774
+ # Pass the distributed info to the Trainer
775
+ trainer = Trainer(
776
+ config_path="/gemini/user/private/zhaotianhao/Triposf/config_edge_1024_error_8enc_8dec_woself_finetune_128to1024_addhead.yaml",
777
+ rank=rank,
778
+ world_size=world_size,
779
+ local_rank=local_rank
780
+ )
781
+ trainer.train()
782
+
783
+ # Clean up the process group
784
+ dist.destroy_process_group()
785
+
786
+
787
+ if __name__ == '__main__':
788
+ main()
train_slat_vae_512_128to1024_pointnet_head.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import os
3
+ # os.environ['ATTN_BACKEND'] = 'xformers'
4
+ import yaml
5
+ import torch
6
+ import time
7
+ from datetime import datetime
8
+ from torch.utils.data import DataLoader
9
+ from functools import partial
10
+ from triposf.modules.sparse.basic import SparseTensor
11
+ import torch.nn.functional as F
12
+ from torch.optim import AdamW
13
+ from torch.cuda.amp import GradScaler, autocast
14
+
15
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
16
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
17
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
18
+
19
+ from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss, FocalLoss
20
+
21
+ import torch.distributed as dist
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+ from torch.utils.data.distributed import DistributedSampler
24
+
25
+ from transformers import get_cosine_schedule_with_warmup
26
+ import math
27
+
28
+ from torchvision.ops import sigmoid_focal_loss
29
+
30
+ import numpy as np
31
+ import open3d as o3d
32
+
33
+ def export_sampled_edges(coords, u, v, labels, step_idx, save_dir="debug_viz", batch_idx_to_viz=0):
34
+ """
35
+ 导出采样边为 PLY 文件。
36
+
37
+ Args:
38
+ coords: [N, 4] Tensor (batch_idx, x, y, z)
39
+ u: [E] Tensor, 起点索引 (global index)
40
+ v: [E] Tensor, 终点索引 (global index)
41
+ labels: [E, 1] Tensor, 1.0 为正样本, 0.0 为负样本
42
+ step_idx: 当前步数或 epoch,用于文件名
43
+ save_dir: 保存目录
44
+ batch_idx_to_viz: 只可视化哪个 batch 的数据 (防止多个 batch 叠加在一起看不清)
45
+ """
46
+ os.makedirs(save_dir, exist_ok=True)
47
+
48
+ # 1. 转为 CPU numpy
49
+ coords_np = coords.detach().cpu().numpy()
50
+ u_np = u.detach().cpu().numpy()
51
+ v_np = v.detach().cpu().numpy()
52
+ labels_np = labels.detach().cpu().numpy().reshape(-1)
53
+
54
+ # 2. 筛选特定 Batch (通常只看 Batch 0 比较清晰)
55
+ # coords 的第0列是 batch index
56
+ batch_mask = (coords_np[:, 0] == batch_idx_to_viz)
57
+
58
+ # 获取属于该 batch 的全局索引范围
59
+ # 注意:u 和 v 是针对所有 coords 的全局索引。
60
+ # 我们需要判断一条边的两个端点是否都在这个 batch 内。
61
+
62
+ # 快速检查端点是否在当前 batch
63
+ valid_u_in_batch = batch_mask[u_np]
64
+ valid_v_in_batch = batch_mask[v_np]
65
+ edge_batch_mask = valid_u_in_batch & valid_v_in_batch
66
+
67
+ if edge_batch_mask.sum() == 0:
68
+ print(f"Warning: No edges found for batch {batch_idx_to_viz}")
69
+ return
70
+
71
+ # 应用 Batch 筛选
72
+ u_b = u_np[edge_batch_mask]
73
+ v_b = v_np[edge_batch_mask]
74
+ labels_b = labels_np[edge_batch_mask]
75
+
76
+ # 3. 提取该 Batch 的顶点坐标 (去掉 batch_idx 维度)
77
+ # 此时我们需要重新映射 u, v 的索引,因为我们要只保存该 batch 的点
78
+ batch_indices_global = np.where(batch_mask)[0]
79
+ # 创建全局索引到局部索引的映射表
80
+ global_to_local = {gid: lid for lid, gid in enumerate(batch_indices_global)}
81
+
82
+ points_xyz = coords_np[batch_indices_global, 1:4] # [M, 3]
83
+
84
+ # 转换 u, v 为局部索引
85
+ try:
86
+ u_local = np.array([global_to_local[idx] for idx in u_b])
87
+ v_local = np.array([global_to_local[idx] for idx in v_b])
88
+ except KeyError:
89
+ print("Error in index mapping. Edge endpoints might cross batches.")
90
+ return
91
+
92
+ # 4. 分离正负样本
93
+ pos_mask = labels_b > 0.5
94
+ neg_mask = ~pos_mask
95
+
96
+ # 内部函数:写 PLY
97
+ def write_ply(filename, points, edges_u, edges_v, color_rgb):
98
+ num_verts = len(points)
99
+ num_edges = len(edges_u)
100
+
101
+ with open(filename, 'w') as f:
102
+ f.write("ply\n")
103
+ f.write("format ascii 1.0\n")
104
+ f.write(f"element vertex {num_verts}\n")
105
+ f.write("property float x\n")
106
+ f.write("property float y\n")
107
+ f.write("property float z\n")
108
+ f.write("property uchar red\n")
109
+ f.write("property uchar green\n")
110
+ f.write("property uchar blue\n")
111
+ f.write(f"element edge {num_edges}\n")
112
+ f.write("property int vertex1\n")
113
+ f.write("property int vertex2\n")
114
+ f.write("end_header\n")
115
+
116
+ # Write Vertices with Color
117
+ # 为了让可视化更清楚,我们将所有点染成指定颜色
118
+ for i in range(num_verts):
119
+ x, y, z = points[i]
120
+ f.write(f"{x:.4f} {y:.4f} {z:.4f} {color_rgb[0]} {color_rgb[1]} {color_rgb[2]}\n")
121
+
122
+ # Write Edges
123
+ for i in range(num_edges):
124
+ f.write(f"{edges_u[i]} {edges_v[i]}\n")
125
+
126
+ print(f"Saved: {filename} (Edges: {num_edges})")
127
+
128
+ # 5. 保存正样本 (绿色)
129
+ if pos_mask.sum() > 0:
130
+ write_ply(
131
+ os.path.join(save_dir, f"step_{step_idx}_pos_edges.ply"),
132
+ points_xyz, # 使用所有点,或者优化为只使用涉及的点(这里为了坐标统一简单起见使用所有点)
133
+ u_local[pos_mask],
134
+ v_local[pos_mask],
135
+ color_rgb=(0, 255, 0) # Green
136
+ )
137
+
138
+ # 6. 保存负样本 (红色)
139
+ if neg_mask.sum() > 0:
140
+ # 为了避免文件太大或太乱,如果负样本特别多,可以考虑随机采样一部分保存
141
+ # 这里默认全部保存
142
+ write_ply(
143
+ os.path.join(save_dir, f"step_{step_idx}_neg_edges.ply"),
144
+ points_xyz,
145
+ u_local[neg_mask],
146
+ v_local[neg_mask],
147
+ color_rgb=(255, 0, 0) # Red
148
+ )
149
+
150
+
151
+ def flatten_coords_4d(coords_4d: torch.Tensor):
152
+ coords_4d_long = coords_4d.long()
153
+
154
+ base_x = 1024
155
+ base_y = 1024 * 1024
156
+ base_z = 1024 * 1024 * 1024
157
+
158
+ flat_coords = coords_4d_long[:, 0] * base_z + \
159
+ coords_4d_long[:, 1] * base_y + \
160
+ coords_4d_long[:, 2] * base_x + \
161
+ coords_4d_long[:, 3]
162
+ return flat_coords
163
+
164
+ def downsample_voxels(
165
+ voxels: torch.Tensor,
166
+ input_resolution: int,
167
+ output_resolution: int
168
+ ) -> torch.Tensor:
169
+ if input_resolution % output_resolution != 0:
170
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
171
+ f"by output_resolution ({output_resolution}).")
172
+
173
+ factor = input_resolution // output_resolution
174
+
175
+ downsampled_voxels = voxels.clone().to(torch.long)
176
+
177
+ downsampled_voxels[:, 1:] //= factor
178
+
179
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
180
+ return unique_downsampled_voxels
181
+
182
+ class Trainer:
183
+ def __init__(self, config_path, rank, world_size, local_rank):
184
+ self.rank = rank
185
+ self.world_size = world_size
186
+ self.local_rank = local_rank
187
+ self.is_master = self.rank == 0
188
+
189
+ self.load_config(config_path)
190
+ self.accum_steps = max(1, 4 // self.cfg['training']['batch_size'])
191
+
192
+ self.config_hash = self.save_config_with_hash()
193
+ self.init_device()
194
+ self.init_dirs()
195
+ self.init_components()
196
+ self.init_training()
197
+
198
+ self.train_loss_history = []
199
+ self.eval_loss_history = []
200
+ self.best_eval_loss = float('inf')
201
+
202
+ def save_config_with_hash(self):
203
+ import hashlib
204
+
205
+ # Serialize config to hash
206
+ config_str = yaml.dump(self.cfg)
207
+ config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
208
+
209
+ # Prepare all flags as string for formatting
210
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
211
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
212
+
213
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
214
+
215
+ # Format save_dir with all placeholders
216
+ self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
217
+ dataset_name=dataset_name,
218
+ config_hash=config_hash,
219
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
220
+ multires=self.cfg['model']['multires'],
221
+ add_block_embed=add_block_embed_flag,
222
+ using_attn=using_attn_flag,
223
+ batch_size=self.cfg['training']['batch_size'],
224
+ )
225
+
226
+ if self.is_master:
227
+ os.makedirs(self.save_dir, exist_ok=True)
228
+ config_path = os.path.join(self.save_dir, "config.yaml")
229
+ with open(config_path, 'w') as f:
230
+ yaml.dump(self.cfg, f)
231
+
232
+ dist.barrier()
233
+ return config_hash
234
+
235
+ def save_checkpoint(self, epoch, avg_loss, batch_idx):
236
+ if not self.is_master:
237
+ return
238
+ checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
239
+ config_path = os.path.join(self.save_dir, "config.yaml")
240
+
241
+ torch.save({
242
+ 'voxel_encoder': self.voxel_encoder.module.state_dict(),
243
+ 'vae': self.vae.module.state_dict(),
244
+ 'connection_head': self.connection_head.module.state_dict(),
245
+ 'epoch': epoch,
246
+ 'loss': avg_loss,
247
+ 'config': self.cfg
248
+ }, checkpoint_path)
249
+
250
+ def quoted_presenter(dumper, data):
251
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
252
+
253
+ yaml.add_representer(str, quoted_presenter)
254
+
255
+ with open(config_path, 'w') as f:
256
+ yaml.dump(self.cfg, f)
257
+
258
+ def load_config(self, config_path):
259
+ with open(config_path) as f:
260
+ self.cfg = yaml.safe_load(f)
261
+
262
+ # Extract and convert flags for formatting
263
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
264
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
265
+
266
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
267
+
268
+ self.save_dir = self.cfg['experiment']['save_dir'].format(
269
+ dataset_name=dataset_name,
270
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
271
+ multires=self.cfg['model']['multires'],
272
+ add_block_embed=add_block_embed_flag,
273
+ using_attn=using_attn_flag,
274
+ batch_size=self.cfg['training']['batch_size'],
275
+ )
276
+
277
+ if self.is_master:
278
+ os.makedirs(self.save_dir, exist_ok=True)
279
+ dist.barrier()
280
+
281
+
282
+ def init_device(self):
283
+ self.device = torch.device(f"cuda:{self.local_rank}")
284
+
285
+ def init_dirs(self):
286
+ self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
287
+ if self.is_master:
288
+ with open(self.log_file, "a") as f:
289
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
290
+ f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
291
+
292
+ def init_components(self):
293
+ self.dataset = VoxelVertexDataset_edge(
294
+ root_dir=self.cfg['dataset']['path'],
295
+ base_resolution=self.cfg['dataset']['base_resolution'],
296
+ min_resolution=self.cfg['dataset']['min_resolution'],
297
+ cache_dir=self.cfg['dataset']['cache_dir'],
298
+ renders_dir=self.cfg['dataset']['renders_dir'],
299
+
300
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
301
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
302
+
303
+ active_voxel_res=128,
304
+ pc_sample_number=819200,
305
+
306
+ sample_type=self.cfg['dataset']['sample_type'],
307
+ )
308
+
309
+ self.sampler = DistributedSampler(
310
+ self.dataset,
311
+ num_replicas=self.world_size,
312
+ rank=self.rank,
313
+ shuffle=True,
314
+ )
315
+
316
+ self.dataloader = DataLoader(
317
+ self.dataset,
318
+ batch_size=self.cfg['training']['batch_size'],
319
+ shuffle=False,
320
+ collate_fn=partial(collate_fn_pointnet,),
321
+ num_workers=self.cfg['training']['num_workers'],
322
+ pin_memory=True,
323
+ sampler=self.sampler,
324
+ prefetch_factor=4,
325
+ persistent_workers=True,
326
+ )
327
+
328
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
329
+ in_channels=15,
330
+ hidden_dim=256,
331
+ out_channels=1024,
332
+ scatter_type='mean',
333
+ n_blocks=5,
334
+ resolution=128,
335
+ add_label=False,
336
+ ).to(self.device)
337
+
338
+ self.connection_head = ConnectionHead(
339
+ channels=32 * 2,
340
+ out_channels=1,
341
+ mlp_ratio=16,
342
+ ).to(self.device)
343
+
344
+ # self.connection_head = ConnectionHead(
345
+ # channels=64 * 2,
346
+ # out_channels=1,
347
+ # mlp_ratio=8,
348
+ # ).to(self.device)
349
+
350
+ # ablation 3: voxelvae_1volume, have tested
351
+ self.vae = VoxelVAE(
352
+ in_channels=self.cfg['model']['in_channels'],
353
+ latent_dim=self.cfg['model']['latent_dim'],
354
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
355
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
356
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
357
+ num_heads=8,
358
+ num_head_channels=64,
359
+ mlp_ratio=4.0,
360
+ attn_mode="swin",
361
+ window_size=8,
362
+ pe_mode="ape",
363
+ use_fp16=False,
364
+ use_checkpoint=True,
365
+ qk_rms_norm=False,
366
+ using_subdivide=True,
367
+ using_attn=self.cfg['model']['using_attn'],
368
+ attn_first=self.cfg['model'].get('attn_first', True),
369
+ pred_direction=self.cfg['model'].get('pred_direction', False),
370
+ ).to(self.device)
371
+
372
+ if self.cfg['training']['from_pretrained']:
373
+ load_pretrained_woself(
374
+ checkpoint_path=self.cfg['training']['checkpoint_path'],
375
+ voxel_encoder=self.voxel_encoder,
376
+ vae=self.vae,
377
+ connection_head=self.connection_head,
378
+ optimizer=None,
379
+ )
380
+
381
+ self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
382
+ self.connection_head = DDP(self.connection_head, device_ids=[self.local_rank], find_unused_parameters=False)
383
+ self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=False)
384
+
385
+ def init_training(self):
386
+ self.optimizer = AdamW(
387
+ list(self.vae.module.parameters()) +
388
+ list(self.voxel_encoder.module.parameters()) +
389
+ list(self.connection_head.module.parameters()),
390
+ lr=self.cfg['training']['lr'],
391
+ weight_decay=0.01,
392
+ )
393
+
394
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
395
+ max_epochs = self.cfg['training']['max_epochs']
396
+ num_training_steps = max_epochs * num_update_steps_per_epoch
397
+
398
+ num_warmup_steps = 200
399
+
400
+ self.scheduler = get_cosine_schedule_with_warmup(
401
+ self.optimizer,
402
+ num_warmup_steps=num_warmup_steps,
403
+ num_training_steps=num_training_steps
404
+ )
405
+
406
+ self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=10.0).to(self.device)
407
+ # self.focal_loss = FocalLoss(gamma=2, alpha=0.6)
408
+ self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
409
+ self.asyloss = AsymmetricFocalLoss(
410
+ gamma_pos=0.0,
411
+ gamma_neg=4.0,
412
+ clip=0.05,
413
+ )
414
+
415
+ self.bce_loss = torch.nn.BCEWithLogitsLoss()
416
+
417
+ self.dice_loss = DiceLoss()
418
+ self.scaler = GradScaler()
419
+
420
+ def train_step(self, batch, b_idx):
421
+ """Modified training step that handles vertex and edge voxels separately after initial prediction."""
422
+ # 1. Retrieve data from batch
423
+ combined_voxels_1024 = batch['combined_voxels_1024'].to(self.device)
424
+ combined_voxel_labels_1024 = batch['combined_voxel_labels_1024'].to(self.device)
425
+ gt_vertex_voxels_1024 = batch['gt_vertex_voxels_1024'].to(self.device)
426
+
427
+ # gt_edge_voxels_1024 = batch['gt_edge_voxels_1024'].to(self.device)
428
+ # gt_combined_endpoints_1024 = batch['gt_combined_endpoints_1024'].to(self.device)
429
+ # gt_combined_errors_1024 = batch['gt_combined_errors_1024'].to(self.device)
430
+
431
+ # gt_edges = batch['gt_vertex_edge_indices_256'].to(self.device)
432
+ gt_edges = batch['gt_vertex_edge_indices_1024'].to(self.device)
433
+
434
+ edge_mask = (combined_voxel_labels_1024 == 1)
435
+
436
+ # gt_edge_endpoints_1024 = gt_combined_endpoints_1024[edge_mask]
437
+ # gt_edge_errors_1024 = gt_combined_errors_1024[edge_mask]
438
+
439
+ # p1 = gt_edge_endpoints_1024[:, 1:4].float()
440
+ # p2 = gt_edge_endpoints_1024[:, 4:7].float()
441
+
442
+ # mask = ( (p1[:,0] < p2[:,0]) |
443
+ # ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
444
+ # ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
445
+
446
+ # pA = torch.where(mask[:, None], p1, p2) # smaller one
447
+ # pB = torch.where(mask[:, None], p2, p1) # larger one
448
+
449
+ # d = pB - pA
450
+ # dir_gt = F.normalize(d, dim=-1, eps=1e-6)
451
+ vtx_128 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=128)
452
+ vtx_256 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=256)
453
+ vtx_512 = downsample_voxels(gt_vertex_voxels_1024, input_resolution=1024, output_resolution=512)
454
+ vtx_1024 = gt_vertex_voxels_1024
455
+
456
+ edge_128 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=128)
457
+ edge_256 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=256)
458
+ edge_512 = downsample_voxels(combined_voxels_1024, input_resolution=1024, output_resolution=512)
459
+ edge_1024 = combined_voxels_1024
460
+
461
+ active_coords = batch['active_voxels_128'].to(self.device)
462
+ point_cloud = batch['point_cloud_128'].to(self.device)
463
+
464
+ with autocast(dtype=torch.bfloat16):
465
+ active_voxel_feats = self.voxel_encoder(
466
+ p=point_cloud,
467
+ sparse_coords=active_coords,
468
+ res=128,
469
+ bbox_size=(-0.5, 0.5),
470
+ )
471
+
472
+ sparse_input = SparseTensor(
473
+ feats=active_voxel_feats,
474
+ coords=active_coords.int()
475
+ )
476
+
477
+ gt_edge_voxels_list = [
478
+ edge_128,
479
+ edge_256,
480
+ edge_512,
481
+ edge_1024,
482
+ ]
483
+
484
+ gt_vertex_voxels_list = [
485
+ vtx_128,
486
+ vtx_256,
487
+ vtx_512,
488
+ vtx_1024,
489
+ ]
490
+
491
+ results, posterior, latent_128 = self.vae(
492
+ sparse_input,
493
+ gt_vertex_voxels_list=gt_vertex_voxels_list,
494
+ gt_edge_voxels_list=gt_edge_voxels_list,
495
+ training=True,
496
+ sample_ratio=0.,
497
+ )
498
+
499
+ # print("results[-1]['edge']['coords_4d'][1827:1830]", results[-1]['edge']['coords_4d'][1827:1830])
500
+ total_loss = 0.
501
+ prune_loss_total = 0.
502
+ vertex_loss_total = 0.
503
+ edge_loss_total=0.
504
+
505
+ initial_result = results[0]
506
+ vertex_mask = initial_result['vertex_mask']
507
+ vtx_logits = initial_result['vtx_feats']
508
+ vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
509
+
510
+ edge_mask = initial_result['edge_mask']
511
+ edge_logits = initial_result['edge_feats']
512
+ edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
513
+ # edge_loss = self.bce_loss(edge_logits.squeeze(-1), edge_mask.float())
514
+
515
+ vertex_loss_total += vertex_loss
516
+ edge_loss_total += edge_loss
517
+
518
+ total_loss += vertex_loss
519
+ total_loss += edge_loss
520
+
521
+ # Process each level's results
522
+ for idx, res_dict in enumerate(results[1:], start=1):
523
+ # Vertex branch losses
524
+ vertex_pred_coords = res_dict['vertex']['occ_coords']
525
+ vertex_occ_probs = res_dict['vertex']['occ_probs']
526
+ vertex_gt_coords = res_dict['vertex']['coords']
527
+
528
+ vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=1024).float()
529
+ # print('vertex_labels.sum()', vertex_labels.sum(), idx)
530
+ vertex_logits = vertex_occ_probs.squeeze()
531
+
532
+ # if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
533
+ vertex_prune_loss = self.asyloss(vertex_logits, vertex_labels)
534
+
535
+ prune_loss_total += vertex_prune_loss
536
+ total_loss += vertex_prune_loss
537
+
538
+
539
+ # Edge branch losses
540
+ edge_pred_coords = res_dict['edge']['occ_coords']
541
+ edge_occ_probs = res_dict['edge']['occ_probs']
542
+ edge_gt_coords = res_dict['edge']['coords']
543
+
544
+ edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=1024).float()
545
+ edge_logits = edge_occ_probs.squeeze()
546
+ edge_prune_loss = self.asyloss(edge_logits, edge_labels)
547
+
548
+ prune_loss_total += edge_prune_loss
549
+ total_loss += edge_prune_loss
550
+
551
+ if idx == 3:
552
+ mse_loss_feats = torch.tensor(0., device=self.device)
553
+ mse_loss_dirs = torch.tensor(0., device=self.device)
554
+ # connection_loss = torch.tensor(0., device=self.device)
555
+
556
+ # --- Vertex Branch (Connection Loss 核心) ---
557
+ vtx_pred_coords = res_dict['vertex']['coords_4d'] # [N, 4]
558
+ vtx_pred_feats = res_dict['vertex']['feats'] # [N, C]
559
+
560
+ # 1.1 排序 (既用于匹配 GT,也用于快速寻找空间邻居)
561
+ vtx_pred_keys = flatten_coords_4d(vtx_pred_coords)
562
+ vtx_pred_keys_sorted, vtx_pred_order = torch.sort(vtx_pred_keys)
563
+
564
+ # 1.2 匹配 GT
565
+ vtx_gt_keys = flatten_coords_4d(gt_vertex_voxels_1024.to(self.device))
566
+ vtx_pos = torch.searchsorted(vtx_pred_keys_sorted, vtx_gt_keys)
567
+ vtx_pos = vtx_pos.clamp(max=len(vtx_pred_keys_sorted) - 1)
568
+ vtx_match_mask = (vtx_pred_keys_sorted[vtx_pos] == vtx_gt_keys)
569
+
570
+ gt_to_pred_mapping = torch.full((len(vtx_gt_keys),), -1, device=self.device, dtype=torch.long)
571
+ matched_pred_indices = vtx_pred_order[vtx_pos[vtx_match_mask]]
572
+ gt_to_pred_mapping[vtx_match_mask] = matched_pred_indices
573
+
574
+ # ====================================================
575
+ # 2. 构建核心数据:正样本 Hash 集合
576
+ # ====================================================
577
+ # 这里的 pos_u/pos_v 仅用于构建 "什么是真连接" 的查询表
578
+ u_gt, v_gt = gt_edges[:, 0], gt_edges[:, 1]
579
+ u_pred = gt_to_pred_mapping[u_gt]
580
+ v_pred = gt_to_pred_mapping[v_gt]
581
+
582
+ valid_edge_mask = (u_pred != -1) & (v_pred != -1)
583
+ real_pos_u = u_pred[valid_edge_mask]
584
+ real_pos_v = v_pred[valid_edge_mask]
585
+
586
+ num_real_pos = real_pos_u.shape[0]
587
+ num_total_nodes = vtx_pred_coords.shape[0]
588
+
589
+ if num_real_pos > 0:
590
+ # 2. 构建候选样本 (Candidate Generation)
591
+ # ====================================================
592
+ cand_u_list = []
593
+ cand_v_list = []
594
+
595
+ batch_ids = vtx_pred_coords[:, 0]
596
+ unique_batches = torch.unique(batch_ids)
597
+
598
+ RADIUS = 64
599
+ MAX_PTS_FOR_DIST = 12000
600
+ K_RANDOM = 32
601
+
602
+ for b_id in unique_batches:
603
+ mask_b = (batch_ids == b_id)
604
+ indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
605
+ coords_b = vtx_pred_coords[mask_b, 1:4].float() # (x,y,z)
606
+ num_b = coords_b.shape[0]
607
+
608
+ if num_b < 2: continue
609
+
610
+ # --- A. Radius Graph (Hard Negatives) ---
611
+ if num_b <= MAX_PTS_FOR_DIST:
612
+ # 计算距离矩阵 [M, M]
613
+ # 注意:autocast 下 float16 ��� cdist 可能精度不够,建议转 float32
614
+ dist_mat = torch.cdist(coords_b.float(), coords_b.float())
615
+
616
+ # 找到距离小于 Radius 的点对 (排除自环)
617
+ adj_mat = (dist_mat < RADIUS) & (dist_mat > 1e-6)
618
+
619
+ # 提取索引 (local indices in batch)
620
+ src_local, dst_local = torch.nonzero(adj_mat, as_tuple=True)
621
+
622
+ # 映射回全局索引
623
+ cand_u_list.append(indices_b[src_local])
624
+ cand_v_list.append(indices_b[dst_local])
625
+ else:
626
+ print('num_b is big!')
627
+ pass
628
+
629
+ # --- B. Random Sampling (Easy Negatives) ---
630
+ # 随机生成 num_b * K 对
631
+ n_rand = num_b * K_RANDOM
632
+ rand_src_local = torch.randint(0, num_b, (n_rand,), device=self.device)
633
+ rand_dst_local = torch.randint(0, num_b, (n_rand,), device=self.device)
634
+
635
+ # 映射回全局索引
636
+ cand_u_list.append(indices_b[rand_src_local])
637
+ cand_v_list.append(indices_b[rand_dst_local])
638
+
639
+ # 合并所有来源 (GT + Radius + Random)
640
+ # 注意:我们把 real_pos 也加进来,确保正样本一定在列表里
641
+ all_u = torch.cat([real_pos_u] + cand_u_list)
642
+ all_v = torch.cat([real_pos_v] + cand_v_list)
643
+
644
+
645
+
646
+ # 3. 去重与 Labeling (Deduplication & Labeling)
647
+ # ====================================================
648
+ # 构造无向边 Hash: min * N + max
649
+ # 确保 MAX_NODES 足够大,比如 1000000 或 num_total_nodes
650
+ HASH_BASE = num_total_nodes + 100
651
+
652
+ p_min = torch.min(all_u, all_v)
653
+ p_max = torch.max(all_u, all_v)
654
+
655
+ # 过滤掉自环 (u==v)
656
+ valid_pair = (p_min != p_max)
657
+ p_min = p_min[valid_pair]
658
+ p_max = p_max[valid_pair]
659
+
660
+
661
+ all_hashes = p_min.long() * HASH_BASE + p_max.long()
662
+
663
+ # --- 核心:去重 ---
664
+ unique_hashes = torch.unique(all_hashes)
665
+
666
+ # 解码回 u, v
667
+ final_u = unique_hashes // HASH_BASE
668
+ final_v = unique_hashes % HASH_BASE
669
+
670
+ # --- Labeling ---
671
+ # 构建 GT 的 Hash 表用于查询
672
+ gt_min = torch.min(real_pos_u, real_pos_v)
673
+ gt_max = torch.max(real_pos_u, real_pos_v)
674
+ gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
675
+ gt_hashes = torch.unique(gt_hashes) # GT 也去重一下保险
676
+ gt_hashes_sorted, _ = torch.sort(gt_hashes)
677
+
678
+ # 查询 unique_hashes 是否在 gt_hashes 中
679
+ # 使用 searchsorted
680
+ idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
681
+ idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
682
+ is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
683
+
684
+ targets = is_connected.float().unsqueeze(-1) # [N_pairs, 1]
685
+
686
+ # ============================================================
687
+ # 4. 前向传播与 Loss
688
+ # ====================================================
689
+ feat_u = vtx_pred_feats[final_u]
690
+ feat_v = vtx_pred_feats[final_v]
691
+
692
+ # 对称特征融合
693
+ feat_uv = torch.cat([feat_u, feat_v], dim=-1)
694
+ feat_vu = torch.cat([feat_v, feat_u], dim=-1)
695
+
696
+ logits_uv = self.connection_head(feat_uv)
697
+ logits_vu = self.connection_head(feat_vu)
698
+ logits = (logits_uv + logits_vu) / 2.
699
+
700
+ # export_sampled_edges(
701
+ # coords=vtx_pred_coords, # [N, 4]
702
+ # u=final_u, # [E]
703
+ # v=final_v, # [E]
704
+ # labels=targets, # [E, 1]
705
+ # step_idx=b_idx,
706
+ # )
707
+ # export_sampled_edges(
708
+ # coords=vtx_pred_coords, # [N, 4]
709
+ # u=final_u, # [E]
710
+ # v=final_v, # [E]
711
+ # labels=targets, # [E, 1]
712
+ # step_idx=b_idx,
713
+ # batch_idx_to_viz=1,
714
+ # save_dir="debug_viz2"
715
+ # )
716
+ # exit()
717
+
718
+ connection_loss = self.asyloss(logits, targets)
719
+ total_loss += connection_loss
720
+ else:
721
+ connection_loss = torch.tensor(0., device=self.device)
722
+
723
+
724
+ # KL loss
725
+ kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
726
+ total_loss += kl_loss
727
+
728
+ # Backpropagation
729
+ scaled_total_loss = total_loss / self.accum_steps
730
+ # self.scaler.scale(scaled_total_loss).backward()
731
+ scaled_total_loss.backward()
732
+
733
+ return {
734
+ 'total_loss': total_loss.item(),
735
+ 'kl_loss': kl_loss.item(),
736
+ 'prune_loss': prune_loss_total.item(),
737
+ 'vertex_loss': vertex_loss_total.item(),
738
+ 'edge_loss': edge_loss_total.item(),
739
+ 'offset_loss': mse_loss_feats.item(),
740
+ 'direction_loss': mse_loss_dirs.item(),
741
+ 'connection_loss': connection_loss.item(),
742
+ }
743
+
744
+
745
+ def train(self):
746
+ accum_steps = self.accum_steps
747
+ for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
748
+ self.dataloader.sampler.set_epoch(epoch)
749
+ # Initialize metrics
750
+ metrics = {
751
+ 'total_loss': 0.0,
752
+ 'kl_loss': 0.0,
753
+ 'prune_loss': 0.0,
754
+ 'vertex_loss': 0.0,
755
+ 'edge_loss': 0.0,
756
+ 'offset_loss': 0.0,
757
+ 'direction_loss': 0.0,
758
+ 'connection_loss': 0.0,
759
+ }
760
+ num_batches = 0
761
+ self.optimizer.zero_grad(set_to_none=True)
762
+
763
+ for i, batch in enumerate(self.dataloader):
764
+ # Get all losses from train_step
765
+ if batch is None:
766
+ continue
767
+ step_losses = self.train_step(batch, i)
768
+
769
+ # Accumulate losses
770
+ for key in metrics:
771
+ metrics[key] += step_losses[key]
772
+
773
+ num_batches += 1
774
+
775
+ if (i + 1) % accum_steps == 0:
776
+ # self.scaler.unscale_(self.optimizer)
777
+ # torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
778
+ # torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
779
+ # torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
780
+
781
+ # self.scaler.step(self.optimizer)
782
+ # self.scaler.update()
783
+ # self.optimizer.zero_grad(set_to_none=True)
784
+
785
+ # self.scheduler.step()
786
+
787
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
788
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
789
+ torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
790
+ self.optimizer.step()
791
+ self.optimizer.zero_grad(set_to_none=True)
792
+ self.scheduler.step()
793
+
794
+
795
+ # Print batch-level metrics
796
+ if self.is_master:
797
+ avg_metric = {key: value / num_batches for key, value in metrics.items()}
798
+ print(
799
+ f"[Epoch {epoch}] Batch:{num_batches} "
800
+ f"AvgL:{avg_metric['total_loss']:.4f} "
801
+ f"Loss: {step_losses['total_loss']:.4f}, "
802
+ f"KLL: {step_losses['kl_loss']:.4f}, "
803
+ f"PruneL: {step_losses['prune_loss']:.4f}, "
804
+ f"VertexL: {step_losses['vertex_loss']:.4f}, "
805
+ f"EdgeL: {step_losses['edge_loss']:.4f}, "
806
+ f"OffsetL: {step_losses['offset_loss']:.4f}, "
807
+ f"DireL: {step_losses['direction_loss']:.4f}, "
808
+ f"ConL: {step_losses['connection_loss']:.4f}, "
809
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
810
+ )
811
+
812
+ if i % 2000 == 0 and i != 0:
813
+ self.save_checkpoint(epoch, avg_metric['total_loss'], i)
814
+ with open(self.log_file, "a") as f:
815
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
816
+ log_line = (
817
+ f"Epoch {epoch:05d} | "
818
+ f"Batch {i:05d} | "
819
+ f"Loss: {avg_metric['total_loss']:.6f} "
820
+ f"Avg KLL: {avg_metric['kl_loss']:.4f} "
821
+ f"Avg PruneL: {avg_metric['prune_loss']:.4f} "
822
+ f"Avg VertexL: {avg_metric['vertex_loss']:.4f} "
823
+ f"Avg EdgeL: {avg_metric['edge_loss']:.4f} "
824
+ f"Avg OffsetL: {avg_metric['offset_loss']:.4f} "
825
+ f"Avg DireL: {avg_metric['direction_loss']:.4f} "
826
+ f"Avg ConL: {avg_metric['connection_loss']:.4f} "
827
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
828
+ f"[{current_time}]\n"
829
+ )
830
+ f.write(log_line)
831
+
832
+ if num_batches % accum_steps != 0:
833
+ # self.scaler.unscale_(self.optimizer)
834
+ # torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
835
+ # torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
836
+ # torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
837
+
838
+ # self.scaler.step(self.optimizer)
839
+ # self.scaler.update()
840
+ # self.optimizer.zero_grad(set_to_none=True)
841
+
842
+ # self.scheduler.step()
843
+
844
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=0.5)
845
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=0.5)
846
+ torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=0.5)
847
+
848
+ self.optimizer.step()
849
+ self.optimizer.zero_grad(set_to_none=True)
850
+ self.scheduler.step()
851
+
852
+ # Calculate epoch averages
853
+ avg_metrics = {key: value / num_batches for key, value in metrics.items()}
854
+ self.train_loss_history.append(avg_metrics['total_loss'])
855
+
856
+
857
+ # Log to file
858
+ if self.is_master:
859
+ with open(self.log_file, "a") as f:
860
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
861
+ log_line = (
862
+ f"Epoch {epoch:05d} | "
863
+ f"Loss: {avg_metrics['total_loss']:.6f} "
864
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
865
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
866
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
867
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
868
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
869
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
870
+ f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
871
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
872
+ f"[{current_time}]\n"
873
+ )
874
+ f.write(log_line)
875
+
876
+ # Print epoch summary
877
+ print(
878
+ f"[Epoch {epoch}] "
879
+ f"Avg Loss: {avg_metrics['total_loss']:.4f} "
880
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
881
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
882
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
883
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
884
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
885
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
886
+ f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
887
+ f"[{current_time}]\n"
888
+ )
889
+
890
+ # Save checkpoint
891
+ if epoch % self.cfg['training']['save_every'] == 0:
892
+ self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
893
+
894
+ # Update learning rate
895
+ if self.is_master:
896
+ current_lr = self.optimizer.param_groups[0]['lr']
897
+ print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
898
+
899
+ dist.barrier()
900
+
901
+
902
+ def main():
903
+ # Initialize the process group
904
+ dist.init_process_group(backend='nccl')
905
+
906
+ # Get rank and world size from environment variables set by the launcher
907
+ rank = int(os.environ['RANK'])
908
+ world_size = int(os.environ['WORLD_SIZE'])
909
+ local_rank = int(os.environ['LOCAL_RANK'])
910
+
911
+ # Set the device for the current process. This is crucial.
912
+ torch.cuda.set_device(local_rank)
913
+ torch.manual_seed(42)
914
+
915
+ # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
916
+ # Pass the distributed info to the Trainer
917
+ trainer = Trainer(
918
+ config_path="/root/Trisf/config_edge_1024_error_8enc_8dec_woself_finetune_128to1024.yaml",
919
+ rank=rank,
920
+ world_size=world_size,
921
+ local_rank=local_rank
922
+ )
923
+ trainer.train()
924
+
925
+ # Clean up the process group
926
+ dist.destroy_process_group()
927
+
928
+
929
+ if __name__ == '__main__':
930
+ main()
train_slat_vae_512_128to256_pointnet_head.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import os
3
+ # os.environ['ATTN_BACKEND'] = 'xformers'
4
+ import yaml
5
+ import torch
6
+ import time
7
+ from datetime import datetime
8
+ from torch.utils.data import DataLoader
9
+ from functools import partial
10
+ from triposf.modules.sparse.basic import SparseTensor
11
+ import torch.nn.functional as F
12
+ from torch.optim import AdamW
13
+ from torch.cuda.amp import GradScaler, autocast
14
+
15
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
16
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
17
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
18
+
19
+ from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss
20
+
21
+ import torch.distributed as dist
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+ from torch.utils.data.distributed import DistributedSampler
24
+
25
+ from transformers import get_cosine_schedule_with_warmup
26
+ import math
27
+
28
+ import numpy as np
29
+ import open3d as o3d
30
+
31
+ def export_sampled_edges(coords, u, v, labels, step_idx, save_dir="debug_viz", batch_idx_to_viz=0):
32
+ """
33
+ 导出采样边为 PLY 文件。
34
+
35
+ Args:
36
+ coords: [N, 4] Tensor (batch_idx, x, y, z)
37
+ u: [E] Tensor, 起点索引 (global index)
38
+ v: [E] Tensor, 终点索引 (global index)
39
+ labels: [E, 1] Tensor, 1.0 为正样本, 0.0 为负样本
40
+ step_idx: 当前步数或 epoch,用于文件名
41
+ save_dir: 保存目录
42
+ batch_idx_to_viz: 只可视化哪个 batch 的数据 (防止多个 batch 叠加在一起看不清)
43
+ """
44
+ os.makedirs(save_dir, exist_ok=True)
45
+
46
+ # 1. 转为 CPU numpy
47
+ coords_np = coords.detach().cpu().numpy()
48
+ u_np = u.detach().cpu().numpy()
49
+ v_np = v.detach().cpu().numpy()
50
+ labels_np = labels.detach().cpu().numpy().reshape(-1)
51
+
52
+ # 2. 筛选特定 Batch (通常只看 Batch 0 比较清晰)
53
+ # coords 的第0列是 batch index
54
+ batch_mask = (coords_np[:, 0] == batch_idx_to_viz)
55
+
56
+ # 获取属于该 batch 的全局索引范围
57
+ # 注意:u 和 v 是针对所有 coords 的全局索引。
58
+ # 我们需要判断一条边的两个端点是否都在这个 batch 内。
59
+
60
+ # 快速检查端点是否在当前 batch
61
+ valid_u_in_batch = batch_mask[u_np]
62
+ valid_v_in_batch = batch_mask[v_np]
63
+ edge_batch_mask = valid_u_in_batch & valid_v_in_batch
64
+
65
+ if edge_batch_mask.sum() == 0:
66
+ print(f"Warning: No edges found for batch {batch_idx_to_viz}")
67
+ return
68
+
69
+ # 应用 Batch 筛选
70
+ u_b = u_np[edge_batch_mask]
71
+ v_b = v_np[edge_batch_mask]
72
+ labels_b = labels_np[edge_batch_mask]
73
+
74
+ # 3. 提取该 Batch 的顶点坐标 (去掉 batch_idx 维度)
75
+ # 此时我们需要重新映射 u, v 的索引,因为我们要只保存该 batch 的点
76
+ batch_indices_global = np.where(batch_mask)[0]
77
+ # 创建全局索引到局部索引的映射表
78
+ global_to_local = {gid: lid for lid, gid in enumerate(batch_indices_global)}
79
+
80
+ points_xyz = coords_np[batch_indices_global, 1:4] # [M, 3]
81
+
82
+ # 转换 u, v 为局部索引
83
+ try:
84
+ u_local = np.array([global_to_local[idx] for idx in u_b])
85
+ v_local = np.array([global_to_local[idx] for idx in v_b])
86
+ except KeyError:
87
+ print("Error in index mapping. Edge endpoints might cross batches.")
88
+ return
89
+
90
+ # 4. 分离正负样本
91
+ pos_mask = labels_b > 0.5
92
+ neg_mask = ~pos_mask
93
+
94
+ # 内部函数:写 PLY
95
+ def write_ply(filename, points, edges_u, edges_v, color_rgb):
96
+ num_verts = len(points)
97
+ num_edges = len(edges_u)
98
+
99
+ with open(filename, 'w') as f:
100
+ f.write("ply\n")
101
+ f.write("format ascii 1.0\n")
102
+ f.write(f"element vertex {num_verts}\n")
103
+ f.write("property float x\n")
104
+ f.write("property float y\n")
105
+ f.write("property float z\n")
106
+ f.write("property uchar red\n")
107
+ f.write("property uchar green\n")
108
+ f.write("property uchar blue\n")
109
+ f.write(f"element edge {num_edges}\n")
110
+ f.write("property int vertex1\n")
111
+ f.write("property int vertex2\n")
112
+ f.write("end_header\n")
113
+
114
+ # Write Vertices with Color
115
+ # 为了让可视化更清楚,我们将所有点染成指定颜色
116
+ for i in range(num_verts):
117
+ x, y, z = points[i]
118
+ f.write(f"{x:.4f} {y:.4f} {z:.4f} {color_rgb[0]} {color_rgb[1]} {color_rgb[2]}\n")
119
+
120
+ # Write Edges
121
+ for i in range(num_edges):
122
+ f.write(f"{edges_u[i]} {edges_v[i]}\n")
123
+
124
+ print(f"Saved: {filename} (Edges: {num_edges})")
125
+
126
+ # 5. 保存正样本 (绿色)
127
+ if pos_mask.sum() > 0:
128
+ write_ply(
129
+ os.path.join(save_dir, f"step_{step_idx}_pos_edges.ply"),
130
+ points_xyz, # 使用所有点,或者优化为只使用涉及的点(这里为了坐标统一���单起见使用所有点)
131
+ u_local[pos_mask],
132
+ v_local[pos_mask],
133
+ color_rgb=(0, 255, 0) # Green
134
+ )
135
+
136
+ # 6. 保存负样本 (红色)
137
+ if neg_mask.sum() > 0:
138
+ # 为了避免文件太大或太乱,如果负样本特别多,可以考虑随机采样一部分保存
139
+ # 这里默认全部保存
140
+ write_ply(
141
+ os.path.join(save_dir, f"step_{step_idx}_neg_edges.ply"),
142
+ points_xyz,
143
+ u_local[neg_mask],
144
+ v_local[neg_mask],
145
+ color_rgb=(255, 0, 0) # Red
146
+ )
147
+
148
+ def flatten_coords_4d(coords_4d: torch.Tensor):
149
+ coords_4d_long = coords_4d.long()
150
+
151
+ base_x = 256
152
+ base_y = 256 * 256
153
+ base_z = 256 * 256 * 256
154
+
155
+ flat_coords = coords_4d_long[:, 0] * base_z + \
156
+ coords_4d_long[:, 1] * base_y + \
157
+ coords_4d_long[:, 2] * base_x + \
158
+ coords_4d_long[:, 3]
159
+ return flat_coords
160
+
161
+ def downsample_voxels(
162
+ voxels: torch.Tensor,
163
+ input_resolution: int,
164
+ output_resolution: int
165
+ ) -> torch.Tensor:
166
+ if input_resolution % output_resolution != 0:
167
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
168
+ f"by output_resolution ({output_resolution}).")
169
+
170
+ factor = input_resolution // output_resolution
171
+
172
+ downsampled_voxels = voxels.clone().to(torch.long)
173
+
174
+ downsampled_voxels[:, 1:] //= factor
175
+
176
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
177
+ return unique_downsampled_voxels
178
+
179
+ class Trainer:
180
+ def __init__(self, config_path, rank, world_size, local_rank):
181
+ self.rank = rank
182
+ self.world_size = world_size
183
+ self.local_rank = local_rank
184
+ self.is_master = self.rank == 0
185
+
186
+ self.load_config(config_path)
187
+ self.accum_steps = max(1, 4 // self.cfg['training']['batch_size'])
188
+
189
+ self.config_hash = self.save_config_with_hash()
190
+ self.init_device()
191
+ self.init_dirs()
192
+ self.init_components()
193
+ self.init_training()
194
+
195
+ self.train_loss_history = []
196
+ self.eval_loss_history = []
197
+ self.best_eval_loss = float('inf')
198
+
199
+ def save_config_with_hash(self):
200
+ import hashlib
201
+
202
+ # Serialize config to hash
203
+ config_str = yaml.dump(self.cfg)
204
+ config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
205
+
206
+ # Prepare all flags as string for formatting
207
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
208
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
209
+
210
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
211
+
212
+ # Format save_dir with all placeholders
213
+ self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
214
+ dataset_name=dataset_name,
215
+ config_hash=config_hash,
216
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
217
+ multires=self.cfg['model']['multires'],
218
+ add_block_embed=add_block_embed_flag,
219
+ using_attn=using_attn_flag,
220
+ batch_size=self.cfg['training']['batch_size'],
221
+ )
222
+
223
+ if self.is_master:
224
+ os.makedirs(self.save_dir, exist_ok=True)
225
+ config_path = os.path.join(self.save_dir, "config.yaml")
226
+ with open(config_path, 'w') as f:
227
+ yaml.dump(self.cfg, f)
228
+
229
+ dist.barrier()
230
+ return config_hash
231
+
232
+ def save_checkpoint(self, epoch, avg_loss, batch_idx):
233
+ if not self.is_master:
234
+ return
235
+ checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
236
+ config_path = os.path.join(self.save_dir, "config.yaml")
237
+
238
+ torch.save({
239
+ 'voxel_encoder': self.voxel_encoder.module.state_dict(),
240
+ 'vae': self.vae.module.state_dict(),
241
+ 'connection_head': self.connection_head.module.state_dict(),
242
+ 'epoch': epoch,
243
+ 'loss': avg_loss,
244
+ 'config': self.cfg
245
+ }, checkpoint_path)
246
+
247
+ def quoted_presenter(dumper, data):
248
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
249
+
250
+ yaml.add_representer(str, quoted_presenter)
251
+
252
+ with open(config_path, 'w') as f:
253
+ yaml.dump(self.cfg, f)
254
+
255
+ def load_config(self, config_path):
256
+ with open(config_path) as f:
257
+ self.cfg = yaml.safe_load(f)
258
+
259
+ # Extract and convert flags for formatting
260
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
261
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
262
+
263
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
264
+
265
+ self.save_dir = self.cfg['experiment']['save_dir'].format(
266
+ dataset_name=dataset_name,
267
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
268
+ multires=self.cfg['model']['multires'],
269
+ add_block_embed=add_block_embed_flag,
270
+ using_attn=using_attn_flag,
271
+ batch_size=self.cfg['training']['batch_size'],
272
+ )
273
+
274
+ if self.is_master:
275
+ os.makedirs(self.save_dir, exist_ok=True)
276
+ dist.barrier()
277
+
278
+
279
+ def init_device(self):
280
+ self.device = torch.device(f"cuda:{self.local_rank}")
281
+
282
+ def init_dirs(self):
283
+ self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
284
+ if self.is_master:
285
+ with open(self.log_file, "a") as f:
286
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
287
+ f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
288
+
289
+ def init_components(self):
290
+ self.dataset = VoxelVertexDataset_edge(
291
+ root_dir=self.cfg['dataset']['path'],
292
+ base_resolution=self.cfg['dataset']['base_resolution'],
293
+ min_resolution=self.cfg['dataset']['min_resolution'],
294
+ cache_dir=self.cfg['dataset']['cache_dir'],
295
+ renders_dir=self.cfg['dataset']['renders_dir'],
296
+
297
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
298
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
299
+
300
+ active_voxel_res=128,
301
+ pc_sample_number=819200,
302
+
303
+ sample_type=self.cfg['dataset']['sample_type'],
304
+ )
305
+
306
+ self.sampler = DistributedSampler(
307
+ self.dataset,
308
+ num_replicas=self.world_size,
309
+ rank=self.rank,
310
+ shuffle=True,
311
+ )
312
+
313
+ self.dataloader = DataLoader(
314
+ self.dataset,
315
+ batch_size=self.cfg['training']['batch_size'],
316
+ shuffle=False,
317
+ collate_fn=partial(collate_fn_pointnet,),
318
+ num_workers=self.cfg['training']['num_workers'],
319
+ pin_memory=True,
320
+ sampler=self.sampler,
321
+ # prefetch_factor=4,
322
+ persistent_workers=True,
323
+ )
324
+
325
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
326
+ in_channels=15,
327
+ hidden_dim=256,
328
+ out_channels=1024,
329
+ scatter_type='mean',
330
+ n_blocks=5,
331
+ resolution=128,
332
+ add_label=False,
333
+ ).to(self.device)
334
+
335
+ self.connection_head = ConnectionHead(
336
+ channels=128 * 2,
337
+ out_channels=1,
338
+ mlp_ratio=4,
339
+ ).to(self.device)
340
+
341
+ # ablation 3: voxelvae_1volume, have tested
342
+ self.vae = VoxelVAE(
343
+ in_channels=self.cfg['model']['in_channels'],
344
+ latent_dim=self.cfg['model']['latent_dim'],
345
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
346
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
347
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
348
+ num_heads=8,
349
+ num_head_channels=64,
350
+ mlp_ratio=4.0,
351
+ attn_mode="swin",
352
+ window_size=8,
353
+ pe_mode="ape",
354
+ use_fp16=False,
355
+ use_checkpoint=True,
356
+ qk_rms_norm=False,
357
+ using_subdivide=True,
358
+ using_attn=self.cfg['model']['using_attn'],
359
+ attn_first=self.cfg['model'].get('attn_first', True),
360
+ pred_direction=self.cfg['model'].get('pred_direction', False),
361
+ ).to(self.device)
362
+
363
+ if self.cfg['training']['from_pretrained']:
364
+ load_pretrained_woself(
365
+ checkpoint_path=self.cfg['training']['checkpoint_path'],
366
+ voxel_encoder=self.voxel_encoder,
367
+ vae=self.vae,
368
+ connection_head=self.connection_head,
369
+ optimizer=None,
370
+ )
371
+
372
+ self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
373
+ self.connection_head = DDP(self.connection_head, device_ids=[self.local_rank], find_unused_parameters=False)
374
+ self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=False)
375
+
376
+ def init_training(self):
377
+ self.optimizer = AdamW(
378
+ list(self.vae.module.parameters()) +
379
+ list(self.voxel_encoder.module.parameters()) +
380
+ list(self.connection_head.module.parameters()),
381
+ lr=self.cfg['training']['lr'],
382
+ weight_decay=0.01,
383
+ )
384
+
385
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
386
+ max_epochs = self.cfg['training']['max_epochs']
387
+ num_training_steps = max_epochs * num_update_steps_per_epoch
388
+
389
+ num_warmup_steps = 100
390
+
391
+ self.scheduler = get_cosine_schedule_with_warmup(
392
+ self.optimizer,
393
+ num_warmup_steps=num_warmup_steps,
394
+ num_training_steps=num_training_steps
395
+ )
396
+
397
+ self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=100.0).to(self.device)
398
+ self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
399
+ self.asyloss = AsymmetricFocalLoss(
400
+ gamma_pos=0.0,
401
+ gamma_neg=4.0,
402
+ clip=0.05,
403
+ )
404
+
405
+ self.bce_loss = torch.nn.BCEWithLogitsLoss()
406
+
407
+ self.dice_loss = DiceLoss()
408
+ self.scaler = GradScaler()
409
+
410
+ def train_step(self, batch):
411
+ """Modified training step that handles vertex and edge voxels separately after initial prediction."""
412
+ # 1. Retrieve data from batch
413
+ combined_voxels_256 = batch['combined_voxels_256'].to(self.device)
414
+ combined_voxel_labels_256 = batch['combined_voxel_labels_256'].to(self.device)
415
+ gt_vertex_voxels_256 = batch['gt_vertex_voxels_256'].to(self.device)
416
+
417
+ gt_edge_voxels_256 = batch['gt_edge_voxels_256'].to(self.device)
418
+ gt_combined_endpoints_256 = batch['gt_combined_endpoints_256'].to(self.device)
419
+ gt_combined_errors_256 = batch['gt_combined_errors_256'].to(self.device)
420
+
421
+ gt_edges = batch['gt_vertex_edge_indices_256'].to(self.device)
422
+
423
+ edge_mask = (combined_voxel_labels_256 == 1)
424
+
425
+ gt_edge_endpoints_256 = gt_combined_endpoints_256[edge_mask]
426
+ gt_edge_errors_256 = gt_combined_errors_256[edge_mask]
427
+
428
+ p1 = gt_edge_endpoints_256[:, 1:4].float()
429
+ p2 = gt_edge_endpoints_256[:, 4:7].float()
430
+
431
+ mask = ( (p1[:,0] < p2[:,0]) |
432
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
433
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
434
+
435
+ pA = torch.where(mask[:, None], p1, p2) # smaller one
436
+ pB = torch.where(mask[:, None], p2, p1) # larger one
437
+
438
+ d = pB - pA
439
+ dir_gt = F.normalize(d, dim=-1, eps=1e-6)
440
+ vtx_128 = downsample_voxels(gt_vertex_voxels_256, input_resolution=256, output_resolution=128)
441
+ vtx_256 = gt_vertex_voxels_256
442
+
443
+ edge_128 = downsample_voxels(combined_voxels_256, input_resolution=256, output_resolution=128)
444
+ edge_256 = combined_voxels_256
445
+
446
+ active_coords = batch['active_voxels_128'].to(self.device)
447
+ point_cloud = batch['point_cloud_128'].to(self.device)
448
+
449
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
450
+ active_voxel_feats = self.voxel_encoder(
451
+ p=point_cloud,
452
+ sparse_coords=active_coords,
453
+ res=128,
454
+ bbox_size=(-0.5, 0.5),
455
+ )
456
+
457
+ sparse_input = SparseTensor(
458
+ feats=active_voxel_feats,
459
+ coords=active_coords.int()
460
+ )
461
+
462
+ gt_edge_voxels_list = [
463
+ edge_128,
464
+ edge_256,
465
+ ]
466
+
467
+ gt_vertex_voxels_list = [
468
+ vtx_128,
469
+ vtx_256,
470
+ ]
471
+
472
+ results, posterior, latent_128 = self.vae(
473
+ sparse_input,
474
+ gt_vertex_voxels_list=gt_vertex_voxels_list,
475
+ gt_edge_voxels_list=gt_edge_voxels_list,
476
+ training=True,
477
+ sample_ratio=0.,
478
+ )
479
+
480
+ # print("results[-1]['edge']['coords_4d'][1827:1830]", results[-1]['edge']['coords_4d'][1827:1830])
481
+ total_loss = 0.
482
+ prune_loss_total = 0.
483
+ vertex_loss_total = 0.
484
+ edge_loss_total=0.
485
+
486
+ with autocast(dtype=torch.bfloat16):
487
+ initial_result = results[0]
488
+ vertex_mask = initial_result['vertex_mask']
489
+ vtx_logits = initial_result['vtx_feats']
490
+ vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
491
+
492
+ edge_mask = initial_result['edge_mask']
493
+ edge_logits = initial_result['edge_feats']
494
+ edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
495
+
496
+ vertex_loss_total += vertex_loss
497
+ edge_loss_total += edge_loss
498
+
499
+ total_loss += vertex_loss
500
+ total_loss += edge_loss
501
+
502
+ # Process each level's results
503
+ for idx, res_dict in enumerate(results[1:], start=1):
504
+ # Vertex branch losses
505
+ vertex_pred_coords = res_dict['vertex']['occ_coords']
506
+ vertex_occ_probs = res_dict['vertex']['occ_probs']
507
+ vertex_gt_coords = res_dict['vertex']['coords']
508
+
509
+ vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=256).float()
510
+ # print('vertex_labels.sum()', vertex_labels.sum(), idx)
511
+ vertex_logits = vertex_occ_probs.squeeze()
512
+
513
+ # if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
514
+ vertex_prune_loss = self.focal_loss(vertex_logits, vertex_labels)
515
+ # vertex_prune_loss = self.dice_loss(vertex_logits, vertex_labels)
516
+
517
+ # dilation 1: bce loss
518
+ # vertex_prune_loss = self.bce_loss(vertex_logits, vertex_labels,)
519
+
520
+ prune_loss_total += vertex_prune_loss
521
+ total_loss += vertex_prune_loss
522
+
523
+
524
+ # Edge branch losses
525
+ edge_pred_coords = res_dict['edge']['occ_coords']
526
+ edge_occ_probs = res_dict['edge']['occ_probs']
527
+ edge_gt_coords = res_dict['edge']['coords']
528
+
529
+ edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=256).float()
530
+ # print('edge_labels.sum()', edge_labels.sum(), idx)
531
+ edge_logits = edge_occ_probs.squeeze()
532
+ # if edge_labels.sum() > 0 and edge_labels.sum() < len(edge_labels):
533
+ edge_prune_loss = self.focal_loss(edge_logits, edge_labels)
534
+
535
+ # dilation 1: bce loss
536
+ # edge_prune_loss = self.bce_loss(edge_logits, edge_labels,)
537
+
538
+ prune_loss_total += edge_prune_loss
539
+ total_loss += edge_prune_loss
540
+
541
+ if idx == 1:
542
+ pred_coords = res_dict['edge']['coords_4d'] # [N,4] (b,x,y,z)
543
+ pred_feats = res_dict['edge']['predicted_offset_feats'] # [N,C]
544
+
545
+ gt_coords = gt_edge_voxels_256.to(pred_coords.device) # [M,4]
546
+ gt_feats = gt_edge_errors_256[:, 1:].to(pred_coords.device) # [M,C]
547
+
548
+ pred_keys = flatten_coords_4d(pred_coords)
549
+ gt_keys = flatten_coords_4d(gt_coords)
550
+
551
+ sorted_pred_keys, pred_order = torch.sort(pred_keys)
552
+ pred_coords_sorted = pred_coords[pred_order]
553
+ pred_feats_sorted = pred_feats[pred_order]
554
+
555
+ sorted_gt_keys, gt_order = torch.sort(gt_keys)
556
+ gt_coords_sorted = gt_coords[gt_order]
557
+ gt_feats_sorted = gt_feats[gt_order]
558
+
559
+ pos = torch.searchsorted(sorted_gt_keys, sorted_pred_keys)
560
+ valid_mask = (pos < len(sorted_gt_keys)) & (sorted_gt_keys[pos] == sorted_pred_keys)
561
+
562
+ if valid_mask.any():
563
+ # print('valid_mask.sum()', valid_mask.sum())
564
+ matched_pred_feats = pred_feats_sorted[valid_mask]
565
+ matched_gt_feats = gt_feats_sorted[pos[valid_mask]]
566
+ mse_loss_feats = self.mse_loss(matched_pred_feats, matched_gt_feats * 2)
567
+ total_loss += mse_loss_feats * 0.
568
+
569
+ if self.cfg['model'].get('pred_direction', False):
570
+ pred_dirs = res_dict['edge']['predicted_direction_feats']
571
+ dir_gt_device = dir_gt.to(pred_coords.device)
572
+
573
+ pred_dirs_sorted = pred_dirs[pred_order]
574
+ dir_gt_sorted = dir_gt_device[gt_order]
575
+
576
+ matched_pred_dirs = pred_dirs_sorted[valid_mask]
577
+ matched_gt_dirs = dir_gt_sorted[pos[valid_mask]]
578
+
579
+ mse_loss_dirs = self.mse_loss(matched_pred_dirs, matched_gt_dirs)
580
+ total_loss += mse_loss_dirs * 0.
581
+ else:
582
+ mse_loss_feats = torch.tensor(0., device=pred_coords.device)
583
+ if self.cfg['model'].get('pred_direction', False):
584
+ mse_loss_dirs = torch.tensor(0., device=pred_coords.device)
585
+
586
+
587
+ # --- Vertex Branch (Connection Loss 核心) ---
588
+ vtx_pred_coords = res_dict['vertex']['coords_4d'] # [N, 4]
589
+ vtx_pred_feats = res_dict['vertex']['feats'] # [N, C]
590
+
591
+ # 1.1 排序 (既用于匹配 GT,也用于快速寻找空间邻居)
592
+ vtx_pred_keys = flatten_coords_4d(vtx_pred_coords)
593
+ vtx_pred_keys_sorted, vtx_pred_order = torch.sort(vtx_pred_keys)
594
+
595
+ # 1.2 匹配 GT
596
+ vtx_gt_keys = flatten_coords_4d(gt_vertex_voxels_256.to(self.device))
597
+ vtx_pos = torch.searchsorted(vtx_pred_keys_sorted, vtx_gt_keys)
598
+ vtx_pos = vtx_pos.clamp(max=len(vtx_pred_keys_sorted) - 1)
599
+ vtx_match_mask = (vtx_pred_keys_sorted[vtx_pos] == vtx_gt_keys)
600
+
601
+ gt_to_pred_mapping = torch.full((len(vtx_gt_keys),), -1, device=self.device, dtype=torch.long)
602
+ matched_pred_indices = vtx_pred_order[vtx_pos[vtx_match_mask]]
603
+ gt_to_pred_mapping[vtx_match_mask] = matched_pred_indices
604
+
605
+ # ====================================================
606
+ # 2. 构建核心数据:正样本 Hash 集合
607
+ # ====================================================
608
+ # 这里的 pos_u/pos_v 仅用于构建 "什么是真连接" 的查询表
609
+ u_gt, v_gt = gt_edges[:, 0], gt_edges[:, 1]
610
+ u_pred = gt_to_pred_mapping[u_gt]
611
+ v_pred = gt_to_pred_mapping[v_gt]
612
+
613
+ valid_edge_mask = (u_pred != -1) & (v_pred != -1)
614
+ real_pos_u = u_pred[valid_edge_mask]
615
+ real_pos_v = v_pred[valid_edge_mask]
616
+
617
+ num_real_pos = real_pos_u.shape[0]
618
+ num_total_nodes = vtx_pred_coords.shape[0]
619
+
620
+ if num_real_pos > 0:
621
+ # 2. 构建候选样本 (Candidate Generation)
622
+ # ====================================================
623
+ cand_u_list = []
624
+ cand_v_list = []
625
+
626
+ batch_ids = vtx_pred_coords[:, 0]
627
+ unique_batches = torch.unique(batch_ids)
628
+
629
+ RADIUS = 16
630
+ MAX_PTS_FOR_DIST = 12000
631
+ K_RANDOM = 32
632
+
633
+ for b_id in unique_batches:
634
+ mask_b = (batch_ids == b_id)
635
+ indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
636
+ coords_b = vtx_pred_coords[mask_b, 1:4].float() # (x,y,z)
637
+ num_b = coords_b.shape[0]
638
+
639
+ if num_b < 2: continue
640
+
641
+ # --- A. Radius Graph (Hard Negatives) ---
642
+ if num_b <= MAX_PTS_FOR_DIST:
643
+ # 计算距离矩阵 [M, M]
644
+ # 注意:autocast 下 float16 的 cdist 可能精度不够,建议转 float32
645
+ dist_mat = torch.cdist(coords_b.float(), coords_b.float())
646
+
647
+ # 找到距离小于 Radius 的点对 (排除自环)
648
+ adj_mat = (dist_mat < RADIUS) & (dist_mat > 1e-6)
649
+
650
+ # 提取索引 (local indices in batch)
651
+ src_local, dst_local = torch.nonzero(adj_mat, as_tuple=True)
652
+
653
+ # 映射回全局索引
654
+ cand_u_list.append(indices_b[src_local])
655
+ cand_v_list.append(indices_b[dst_local])
656
+ else:
657
+ # 如果点太多,显存不够,退化为随机局部采样或跳过
658
+ # 这里简单处理:跳过 Radius Graph,依赖 Random
659
+ pass
660
+
661
+ # --- B. Random Sampling (Easy Negatives) ---
662
+ # 随机生成 num_b * K 对
663
+ n_rand = num_b * K_RANDOM
664
+ rand_src_local = torch.randint(0, num_b, (n_rand,), device=self.device)
665
+ rand_dst_local = torch.randint(0, num_b, (n_rand,), device=self.device)
666
+
667
+ # 映射回全局索引
668
+ cand_u_list.append(indices_b[rand_src_local])
669
+ cand_v_list.append(indices_b[rand_dst_local])
670
+
671
+ # 合并所有来源 (GT + Radius + Random)
672
+ # 注意:我们把 real_pos 也加进来,确保正样本一定在列表里
673
+ all_u = torch.cat([real_pos_u] + cand_u_list)
674
+ all_v = torch.cat([real_pos_v] + cand_v_list)
675
+
676
+ # 3. 去重与 Labeling (Deduplication & Labeling)
677
+ # ====================================================
678
+ # 构造无向边 Hash: min * N + max
679
+ # 确保 MAX_NODES 足够大,比如 1000000 或 num_total_nodes
680
+ HASH_BASE = num_total_nodes + 100
681
+
682
+ p_min = torch.min(all_u, all_v)
683
+ p_max = torch.max(all_u, all_v)
684
+
685
+ # 过滤掉自环 (u==v)
686
+ valid_pair = (p_min != p_max)
687
+ p_min = p_min[valid_pair]
688
+ p_max = p_max[valid_pair]
689
+
690
+ all_hashes = p_min.long() * HASH_BASE + p_max.long()
691
+
692
+ # --- 核心:去重 ---
693
+ unique_hashes = torch.unique(all_hashes)
694
+
695
+ # 解码回 u, v
696
+ final_u = unique_hashes // HASH_BASE
697
+ final_v = unique_hashes % HASH_BASE
698
+
699
+ # --- Labeling ---
700
+ # 构建 GT 的 Hash 表用于查询
701
+ gt_min = torch.min(real_pos_u, real_pos_v)
702
+ gt_max = torch.max(real_pos_u, real_pos_v)
703
+ gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
704
+ gt_hashes = torch.unique(gt_hashes) # GT 也去重一下保险
705
+ gt_hashes_sorted, _ = torch.sort(gt_hashes)
706
+
707
+ # 查询 unique_hashes 是否在 gt_hashes 中
708
+ # 使用 searchsorted
709
+ idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
710
+ idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
711
+ is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
712
+
713
+ targets = is_connected.float().unsqueeze(-1) # [N_pairs, 1]
714
+
715
+ # 4. 前向传播与 Loss
716
+ # ====================================================
717
+ feat_u = vtx_pred_feats[final_u]
718
+ feat_v = vtx_pred_feats[final_v]
719
+
720
+ # 对称特征融合
721
+ feat_uv = torch.cat([feat_u, feat_v], dim=-1)
722
+ feat_vu = torch.cat([feat_v, feat_u], dim=-1)
723
+
724
+ logits_uv = self.connection_head(feat_uv)
725
+ logits_vu = self.connection_head(feat_vu)
726
+ logits = logits_uv + logits_vu
727
+
728
+ # print('targets.sum()', targets.sum())
729
+ # print('targets.shape', targets.shape)
730
+
731
+ # export_sampled_edges(
732
+ # coords=vtx_pred_coords, # [N, 4]
733
+ # u=final_u, # [E]
734
+ # v=final_v, # [E]
735
+ # labels=targets, # [E, 1]
736
+ # step_idx=0,
737
+ # )
738
+ # exit()
739
+
740
+ # Focal Loss
741
+ connection_loss = self.asyloss(logits, targets)
742
+ total_loss += connection_loss
743
+
744
+ else:
745
+ connection_loss = torch.tensor(0., device=self.device)
746
+
747
+
748
+
749
+ # KL loss
750
+ kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
751
+ total_loss += kl_loss
752
+
753
+ # Backpropagation
754
+ scaled_total_loss = total_loss / self.accum_steps
755
+ self.scaler.scale(scaled_total_loss).backward()
756
+
757
+ return {
758
+ 'total_loss': total_loss.item(),
759
+ 'kl_loss': kl_loss.item(),
760
+ 'prune_loss': prune_loss_total.item(),
761
+ 'vertex_loss': vertex_loss_total.item(),
762
+ 'edge_loss': edge_loss_total.item(),
763
+ 'offset_loss': mse_loss_feats.item(),
764
+ 'direction_loss': mse_loss_dirs.item(),
765
+ 'connection_loss': connection_loss.item(),
766
+ }
767
+
768
+
769
+ def train(self):
770
+ accum_steps = self.accum_steps
771
+ for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
772
+ self.dataloader.sampler.set_epoch(epoch)
773
+ # Initialize metrics
774
+ metrics = {
775
+ 'total_loss': 0.0,
776
+ 'kl_loss': 0.0,
777
+ 'prune_loss': 0.0,
778
+ 'vertex_loss': 0.0,
779
+ 'edge_loss': 0.0,
780
+ 'offset_loss': 0.0,
781
+ 'direction_loss': 0.0,
782
+ 'connection_loss': 0.0,
783
+ }
784
+ num_batches = 0
785
+ self.optimizer.zero_grad(set_to_none=True)
786
+
787
+ for i, batch in enumerate(self.dataloader):
788
+ # Get all losses from train_step
789
+ if batch is None:
790
+ continue
791
+ step_losses = self.train_step(batch)
792
+
793
+ # Accumulate losses
794
+ for key in metrics:
795
+ metrics[key] += step_losses[key]
796
+ num_batches += 1
797
+
798
+ if (i + 1) % accum_steps == 0:
799
+ self.scaler.unscale_(self.optimizer)
800
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
801
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
802
+ torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=1.0)
803
+
804
+ self.scaler.step(self.optimizer)
805
+ self.scaler.update()
806
+ self.optimizer.zero_grad(set_to_none=True)
807
+
808
+ self.scheduler.step()
809
+
810
+ # Print batch-level metrics
811
+ if self.is_master:
812
+ print(
813
+ f"[Epoch {epoch}] Batch:{num_batches} "
814
+ f"Loss: {step_losses['total_loss']:.4f}, "
815
+ f"KLL: {step_losses['kl_loss']:.4f}, "
816
+ f"PruneL: {step_losses['prune_loss']:.4f}, "
817
+ f"VertexL: {step_losses['vertex_loss']:.4f}, "
818
+ f"EdgeL: {step_losses['edge_loss']:.4f}, "
819
+ f"OffsetL: {step_losses['offset_loss']:.4f}, "
820
+ f"DireL: {step_losses['direction_loss']:.4f}, "
821
+ f"ConL: {step_losses['connection_loss']:.4f}, "
822
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
823
+ )
824
+
825
+ # if i % 2000 == 0 and i != 0:
826
+ # self.save_checkpoint(epoch, step_losses['total_loss'], i)
827
+
828
+ if num_batches % accum_steps != 0:
829
+ self.scaler.unscale_(self.optimizer)
830
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
831
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
832
+ torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=1.0)
833
+
834
+ self.scaler.step(self.optimizer)
835
+ self.scaler.update()
836
+ self.optimizer.zero_grad(set_to_none=True)
837
+
838
+ self.scheduler.step()
839
+
840
+ # Calculate epoch averages
841
+ avg_metrics = {key: value / num_batches for key, value in metrics.items()}
842
+ self.train_loss_history.append(avg_metrics['total_loss'])
843
+
844
+
845
+ # Log to file
846
+ if self.is_master:
847
+ with open(self.log_file, "a") as f:
848
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
849
+ log_line = (
850
+ f"Epoch {epoch:05d} | "
851
+ f"Loss: {avg_metrics['total_loss']:.6f} "
852
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
853
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
854
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
855
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
856
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
857
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
858
+ f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
859
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
860
+ f"[{current_time}]\n"
861
+ )
862
+ f.write(log_line)
863
+
864
+ # Print epoch summary
865
+ print(
866
+ f"[Epoch {epoch}] "
867
+ f"Avg Loss: {avg_metrics['total_loss']:.4f} "
868
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
869
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
870
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
871
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
872
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
873
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
874
+ f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
875
+ f"[{current_time}]\n"
876
+ )
877
+
878
+ # Save checkpoint
879
+ if epoch % self.cfg['training']['save_every'] == 0:
880
+ self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
881
+
882
+ # Update learning rate
883
+ if self.is_master:
884
+ current_lr = self.optimizer.param_groups[0]['lr']
885
+ print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
886
+
887
+ dist.barrier()
888
+
889
+ def main():
890
+ # Initialize the process group
891
+ dist.init_process_group(backend='nccl')
892
+
893
+ # Get rank and world size from environment variables set by the launcher
894
+ rank = int(os.environ['RANK'])
895
+ world_size = int(os.environ['WORLD_SIZE'])
896
+ local_rank = int(os.environ['LOCAL_RANK'])
897
+
898
+ # Set the device for the current process. This is crucial.
899
+ torch.cuda.set_device(local_rank)
900
+ torch.manual_seed(42)
901
+
902
+ # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
903
+ # Pass the distributed info to the Trainer
904
+ trainer = Trainer(
905
+ config_path="/root/Trisf/config_edge_1024_error_8enc_8dec_woself_finetune_128to256.yaml",
906
+ rank=rank,
907
+ world_size=world_size,
908
+ local_rank=local_rank
909
+ )
910
+ trainer.train()
911
+
912
+ # Clean up the process group
913
+ dist.destroy_process_group()
914
+
915
+
916
+ if __name__ == '__main__':
917
+ main()
train_slat_vae_512_128to512_pointnet_head.py ADDED
@@ -0,0 +1,1090 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import os
3
+ # os.environ['ATTN_BACKEND'] = 'xformers'
4
+ import yaml
5
+ import torch
6
+ import time
7
+ from datetime import datetime
8
+ from torch.utils.data import DataLoader
9
+ from functools import partial
10
+ from triposf.modules.sparse.basic import SparseTensor
11
+ import torch.nn.functional as F
12
+ from torch.optim import AdamW
13
+ from torch.cuda.amp import GradScaler, autocast
14
+
15
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
16
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
17
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
18
+
19
+ from utils import load_pretrained_woself, AdaptiveFocalLoss, fast_isin, AsymmetricFocalLoss, DiceLoss
20
+
21
+ import torch.distributed as dist
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+ from torch.utils.data.distributed import DistributedSampler
24
+
25
+ from transformers import get_cosine_schedule_with_warmup
26
+ import math
27
+
28
+ import numpy as np
29
+ import open3d as o3d
30
+
31
+ def export_sampled_edges(coords, u, v, labels, edge_voxels=None, step_idx=0, save_dir="debug_viz"):
32
+ """
33
+ 导出顶点、采样的边以及背景边缘体素用于可视化 (PLY格式)
34
+ """
35
+ os.makedirs(save_dir, exist_ok=True)
36
+
37
+ # 转为 CPU numpy
38
+ coords_np = coords.detach().cpu().numpy() # [N, 4] (b, x, y, z)
39
+ u_np = u.detach().cpu().numpy()
40
+ v_np = v.detach().cpu().numpy()
41
+ labels_np = labels.detach().cpu().numpy().flatten()
42
+
43
+ edge_voxels_np = None
44
+ if edge_voxels is not None:
45
+ edge_voxels_np = edge_voxels.detach().cpu().numpy() # [M, 4]
46
+
47
+ # 按 Batch 处理 (因为 visualization 最好单个物体看)
48
+ batch_ids = coords_np[:, 0]
49
+ unique_batches = np.unique(batch_ids)
50
+
51
+ for b_id in unique_batches:
52
+ # 1. 筛选当前 Batch 的顶点
53
+ mask_b = (batch_ids == b_id)
54
+ # 获取局部坐标 (x,y,z)
55
+ curr_verts = coords_np[mask_b, 1:4]
56
+
57
+ # 建立 全局索引 -> 局部索引 的映射
58
+ # u 和 v 是基于整个 batch 的全局索引,需要转换
59
+ global_indices = np.where(mask_b)[0]
60
+ # 创建一个大的映射数组 (假设索引范围不超过 max_len)
61
+ max_idx = np.max(global_indices) + 1
62
+ global_to_local = np.full(max_idx, -1)
63
+ global_to_local[global_indices] = np.arange(len(global_indices))
64
+
65
+ # 2. 筛选当前 Batch 相关的边
66
+ # 边的两个端点都必须属于当前 batch
67
+ # 只要检查 u 是否在当前 batch 范围内即可 (假设跨 batch 不连线)
68
+ batch_edge_mask = np.isin(u_np, global_indices)
69
+
70
+ curr_u_global = u_np[batch_edge_mask]
71
+ curr_v_global = v_np[batch_edge_mask]
72
+ curr_labels = labels_np[batch_edge_mask]
73
+
74
+ # 转为局部索引 (用于构建 mesh)
75
+ curr_u = global_to_local[curr_u_global]
76
+ curr_v = global_to_local[curr_v_global]
77
+
78
+ # 3. 筛选当前 Batch 的 Edge Voxels
79
+ curr_edge_voxels = None
80
+ if edge_voxels_np is not None:
81
+ mask_ev = (edge_voxels_np[:, 0] == b_id)
82
+ curr_edge_voxels = edge_voxels_np[mask_ev, 1:4]
83
+
84
+ # 4. 写入 PLY 文件
85
+ filename = os.path.join(save_dir, f"step_{step_idx}_batch_{int(b_id)}.ply")
86
+ write_ply_with_edges_and_voxels(
87
+ filename,
88
+ curr_verts,
89
+ curr_u,
90
+ curr_v,
91
+ curr_labels,
92
+ curr_edge_voxels
93
+ )
94
+ print(f"Saved visualization to {filename}")
95
+
96
+ def write_ply_with_edges_and_voxels(filename, verts, u, v, labels, edge_voxels=None):
97
+ """
98
+ 手动写入 PLY 文件,包含顶点、边和背景体素点。
99
+ 为了区分,我们将它们合并到一个文件中,但使用颜色区分。
100
+ """
101
+ # ------------------------------------------------
102
+ # 准备数据
103
+ # ------------------------------------------------
104
+
105
+ # 1. 节点 (Vertices) -> 设为 青色
106
+ num_verts = len(verts)
107
+ vertex_colors = np.tile([0, 255, 255], (num_verts, 1))
108
+
109
+ # 2. 边缘体素 (Edge Voxels) -> 设为 灰色 (作为背景)
110
+ if edge_voxels is not None:
111
+ num_ev = len(edge_voxels)
112
+ # 为了放在同一个 vertex buffer,我们把 edge voxels 追加到 verts 后面
113
+ all_points = np.vstack([verts, edge_voxels])
114
+ ev_colors = np.tile([180, 180, 180], (num_ev, 1))
115
+ all_colors = np.vstack([vertex_colors, ev_colors])
116
+ else:
117
+ all_points = verts
118
+ all_colors = vertex_colors
119
+ num_ev = 0
120
+
121
+ num_total_points = len(all_points)
122
+
123
+ # 3. 边 (Edges)
124
+ # PLY 格式的 edge 索引是基于当前点列表的
125
+ # u, v 已经是基于 verts 的局部索引了,不需要偏移 (因为 verts 在最前面)
126
+
127
+ # 筛选:我们通常只想看 Positive 的边 (labels==1),或者用颜色区分
128
+ # 这里全部写入,用颜色区分
129
+ # 红色 = 负样本 (预测连接但实际上没连/负采样)
130
+ # 绿色 = 正样本 (GT连接)
131
+
132
+ edge_list = np.stack([u, v], axis=1)
133
+ num_edges = len(edge_list)
134
+
135
+ edge_colors = np.zeros((num_edges, 3), dtype=int)
136
+ edge_colors[labels == 1] = [0, 255, 0] # Green for Positive
137
+ edge_colors[labels == 0] = [255, 0, 0] # Red for Negative
138
+
139
+ # ------------------------------------------------
140
+ # 写入文件 Header
141
+ # ------------------------------------------------
142
+ with open(filename, 'w') as f:
143
+ f.write("ply\n")
144
+ f.write("format ascii 1.0\n")
145
+
146
+ # Vertex Element (包含 Nodes 和 EdgeVoxels)
147
+ f.write(f"element vertex {num_total_points}\n")
148
+ f.write("property float x\n")
149
+ f.write("property float y\n")
150
+ f.write("property float z\n")
151
+ f.write("property uchar red\n")
152
+ f.write("property uchar green\n")
153
+ f.write("property uchar blue\n")
154
+
155
+ # Edge Element (包含采样的连接)
156
+ f.write(f"element edge {num_edges}\n")
157
+ f.write("property list uchar int vertex_indices\n")
158
+ f.write("property uchar red\n")
159
+ f.write("property uchar green\n")
160
+ f.write("property uchar blue\n")
161
+
162
+ f.write("end_header\n")
163
+
164
+ # ------------------------------------------------
165
+ # 写入数据 Body
166
+ # ------------------------------------------------
167
+
168
+ # 1. Write Points (Vertices + Edge Voxels)
169
+ for i in range(num_total_points):
170
+ p = all_points[i]
171
+ c = all_colors[i]
172
+ f.write(f"{p[0]:.4f} {p[1]:.4f} {p[2]:.4f} {int(c[0])} {int(c[1])} {int(c[2])}\n")
173
+
174
+ # 2. Write Edges
175
+ for i in range(num_edges):
176
+ e = edge_list[i]
177
+ c = edge_colors[i]
178
+ f.write(f"2 {int(e[0])} {int(e[1])} {int(c[0])} {int(c[1])} {int(c[2])}\n")
179
+
180
+
181
+ def flatten_coords_4d(coords_4d: torch.Tensor):
182
+ coords_4d_long = coords_4d.long()
183
+
184
+ base_x = 512
185
+ base_y = 512 * 512
186
+ base_z = 512 * 512 * 512
187
+
188
+ flat_coords = coords_4d_long[:, 0] * base_z + \
189
+ coords_4d_long[:, 1] * base_y + \
190
+ coords_4d_long[:, 2] * base_x + \
191
+ coords_4d_long[:, 3]
192
+ return flat_coords
193
+
194
+ def downsample_voxels(
195
+ voxels: torch.Tensor,
196
+ input_resolution: int,
197
+ output_resolution: int
198
+ ) -> torch.Tensor:
199
+ if input_resolution % output_resolution != 0:
200
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
201
+ f"by output_resolution ({output_resolution}).")
202
+
203
+ factor = input_resolution // output_resolution
204
+
205
+ downsampled_voxels = voxels.clone().to(torch.long)
206
+
207
+ downsampled_voxels[:, 1:] //= factor
208
+
209
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
210
+ return unique_downsampled_voxels
211
+
212
+ class Trainer:
213
+ def __init__(self, config_path, rank, world_size, local_rank):
214
+ self.rank = rank
215
+ self.world_size = world_size
216
+ self.local_rank = local_rank
217
+ self.is_master = self.rank == 0
218
+
219
+ self.load_config(config_path)
220
+ self.accum_steps = max(1, 4 // self.cfg['training']['batch_size'])
221
+
222
+ self.config_hash = self.save_config_with_hash()
223
+ self.init_device()
224
+ self.init_dirs()
225
+ self.init_components()
226
+ self.init_training()
227
+
228
+ self.train_loss_history = []
229
+ self.eval_loss_history = []
230
+ self.best_eval_loss = float('inf')
231
+
232
+ def save_config_with_hash(self):
233
+ import hashlib
234
+
235
+ # Serialize config to hash
236
+ config_str = yaml.dump(self.cfg)
237
+ config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
238
+
239
+ # Prepare all flags as string for formatting
240
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
241
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
242
+
243
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
244
+
245
+ # Format save_dir with all placeholders
246
+ self.cfg['experiment']['save_dir'] = self.cfg['experiment']['save_dir'].format(
247
+ dataset_name=dataset_name,
248
+ config_hash=config_hash,
249
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
250
+ multires=self.cfg['model']['multires'],
251
+ add_block_embed=add_block_embed_flag,
252
+ using_attn=using_attn_flag,
253
+ batch_size=self.cfg['training']['batch_size'],
254
+ )
255
+
256
+ if self.is_master:
257
+ os.makedirs(self.save_dir, exist_ok=True)
258
+ config_path = os.path.join(self.save_dir, "config.yaml")
259
+ with open(config_path, 'w') as f:
260
+ yaml.dump(self.cfg, f)
261
+
262
+ dist.barrier()
263
+ return config_hash
264
+
265
+ def save_checkpoint(self, epoch, avg_loss, batch_idx):
266
+ if not self.is_master:
267
+ return
268
+ checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch{epoch}_batch{batch_idx}_loss{avg_loss:.4f}.pt")
269
+ config_path = os.path.join(self.save_dir, "config.yaml")
270
+
271
+ torch.save({
272
+ 'voxel_encoder': self.voxel_encoder.module.state_dict(),
273
+ 'vae': self.vae.module.state_dict(),
274
+ 'connection_head': self.connection_head.module.state_dict(),
275
+ 'epoch': epoch,
276
+ 'loss': avg_loss,
277
+ 'config': self.cfg
278
+ }, checkpoint_path)
279
+
280
+ def quoted_presenter(dumper, data):
281
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"')
282
+
283
+ yaml.add_representer(str, quoted_presenter)
284
+
285
+ with open(config_path, 'w') as f:
286
+ yaml.dump(self.cfg, f)
287
+
288
+ def load_config(self, config_path):
289
+ with open(config_path) as f:
290
+ self.cfg = yaml.safe_load(f)
291
+
292
+ # Extract and convert flags for formatting
293
+ add_block_embed_flag = "True" if self.cfg['model']['add_block_embed'] else "False"
294
+ using_attn_flag = "True" if self.cfg['model']['using_attn'] else "False"
295
+
296
+ dataset_name = os.path.basename(self.cfg['dataset']['path'])
297
+
298
+ self.save_dir = self.cfg['experiment']['save_dir'].format(
299
+ dataset_name=dataset_name,
300
+ n_train_samples=self.cfg['dataset']['n_train_samples'],
301
+ multires=self.cfg['model']['multires'],
302
+ add_block_embed=add_block_embed_flag,
303
+ using_attn=using_attn_flag,
304
+ batch_size=self.cfg['training']['batch_size'],
305
+ )
306
+
307
+ if self.is_master:
308
+ os.makedirs(self.save_dir, exist_ok=True)
309
+ dist.barrier()
310
+
311
+
312
+ def init_device(self):
313
+ self.device = torch.device(f"cuda:{self.local_rank}")
314
+
315
+ def init_dirs(self):
316
+ self.log_file = os.path.join(self.save_dir, f"training_log_{self.cfg['training']['lr']}.txt")
317
+ if self.is_master:
318
+ with open(self.log_file, "a") as f:
319
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
320
+ f.write(f"[{current_time}] Config loaded for distributed training with world size {self.world_size}\n")
321
+
322
+ def init_components(self):
323
+ self.dataset = VoxelVertexDataset_edge(
324
+ root_dir=self.cfg['dataset']['path'],
325
+ base_resolution=self.cfg['dataset']['base_resolution'],
326
+ min_resolution=self.cfg['dataset']['min_resolution'],
327
+ cache_dir=self.cfg['dataset']['cache_dir'],
328
+ renders_dir=self.cfg['dataset']['renders_dir'],
329
+
330
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
331
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
332
+
333
+ active_voxel_res=128,
334
+ pc_sample_number=819200,
335
+
336
+ sample_type=self.cfg['dataset']['sample_type'],
337
+ )
338
+
339
+ self.sampler = DistributedSampler(
340
+ self.dataset,
341
+ num_replicas=self.world_size,
342
+ rank=self.rank,
343
+ shuffle=True,
344
+ )
345
+
346
+ self.dataloader = DataLoader(
347
+ self.dataset,
348
+ batch_size=self.cfg['training']['batch_size'],
349
+ shuffle=False,
350
+ collate_fn=partial(collate_fn_pointnet,),
351
+ num_workers=self.cfg['training']['num_workers'],
352
+ pin_memory=True,
353
+ sampler=self.sampler,
354
+ # prefetch_factor=4,
355
+ persistent_workers=True,
356
+ )
357
+
358
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
359
+ in_channels=15,
360
+ hidden_dim=256,
361
+ out_channels=1024,
362
+ scatter_type='mean',
363
+ n_blocks=5,
364
+ resolution=128,
365
+ add_label=False,
366
+ ).to(self.device)
367
+
368
+ self.connection_head = ConnectionHead(
369
+ channels=128 * 2,
370
+ out_channels=1,
371
+ mlp_ratio=4,
372
+ ).to(self.device)
373
+
374
+ # ablation 3: voxelvae_1volume, have tested
375
+ self.vae = VoxelVAE(
376
+ in_channels=self.cfg['model']['in_channels'],
377
+ latent_dim=self.cfg['model']['latent_dim'],
378
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
379
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
380
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
381
+ num_heads=8,
382
+ num_head_channels=64,
383
+ mlp_ratio=4.0,
384
+ attn_mode="swin",
385
+ window_size=8,
386
+ pe_mode="ape",
387
+ use_fp16=False,
388
+ use_checkpoint=True,
389
+ qk_rms_norm=False,
390
+ using_subdivide=True,
391
+ using_attn=self.cfg['model']['using_attn'],
392
+ attn_first=self.cfg['model'].get('attn_first', True),
393
+ pred_direction=self.cfg['model'].get('pred_direction', False),
394
+ ).to(self.device)
395
+
396
+ if self.cfg['training']['from_pretrained']:
397
+ load_pretrained_woself(
398
+ checkpoint_path=self.cfg['training']['checkpoint_path'],
399
+ voxel_encoder=self.voxel_encoder,
400
+ vae=self.vae,
401
+ connection_head=self.connection_head,
402
+ optimizer=None,
403
+ )
404
+
405
+ self.voxel_encoder = DDP(self.voxel_encoder, device_ids=[self.local_rank], find_unused_parameters=False)
406
+ self.connection_head = DDP(self.connection_head, device_ids=[self.local_rank], find_unused_parameters=False)
407
+ self.vae = DDP(self.vae, device_ids=[self.local_rank], find_unused_parameters=False)
408
+
409
+ def init_training(self):
410
+ self.optimizer = AdamW(
411
+ list(self.vae.module.parameters()) +
412
+ list(self.voxel_encoder.module.parameters()) +
413
+ list(self.connection_head.module.parameters()),
414
+ lr=self.cfg['training']['lr'],
415
+ weight_decay=0.01,
416
+ )
417
+
418
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.accum_steps)
419
+ max_epochs = self.cfg['training']['max_epochs']
420
+ num_training_steps = max_epochs * num_update_steps_per_epoch
421
+
422
+ num_warmup_steps = 200
423
+
424
+ self.scheduler = get_cosine_schedule_with_warmup(
425
+ self.optimizer,
426
+ num_warmup_steps=num_warmup_steps,
427
+ num_training_steps=num_training_steps
428
+ )
429
+
430
+ self.focal_loss = AdaptiveFocalLoss(gamma=2.0, max_alpha=10.0).to(self.device)
431
+ self.mse_loss = nn.MSELoss(reduction='mean').to(self.device)
432
+ self.asyloss = AsymmetricFocalLoss(
433
+ gamma_pos=0.0,
434
+ gamma_neg=4.0,
435
+ clip=0.05,
436
+ )
437
+
438
+ self.bce_loss = torch.nn.BCEWithLogitsLoss()
439
+
440
+ self.dice_loss = DiceLoss()
441
+ self.scaler = GradScaler()
442
+
443
+ def train_step(self, batch):
444
+ """Modified training step that handles vertex and edge voxels separately after initial prediction."""
445
+ # 1. Retrieve data from batch
446
+ combined_voxels_512 = batch['combined_voxels_512'].to(self.device)
447
+ combined_voxel_labels_512 = batch['combined_voxel_labels_512'].to(self.device)
448
+ gt_vertex_voxels_512 = batch['gt_vertex_voxels_512'].to(self.device)
449
+
450
+ gt_edges = batch['gt_vertex_edge_indices_512'].to(self.device)
451
+
452
+ edge_mask = (combined_voxel_labels_512 == 1)
453
+
454
+
455
+ vtx_128 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=128)
456
+ vtx_256 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=256)
457
+ vtx_512 = gt_vertex_voxels_512
458
+
459
+ edge_128 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=128)
460
+ edge_256 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=256)
461
+ edge_512 = combined_voxels_512
462
+
463
+ active_coords = batch['active_voxels_128'].to(self.device)
464
+ point_cloud = batch['point_cloud_128'].to(self.device)
465
+
466
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
467
+ active_voxel_feats = self.voxel_encoder(
468
+ p=point_cloud,
469
+ sparse_coords=active_coords,
470
+ res=128,
471
+ bbox_size=(-0.5, 0.5),
472
+ )
473
+
474
+ sparse_input = SparseTensor(
475
+ feats=active_voxel_feats,
476
+ coords=active_coords.int()
477
+ )
478
+
479
+ gt_edge_voxels_list = [
480
+ edge_128,
481
+ edge_256,
482
+ edge_512,
483
+ ]
484
+
485
+ gt_vertex_voxels_list = [
486
+ vtx_128,
487
+ vtx_256,
488
+ vtx_512,
489
+ ]
490
+
491
+ results, posterior, latent_128 = self.vae(
492
+ sparse_input,
493
+ gt_vertex_voxels_list=gt_vertex_voxels_list,
494
+ gt_edge_voxels_list=gt_edge_voxels_list,
495
+ training=True,
496
+ sample_ratio=0.,
497
+ )
498
+
499
+ # print("results[-1]['edge']['coords_4d'][1827:1830]", results[-1]['edge']['coords_4d'][1827:1830])
500
+ total_loss = 0.
501
+ prune_loss_total = 0.
502
+ vertex_loss_total = 0.
503
+ edge_loss_total=0.
504
+
505
+ with autocast(dtype=torch.bfloat16):
506
+ initial_result = results[0]
507
+ vertex_mask = initial_result['vertex_mask']
508
+ vtx_logits = initial_result['vtx_feats']
509
+ vertex_loss = self.asyloss(vtx_logits.squeeze(-1), vertex_mask.float())
510
+
511
+ edge_mask = initial_result['edge_mask']
512
+ edge_logits = initial_result['edge_feats']
513
+ edge_loss = self.asyloss(edge_logits.squeeze(-1), edge_mask.float())
514
+
515
+ vertex_loss_total += vertex_loss
516
+ edge_loss_total += edge_loss
517
+
518
+ total_loss += vertex_loss
519
+ total_loss += edge_loss
520
+
521
+ # Process each level's results
522
+ for idx, res_dict in enumerate(results[1:], start=1):
523
+ # Vertex branch losses
524
+ vertex_pred_coords = res_dict['vertex']['occ_coords']
525
+ vertex_occ_probs = res_dict['vertex']['occ_probs']
526
+ vertex_gt_coords = res_dict['vertex']['coords']
527
+
528
+ vertex_labels = fast_isin(vertex_pred_coords, vertex_gt_coords, resolution=512).float()
529
+ # print('vertex_labels.sum()', vertex_labels.sum(), idx)
530
+ vertex_logits = vertex_occ_probs.squeeze()
531
+
532
+ # if vertex_labels.sum() > 0 and vertex_labels.sum() < len(vertex_labels):
533
+ vertex_prune_loss = self.focal_loss(vertex_logits, vertex_labels)
534
+ # vertex_prune_loss = self.dice_loss(vertex_logits, vertex_labels)
535
+
536
+ # dilation 1: bce loss
537
+ # vertex_prune_loss = self.bce_loss(vertex_logits, vertex_labels,)
538
+
539
+ prune_loss_total += vertex_prune_loss
540
+ total_loss += vertex_prune_loss
541
+
542
+
543
+ # Edge branch losses
544
+ edge_pred_coords = res_dict['edge']['occ_coords']
545
+ edge_occ_probs = res_dict['edge']['occ_probs']
546
+ edge_gt_coords = res_dict['edge']['coords']
547
+
548
+ edge_labels = fast_isin(edge_pred_coords, edge_gt_coords, resolution=512).float()
549
+ # print('edge_labels.sum()', edge_labels.sum(), idx)
550
+ edge_logits = edge_occ_probs.squeeze()
551
+ # if edge_labels.sum() > 0 and edge_labels.sum() < len(edge_labels):
552
+ edge_prune_loss = self.focal_loss(edge_logits, edge_labels)
553
+
554
+ # dilation 1: bce loss
555
+ # edge_prune_loss = self.bce_loss(edge_logits, edge_labels,)
556
+
557
+ prune_loss_total += edge_prune_loss
558
+ total_loss += edge_prune_loss
559
+
560
+ if idx == 2:
561
+ # pred_coords = res_dict['edge']['coords_4d'] # [N,4] (b,x,y,z)
562
+ # pred_feats = res_dict['edge']['predicted_offset_feats'] # [N,C]
563
+
564
+ # gt_coords = gt_edge_voxels_512.to(pred_coords.device) # [M,4]
565
+ # gt_feats = gt_edge_errors_512[:, 1:].to(pred_coords.device) # [M,C]
566
+
567
+ # pred_keys = flatten_coords_4d(pred_coords)
568
+ # gt_keys = flatten_coords_4d(gt_coords)
569
+
570
+ # sorted_pred_keys, pred_order = torch.sort(pred_keys)
571
+ # pred_coords_sorted = pred_coords[pred_order]
572
+ # pred_feats_sorted = pred_feats[pred_order]
573
+
574
+ # sorted_gt_keys, gt_order = torch.sort(gt_keys)
575
+ # gt_coords_sorted = gt_coords[gt_order]
576
+ # gt_feats_sorted = gt_feats[gt_order]
577
+
578
+ # pos = torch.searchsorted(sorted_gt_keys, sorted_pred_keys)
579
+ # valid_mask = (pos < len(sorted_gt_keys)) & (sorted_gt_keys[pos] == sorted_pred_keys)
580
+
581
+ # if valid_mask.any():
582
+ # # print('valid_mask.sum()', valid_mask.sum())
583
+ # matched_pred_feats = pred_feats_sorted[valid_mask]
584
+ # matched_gt_feats = gt_feats_sorted[pos[valid_mask]]
585
+ # mse_loss_feats = self.mse_loss(matched_pred_feats, matched_gt_feats * 2)
586
+ # total_loss += mse_loss_feats * 0.
587
+
588
+ # if self.cfg['model'].get('pred_direction', False):
589
+ # pred_dirs = res_dict['edge']['predicted_direction_feats']
590
+ # dir_gt_device = dir_gt.to(pred_coords.device)
591
+
592
+ # pred_dirs_sorted = pred_dirs[pred_order]
593
+ # dir_gt_sorted = dir_gt_device[gt_order]
594
+
595
+ # matched_pred_dirs = pred_dirs_sorted[valid_mask]
596
+ # matched_gt_dirs = dir_gt_sorted[pos[valid_mask]]
597
+
598
+ # mse_loss_dirs = self.mse_loss(matched_pred_dirs, matched_gt_dirs)
599
+ # total_loss += mse_loss_dirs * 0.
600
+ # else:
601
+ # mse_loss_feats = torch.tensor(0., device=pred_coords.device)
602
+ # if self.cfg['model'].get('pred_direction', False):
603
+ # mse_loss_dirs = torch.tensor(0., device=pred_coords.device)
604
+
605
+ mse_loss_dirs = torch.tensor(0., device=self.device)
606
+ mse_loss_feats = torch.tensor(0., device=self.device)
607
+
608
+ # --- Vertex Branch (Connection Loss 核心) ---
609
+ vtx_pred_coords = res_dict['vertex']['coords_4d'] # [N, 4]
610
+ vtx_pred_feats = res_dict['vertex']['feats'] # [N, C]
611
+
612
+ # 1.1 排序 (既用于��配 GT,也用于快速寻找空间邻居)
613
+ vtx_pred_keys = flatten_coords_4d(vtx_pred_coords)
614
+ vtx_pred_keys_sorted, vtx_pred_order = torch.sort(vtx_pred_keys)
615
+
616
+ # 1.2 匹配 GT
617
+ vtx_gt_keys = flatten_coords_4d(gt_vertex_voxels_512.to(self.device))
618
+ vtx_pos = torch.searchsorted(vtx_pred_keys_sorted, vtx_gt_keys)
619
+ vtx_pos = vtx_pos.clamp(max=len(vtx_pred_keys_sorted) - 1)
620
+ vtx_match_mask = (vtx_pred_keys_sorted[vtx_pos] == vtx_gt_keys)
621
+
622
+ gt_to_pred_mapping = torch.full((len(vtx_gt_keys),), -1, device=self.device, dtype=torch.long)
623
+ matched_pred_indices = vtx_pred_order[vtx_pos[vtx_match_mask]]
624
+ gt_to_pred_mapping[vtx_match_mask] = matched_pred_indices
625
+
626
+ # ====================================================
627
+ # 2. 构建核心数据:正样本 Hash 集合
628
+ # ====================================================
629
+ # 这里的 pos_u/pos_v 仅用于构建 "什么是真连接" 的查询表
630
+ u_gt, v_gt = gt_edges[:, 0], gt_edges[:, 1]
631
+ u_pred = gt_to_pred_mapping[u_gt]
632
+ v_pred = gt_to_pred_mapping[v_gt]
633
+
634
+ valid_edge_mask = (u_pred != -1) & (v_pred != -1)
635
+ real_pos_u = u_pred[valid_edge_mask]
636
+ real_pos_v = v_pred[valid_edge_mask]
637
+
638
+ num_real_pos = real_pos_u.shape[0]
639
+ num_total_nodes = vtx_pred_coords.shape[0]
640
+
641
+ # if num_real_pos > 0:
642
+ # # 2. 构建候选样本 (Candidate Generation)
643
+ # # ====================================================
644
+ # cand_u_list = []
645
+ # cand_v_list = []
646
+
647
+ # batch_ids = vtx_pred_coords[:, 0]
648
+ # unique_batches = torch.unique(batch_ids)
649
+
650
+ # RADIUS = 32
651
+ # MAX_PTS_FOR_DIST = 12000
652
+ # K_RANDOM = 32
653
+
654
+ # for b_id in unique_batches:
655
+ # mask_b = (batch_ids == b_id)
656
+ # indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
657
+ # coords_b = vtx_pred_coords[mask_b, 1:4].float() # (x,y,z)
658
+ # num_b = coords_b.shape[0]
659
+
660
+ # if num_b < 2: continue
661
+
662
+ # # --- A. Radius Graph (Hard Negatives) ---
663
+ # if num_b <= MAX_PTS_FOR_DIST:
664
+ # # 计算距离矩阵 [M, M]
665
+ # # 注意:autocast 下 float16 的 cdist 可能精度不够,建议转 float32
666
+ # dist_mat = torch.cdist(coords_b.float(), coords_b.float())
667
+
668
+ # # 找到距离小于 Radius 的点对 (排除自环)
669
+ # adj_mat = (dist_mat < RADIUS) & (dist_mat > 1e-6)
670
+
671
+ # # 提取索引 (local indices in batch)
672
+ # src_local, dst_local = torch.nonzero(adj_mat, as_tuple=True)
673
+
674
+ # # 映射回全局索引
675
+ # cand_u_list.append(indices_b[src_local])
676
+ # cand_v_list.append(indices_b[dst_local])
677
+ # else:
678
+ # # 如果点太多,显存不够,退化为随机局部采样或跳过
679
+ # # 这里简单处理:跳过 Radius Graph,依赖 Random
680
+ # pass
681
+
682
+ # # --- B. Random Sampling (Easy Negatives) ---
683
+ # # 随机生成 num_b * K 对
684
+ # n_rand = num_b * K_RANDOM
685
+ # rand_src_local = torch.randint(0, num_b, (n_rand,), device=self.device)
686
+ # rand_dst_local = torch.randint(0, num_b, (n_rand,), device=self.device)
687
+
688
+ # # 映射回全局索引
689
+ # cand_u_list.append(indices_b[rand_src_local])
690
+ # cand_v_list.append(indices_b[rand_dst_local])
691
+
692
+ # # 合并所有来源 (GT + Radius + Random)
693
+ # # 注意:我们把 real_pos 也加进来,确保正样本一定在列表里
694
+ # all_u = torch.cat([real_pos_u] + cand_u_list)
695
+ # all_v = torch.cat([real_pos_v] + cand_v_list)
696
+
697
+ # # 3. 去重与 Labeling (Deduplication & Labeling)
698
+ # # ====================================================
699
+ # # 构造无向边 Hash: min * N + max
700
+ # # 确保 MAX_NODES 足够大,比如 1000000 或 num_total_nodes
701
+ # HASH_BASE = num_total_nodes + 100
702
+
703
+ # p_min = torch.min(all_u, all_v)
704
+ # p_max = torch.max(all_u, all_v)
705
+
706
+ # # 过滤掉自环 (u==v)
707
+ # valid_pair = (p_min != p_max)
708
+ # p_min = p_min[valid_pair]
709
+ # p_max = p_max[valid_pair]
710
+
711
+ # all_hashes = p_min.long() * HASH_BASE + p_max.long()
712
+
713
+ # # --- 核心:去重 ---
714
+ # unique_hashes = torch.unique(all_hashes)
715
+
716
+ # # 解码回 u, v
717
+ # final_u = unique_hashes // HASH_BASE
718
+ # final_v = unique_hashes % HASH_BASE
719
+
720
+ # # --- Labeling ---
721
+ # # 构建 GT 的 Hash 表用于查询
722
+ # gt_min = torch.min(real_pos_u, real_pos_v)
723
+ # gt_max = torch.max(real_pos_u, real_pos_v)
724
+ # gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
725
+ # gt_hashes = torch.unique(gt_hashes) # GT 也去重一下保险
726
+ # gt_hashes_sorted, _ = torch.sort(gt_hashes)
727
+
728
+ # # 查询 unique_hashes 是否在 gt_hashes 中
729
+ # # 使用 searchsorted
730
+ # idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
731
+ # idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
732
+ # is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
733
+
734
+ # targets = is_connected.float().unsqueeze(-1) # [N_pairs, 1]
735
+
736
+ # # 4. 前向传播与 Loss
737
+ # # ====================================================
738
+ # feat_u = vtx_pred_feats[final_u]
739
+ # feat_v = vtx_pred_feats[final_v]
740
+
741
+ # # 对称特征融合
742
+ # feat_uv = torch.cat([feat_u, feat_v], dim=-1)
743
+ # feat_vu = torch.cat([feat_v, feat_u], dim=-1)
744
+
745
+ # logits_uv = self.connection_head(feat_uv)
746
+ # logits_vu = self.connection_head(feat_vu)
747
+ # logits = logits_uv + logits_vu
748
+
749
+ # # print('targets.sum()', targets.sum())
750
+ # # print('targets.shape', targets.shape)
751
+
752
+ # # viz_edge_voxels = combined_voxels_512
753
+
754
+ # # export_sampled_edges(
755
+ # # coords=vtx_pred_coords, # [N, 4] (Batch, X, Y, Z) - 顶点
756
+ # # u=final_u, # [E] - 连线起点索引
757
+ # # v=final_v, # [E] - 连线终点索引
758
+ # # labels=targets, # [E, 1] - 连线标签 (1=Pos, 0=Neg)
759
+ # # edge_voxels=viz_edge_voxels, # [M, 4] - 新增:512分辨率边缘体素
760
+ # # step_idx=0, # 或者传入当前的 step/epoch
761
+ # # save_dir="debug_viz" # 建议指定保存路径
762
+ # # )
763
+ # # exit()
764
+
765
+ # # asyloss Loss
766
+ # connection_loss = self.asyloss(logits, targets)
767
+ # total_loss += connection_loss
768
+
769
+ # else:
770
+ # connection_loss = torch.tensor(0., device=self.device)
771
+
772
+ if num_real_pos > 0:
773
+ # ====================================================
774
+ # 2. 构建候选样本 (Candidate Generation) - KNN + Random 版
775
+ # ====================================================
776
+ cand_u_list = []
777
+ cand_v_list = []
778
+
779
+ batch_ids = vtx_pred_coords[:, 0]
780
+ unique_batches = torch.unique(batch_ids)
781
+
782
+ # === 配置 ===
783
+ K_KNN = 64
784
+ K_RANDOM = 32
785
+ # 这样每个点大约产生 32 条边,总显存 = N * 32 * MLP大小,非常稳定
786
+
787
+ # 安全阈值:如果点数太多,cdist矩阵本身会爆,需要限制
788
+ MAX_PTS_FOR_KNN = 12000
789
+
790
+ for b_id in unique_batches:
791
+ mask_b = (batch_ids == b_id)
792
+ indices_b = torch.nonzero(mask_b).squeeze(-1) # Global indices
793
+ coords_b = vtx_pred_coords[mask_b, 1:4].float()
794
+ num_b = coords_b.shape[0]
795
+
796
+ if num_b < 2: continue
797
+
798
+ # --- A. KNN Graph (距离最近的 K 个) ---
799
+ # 只有当点数在可接受范围内时才算矩阵,否则只用随机
800
+ if num_b <= MAX_PTS_FOR_KNN:
801
+ # 1. 计算距离矩阵 [N, N]
802
+ # 注意: 12000个点产生的矩阵约 576MB,非常安全
803
+ dist_mat = torch.cdist(coords_b, coords_b)
804
+
805
+ # 2. 取最近的 K+1 个 (包含自己)
806
+ # largest=False 表示取最小距离
807
+ k_val = min(K_KNN + 1, num_b)
808
+ _, knn_indices_local = torch.topk(dist_mat, k=k_val, dim=1, largest=False)
809
+
810
+ # 3. 去掉自己 (第0个通常是距离为0的自己)
811
+ knn_indices_local = knn_indices_local[:, 1:]
812
+
813
+ # 4. 构建边索引
814
+ # repeat_interleave 生成源点: [0,0,0, 1,1,1...]
815
+ src_local = torch.arange(num_b, device=self.device).repeat_interleave(knn_indices_local.shape[1])
816
+ dst_local = knn_indices_local.flatten()
817
+
818
+ cand_u_list.append(indices_b[src_local])
819
+ cand_v_list.append(indices_b[dst_local])
820
+
821
+ # --- B. Random Sampling (随机 K 个) ---
822
+ # 不管点多点少,都可以做随机
823
+ n_rand = num_b * K_RANDOM
824
+ if n_rand > 0:
825
+ rand_src = torch.randint(0, num_b, (n_rand,), device=self.device)
826
+ rand_dst = torch.randint(0, num_b, (n_rand,), device=self.device)
827
+
828
+ cand_u_list.append(indices_b[rand_src])
829
+ cand_v_list.append(indices_b[rand_dst])
830
+
831
+ if len(cand_u_list) > 0:
832
+ all_u = torch.cat([real_pos_u] + cand_u_list)
833
+ all_v = torch.cat([real_pos_v] + cand_v_list)
834
+ else:
835
+ all_u = real_pos_u
836
+ all_v = real_pos_v
837
+
838
+ # ====================================================
839
+ # 3. 去重与 Labeling (Logic 保持不变)
840
+ # ====================================================
841
+ HASH_BASE = num_total_nodes + 100
842
+
843
+ p_min = torch.min(all_u, all_v)
844
+ p_max = torch.max(all_u, all_v)
845
+ valid_pair = (p_min != p_max)
846
+ p_min = p_min[valid_pair]
847
+ p_max = p_max[valid_pair]
848
+
849
+ all_hashes = p_min.long() * HASH_BASE + p_max.long()
850
+
851
+ # 去重:因为 KNN 和 Random 可能会重复,或者和 GT 重复
852
+ unique_hashes = torch.unique(all_hashes)
853
+
854
+ # 【注意】这里不需要再做 Max Limit 截断了
855
+ # 因为边数严格受控于 (N * K),不会出现数量级爆炸的情况。
856
+
857
+ # 解码回 u, v
858
+ final_u = unique_hashes // HASH_BASE
859
+ final_v = unique_hashes % HASH_BASE
860
+
861
+ # --- Labeling ---
862
+ gt_min = torch.min(real_pos_u, real_pos_v)
863
+ gt_max = torch.max(real_pos_u, real_pos_v)
864
+ gt_hashes = gt_min.long() * HASH_BASE + gt_max.long()
865
+ gt_hashes = torch.unique(gt_hashes)
866
+ gt_hashes_sorted, _ = torch.sort(gt_hashes)
867
+
868
+ idx_search = torch.searchsorted(gt_hashes_sorted, unique_hashes)
869
+ idx_search = idx_search.clamp(max=len(gt_hashes_sorted) - 1)
870
+ is_connected = (gt_hashes_sorted[idx_search] == unique_hashes)
871
+
872
+ targets = is_connected.float().unsqueeze(-1)
873
+
874
+ feat_u = vtx_pred_feats[final_u]
875
+ feat_v = vtx_pred_feats[final_v]
876
+
877
+ feat_uv = torch.cat([feat_u, feat_v], dim=-1)
878
+ feat_vu = torch.cat([feat_v, feat_u], dim=-1)
879
+
880
+ logits_uv = self.connection_head(feat_uv)
881
+ logits_vu = self.connection_head(feat_vu)
882
+ logits = logits_uv + logits_vu
883
+
884
+ # viz_edge_voxels = combined_voxels_512
885
+ # export_sampled_edges(
886
+ # coords=vtx_pred_coords, # [N, 4] (Batch, X, Y, Z) - 顶点
887
+ # u=final_u, # [E] - 连线起点索引
888
+ # v=final_v, # [E] - 连线终点索引
889
+ # labels=targets, # [E, 1] - 连线标签 (1=Pos, 0=Neg)
890
+ # edge_voxels=viz_edge_voxels, # [M, 4] - 新增:512分辨率边缘体素
891
+ # step_idx=0, # 或者传入当前的 step/epoch
892
+ # save_dir="debug_viz" # 建议指定保存路径
893
+ # )
894
+ # exit()
895
+
896
+ connection_loss = self.asyloss(logits, targets)
897
+ total_loss += connection_loss
898
+
899
+ else:
900
+ connection_loss = torch.tensor(0., device=self.device)
901
+
902
+
903
+ # KL loss
904
+ kl_loss = posterior.kl(dims=(1,)).mean() * 1e-3 # 1e-3 before
905
+ total_loss += kl_loss
906
+
907
+ # Backpropagation
908
+ scaled_total_loss = total_loss / self.accum_steps
909
+ self.scaler.scale(scaled_total_loss).backward()
910
+
911
+ return {
912
+ 'total_loss': total_loss.item(),
913
+ 'kl_loss': kl_loss.item(),
914
+ 'prune_loss': prune_loss_total.item(),
915
+ 'vertex_loss': vertex_loss_total.item(),
916
+ 'edge_loss': edge_loss_total.item(),
917
+ 'offset_loss': mse_loss_feats.item(),
918
+ 'direction_loss': mse_loss_dirs.item(),
919
+ 'connection_loss': connection_loss.item(),
920
+ }
921
+
922
+
923
+ def train(self):
924
+ accum_steps = self.accum_steps
925
+ for epoch in range(self.cfg['training']['start_epoch'], self.cfg['training']['max_epochs']):
926
+ self.dataloader.sampler.set_epoch(epoch)
927
+ # Initialize metrics
928
+ metrics = {
929
+ 'total_loss': 0.0,
930
+ 'kl_loss': 0.0,
931
+ 'prune_loss': 0.0,
932
+ 'vertex_loss': 0.0,
933
+ 'edge_loss': 0.0,
934
+ 'offset_loss': 0.0,
935
+ 'direction_loss': 0.0,
936
+ 'connection_loss': 0.0,
937
+ }
938
+ num_batches = 0
939
+ self.optimizer.zero_grad(set_to_none=True)
940
+
941
+ for i, batch in enumerate(self.dataloader):
942
+ # Get all losses from train_step
943
+ if batch is None:
944
+ continue
945
+ step_losses = self.train_step(batch)
946
+
947
+ # Accumulate losses
948
+ for key in metrics:
949
+ metrics[key] += step_losses[key]
950
+ num_batches += 1
951
+
952
+ if (i + 1) % accum_steps == 0:
953
+ self.scaler.unscale_(self.optimizer)
954
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
955
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
956
+ torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=1.0)
957
+
958
+ self.scaler.step(self.optimizer)
959
+ self.scaler.update()
960
+ self.optimizer.zero_grad(set_to_none=True)
961
+
962
+ self.scheduler.step()
963
+
964
+ # Print batch-level metrics
965
+ if self.is_master:
966
+ avg_metric = {key: value / num_batches for key, value in metrics.items()}
967
+ print(
968
+ f"[Epoch {epoch}] Batch:{num_batches} "
969
+ f"AvgL:{avg_metric['total_loss']:.4f} "
970
+ f"Loss: {step_losses['total_loss']:.4f}, "
971
+ f"KLL: {step_losses['kl_loss']:.4f}, "
972
+ f"PruneL: {step_losses['prune_loss']:.4f}, "
973
+ f"VertexL: {step_losses['vertex_loss']:.4f}, "
974
+ f"EdgeL: {step_losses['edge_loss']:.4f}, "
975
+ f"OffsetL: {step_losses['offset_loss']:.4f}, "
976
+ f"DireL: {step_losses['direction_loss']:.4f}, "
977
+ f"ConL: {step_losses['connection_loss']:.4f}, "
978
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
979
+ )
980
+
981
+ if i % 2000 == 0 and i != 0:
982
+ self.save_checkpoint(epoch, avg_metric['total_loss'], i)
983
+ with open(self.log_file, "a") as f:
984
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
985
+ log_line = (
986
+ f"Epoch {epoch:05d} | "
987
+ f"Batch {i:05d} | "
988
+ f"Loss: {avg_metric['total_loss']:.6f} "
989
+ f"Avg KLL: {avg_metric['kl_loss']:.4f} "
990
+ f"Avg PruneL: {avg_metric['prune_loss']:.4f} "
991
+ f"Avg VertexL: {avg_metric['vertex_loss']:.4f} "
992
+ f"Avg EdgeL: {avg_metric['edge_loss']:.4f} "
993
+ f"Avg OffsetL: {avg_metric['offset_loss']:.4f} "
994
+ f"Avg DireL: {avg_metric['direction_loss']:.4f} "
995
+ f"Avg ConL: {avg_metric['connection_loss']:.4f} "
996
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
997
+ f"[{current_time}]\n"
998
+ )
999
+ f.write(log_line)
1000
+
1001
+ if num_batches % accum_steps != 0:
1002
+ self.scaler.unscale_(self.optimizer)
1003
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
1004
+ torch.nn.utils.clip_grad_norm_(self.voxel_encoder.parameters(), max_norm=1.0)
1005
+ torch.nn.utils.clip_grad_norm_(self.connection_head.parameters(), max_norm=1.0)
1006
+
1007
+ self.scaler.step(self.optimizer)
1008
+ self.scaler.update()
1009
+ self.optimizer.zero_grad(set_to_none=True)
1010
+
1011
+ self.scheduler.step()
1012
+
1013
+ # Calculate epoch averages
1014
+ avg_metrics = {key: value / num_batches for key, value in metrics.items()}
1015
+ self.train_loss_history.append(avg_metrics['total_loss'])
1016
+
1017
+
1018
+ # Log to file
1019
+ if self.is_master:
1020
+ with open(self.log_file, "a") as f:
1021
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1022
+ log_line = (
1023
+ f"Epoch {epoch:05d} | "
1024
+ f"Loss: {avg_metrics['total_loss']:.6f} "
1025
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
1026
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
1027
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
1028
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
1029
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
1030
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
1031
+ f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
1032
+ f"LR: {self.optimizer.param_groups[0]['lr']:.4e} "
1033
+ f"[{current_time}]\n"
1034
+ )
1035
+ f.write(log_line)
1036
+
1037
+ # Print epoch summary
1038
+ print(
1039
+ f"[Epoch {epoch}] "
1040
+ f"Avg Loss: {avg_metrics['total_loss']:.4f} "
1041
+ f"Avg KLL: {avg_metrics['kl_loss']:.4f} "
1042
+ f"Avg PruneL: {avg_metrics['prune_loss']:.4f} "
1043
+ f"Avg VertexL: {avg_metrics['vertex_loss']:.4f} "
1044
+ f"Avg EdgeL: {avg_metrics['edge_loss']:.4f} "
1045
+ f"Avg OffsetL: {avg_metrics['offset_loss']:.4f} "
1046
+ f"Avg DireL: {avg_metrics['direction_loss']:.4f} "
1047
+ f"Avg ConL: {avg_metrics['connection_loss']:.4f} "
1048
+ f"[{current_time}]\n"
1049
+ )
1050
+
1051
+ # Save checkpoint
1052
+ if epoch % self.cfg['training']['save_every'] == 0:
1053
+ self.save_checkpoint(epoch, avg_metrics['total_loss'], i)
1054
+
1055
+ # Update learning rate
1056
+ if self.is_master:
1057
+ current_lr = self.optimizer.param_groups[0]['lr']
1058
+ print(f"Epoch {epoch}: Learning rate updated to {current_lr:.2e}")
1059
+
1060
+ dist.barrier()
1061
+
1062
+ def main():
1063
+ # Initialize the process group
1064
+ dist.init_process_group(backend='nccl')
1065
+
1066
+ # Get rank and world size from environment variables set by the launcher
1067
+ rank = int(os.environ['RANK'])
1068
+ world_size = int(os.environ['WORLD_SIZE'])
1069
+ local_rank = int(os.environ['LOCAL_RANK'])
1070
+
1071
+ # Set the device for the current process. This is crucial.
1072
+ torch.cuda.set_device(local_rank)
1073
+ torch.manual_seed(42+rank)
1074
+
1075
+ # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
1076
+ # Pass the distributed info to the Trainer
1077
+ trainer = Trainer(
1078
+ config_path="/home/tiger/yy/src/Michelangelo-master/config_edge_1024_error_8enc_8dec_woself_finetune_128to512.yaml",
1079
+ rank=rank,
1080
+ world_size=world_size,
1081
+ local_rank=local_rank
1082
+ )
1083
+ trainer.train()
1084
+
1085
+ # Clean up the process group
1086
+ dist.destroy_process_group()
1087
+
1088
+
1089
+ if __name__ == '__main__':
1090
+ main()
trellis/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import models
2
+ from . import modules
3
+ from . import pipelines
4
+ from . import renderers
5
+ from . import representations
6
+ from . import utils
trellis/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (344 Bytes). View file
 
trellis/datasets/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'SparseStructure': 'sparse_structure',
5
+
6
+ 'SparseFeat2Render': 'sparse_feat2render',
7
+ 'SLat2Render':'structured_latent2render',
8
+ 'Slat2RenderGeo':'structured_latent2render',
9
+
10
+ 'SparseStructureLatent': 'sparse_structure_latent',
11
+ 'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
12
+ 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
13
+
14
+ 'SLat': 'structured_latent',
15
+ 'TextConditionedSLat': 'structured_latent',
16
+ 'ImageConditionedSLat': 'structured_latent',
17
+ }
18
+
19
+ __submodules = []
20
+
21
+ __all__ = list(__attributes.keys()) + __submodules
22
+
23
+ def __getattr__(name):
24
+ if name not in globals():
25
+ if name in __attributes:
26
+ module_name = __attributes[name]
27
+ module = importlib.import_module(f".{module_name}", __name__)
28
+ globals()[name] = getattr(module, name)
29
+ elif name in __submodules:
30
+ module = importlib.import_module(f".{name}", __name__)
31
+ globals()[name] = module
32
+ else:
33
+ raise AttributeError(f"module {__name__} has no attribute {name}")
34
+ return globals()[name]
35
+
36
+
37
+ # For Pylance
38
+ if __name__ == '__main__':
39
+ from .sparse_structure import SparseStructure
40
+
41
+ from .sparse_feat2render import SparseFeat2Render
42
+ from .structured_latent2render import (
43
+ SLat2Render,
44
+ Slat2RenderGeo,
45
+ )
46
+
47
+ from .sparse_structure_latent import (
48
+ SparseStructureLatent,
49
+ TextConditionedSparseStructureLatent,
50
+ ImageConditionedSparseStructureLatent,
51
+ )
52
+
53
+ from .structured_latent import (
54
+ SLat,
55
+ TextConditionedSLat,
56
+ ImageConditionedSLat,
57
+ )
58
+
trellis/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
trellis/datasets/__pycache__/components.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
trellis/datasets/__pycache__/sparse_structure_latent.cpython-310.pyc ADDED
Binary file (6.94 kB). View file
 
trellis/datasets/components.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from abc import abstractmethod
3
+ import os
4
+ import json
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class StandardDatasetBase(Dataset):
13
+ """
14
+ Base class for standard datasets.
15
+
16
+ Args:
17
+ roots (str): paths to the dataset
18
+ """
19
+
20
+ def __init__(self,
21
+ roots: str,
22
+ ):
23
+ super().__init__()
24
+ self.roots = roots.split(',')
25
+ self.instances = []
26
+ self.metadata = pd.DataFrame()
27
+
28
+ self._stats = {}
29
+ for root in self.roots:
30
+ key = os.path.basename(root)
31
+ self._stats[key] = {}
32
+ metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
33
+ self._stats[key]['Total'] = len(metadata)
34
+ metadata, stats = self.filter_metadata(metadata)
35
+ self._stats[key].update(stats)
36
+ self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
37
+ metadata.set_index('sha256', inplace=True)
38
+ self.metadata = pd.concat([self.metadata, metadata])
39
+
40
+ @abstractmethod
41
+ def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
42
+ pass
43
+
44
+ @abstractmethod
45
+ def get_instance(self, root: str, instance: str) -> Dict[str, Any]:
46
+ pass
47
+
48
+ def __len__(self):
49
+ return len(self.instances)
50
+
51
+ def __getitem__(self, index) -> Dict[str, Any]:
52
+ try:
53
+ root, instance = self.instances[index]
54
+ return self.get_instance(root, instance)
55
+ except Exception as e:
56
+ print(e)
57
+ return self.__getitem__(np.random.randint(0, len(self)))
58
+
59
+ def __str__(self):
60
+ lines = []
61
+ lines.append(self.__class__.__name__)
62
+ lines.append(f' - Total instances: {len(self)}')
63
+ lines.append(f' - Sources:')
64
+ for key, stats in self._stats.items():
65
+ lines.append(f' - {key}:')
66
+ for k, v in stats.items():
67
+ lines.append(f' - {k}: {v}')
68
+ return '\n'.join(lines)
69
+
70
+
71
+ class TextConditionedMixin:
72
+ def __init__(self, roots, **kwargs):
73
+ super().__init__(roots, **kwargs)
74
+ self.captions = {}
75
+ for instance in self.instances:
76
+ sha256 = instance[1]
77
+ self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions'])
78
+
79
+ def filter_metadata(self, metadata):
80
+ metadata, stats = super().filter_metadata(metadata)
81
+ metadata = metadata[metadata['captions'].notna()]
82
+ stats['With captions'] = len(metadata)
83
+ return metadata, stats
84
+
85
+ def get_instance(self, root, instance):
86
+ pack = super().get_instance(root, instance)
87
+ text = np.random.choice(self.captions[instance])
88
+ pack['cond'] = text
89
+ return pack
90
+
91
+
92
+ class ImageConditionedMixin:
93
+ def __init__(self, roots, *, image_size=518, **kwargs):
94
+ self.image_size = image_size
95
+ super().__init__(roots, **kwargs)
96
+
97
+ def filter_metadata(self, metadata):
98
+ metadata, stats = super().filter_metadata(metadata)
99
+ metadata = metadata[metadata[f'cond_rendered']]
100
+ stats['Cond rendered'] = len(metadata)
101
+ return metadata, stats
102
+
103
+ def get_instance(self, root, instance):
104
+ pack = super().get_instance(root, instance)
105
+
106
+ image_root = os.path.join(root, 'renders_cond', instance)
107
+ with open(os.path.join(image_root, 'transforms.json')) as f:
108
+ metadata = json.load(f)
109
+ n_views = len(metadata['frames'])
110
+ view = np.random.randint(n_views)
111
+ metadata = metadata['frames'][view]
112
+
113
+ image_path = os.path.join(image_root, metadata['file_path'])
114
+ image = Image.open(image_path)
115
+
116
+ alpha = np.array(image.getchannel(3))
117
+ bbox = np.array(alpha).nonzero()
118
+ bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
119
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
120
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
121
+ aug_size_ratio = 1.2
122
+ aug_hsize = hsize * aug_size_ratio
123
+ aug_center_offset = [0, 0]
124
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
125
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
126
+ image = image.crop(aug_bbox)
127
+
128
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
129
+ alpha = image.getchannel(3)
130
+ image = image.convert('RGB')
131
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
132
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
133
+ image = image * alpha.unsqueeze(0)
134
+ pack['cond'] = image
135
+
136
+ return pack
137
+
trellis/datasets/sparse_feat2render.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import utils3d.torch
8
+ from ..modules.sparse.basic import SparseTensor
9
+ from .components import StandardDatasetBase
10
+
11
+
12
+ class SparseFeat2Render(StandardDatasetBase):
13
+ """
14
+ SparseFeat2Render dataset.
15
+
16
+ Args:
17
+ roots (str): paths to the dataset
18
+ image_size (int): size of the image
19
+ model (str): model name
20
+ resolution (int): resolution of the data
21
+ min_aesthetic_score (float): minimum aesthetic score
22
+ max_num_voxels (int): maximum number of voxels
23
+ """
24
+ def __init__(
25
+ self,
26
+ roots: str,
27
+ image_size: int,
28
+ model: str = 'dinov2_vitl14_reg',
29
+ resolution: int = 64,
30
+ min_aesthetic_score: float = 5.0,
31
+ max_num_voxels: int = 32768,
32
+ ):
33
+ self.image_size = image_size
34
+ self.model = model
35
+ self.resolution = resolution
36
+ self.min_aesthetic_score = min_aesthetic_score
37
+ self.max_num_voxels = max_num_voxels
38
+ self.value_range = (0, 1)
39
+
40
+ super().__init__(roots)
41
+
42
+ def filter_metadata(self, metadata):
43
+ stats = {}
44
+ metadata = metadata[metadata[f'feature_{self.model}']]
45
+ stats['With features'] = len(metadata)
46
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
47
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
48
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
49
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
50
+ return metadata, stats
51
+
52
+ def _get_image(self, root, instance):
53
+ with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
54
+ metadata = json.load(f)
55
+ n_views = len(metadata['frames'])
56
+ view = np.random.randint(n_views)
57
+ metadata = metadata['frames'][view]
58
+ fov = metadata['camera_angle_x']
59
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
60
+ c2w = torch.tensor(metadata['transform_matrix'])
61
+ c2w[:3, 1:3] *= -1
62
+ extrinsics = torch.inverse(c2w)
63
+
64
+ image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
65
+ image = Image.open(image_path)
66
+ alpha = image.getchannel(3)
67
+ image = image.convert('RGB')
68
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
69
+ alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
70
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
71
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
72
+
73
+ return {
74
+ 'image': image,
75
+ 'alpha': alpha,
76
+ 'extrinsics': extrinsics,
77
+ 'intrinsics': intrinsics,
78
+ }
79
+
80
+ def _get_feat(self, root, instance):
81
+ DATA_RESOLUTION = 64
82
+ feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz')
83
+ feats = np.load(feats_path, allow_pickle=True)
84
+ coords = torch.tensor(feats['indices']).int()
85
+ feats = torch.tensor(feats['patchtokens']).float()
86
+
87
+ if self.resolution != DATA_RESOLUTION:
88
+ factor = DATA_RESOLUTION // self.resolution
89
+ coords = coords // factor
90
+ coords, idx = coords.unique(return_inverse=True, dim=0)
91
+ feats = torch.scatter_reduce(
92
+ torch.zeros(coords.shape[0], feats.shape[1], device=feats.device),
93
+ dim=0,
94
+ index=idx.unsqueeze(-1).expand(-1, feats.shape[1]),
95
+ src=feats,
96
+ reduce='mean'
97
+ )
98
+
99
+ return {
100
+ 'coords': coords,
101
+ 'feats': feats,
102
+ }
103
+
104
+ @torch.no_grad()
105
+ def visualize_sample(self, sample: dict):
106
+ return sample['image']
107
+
108
+ @staticmethod
109
+ def collate_fn(batch):
110
+ pack = {}
111
+ coords = []
112
+ for i, b in enumerate(batch):
113
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
114
+ coords = torch.cat(coords)
115
+ feats = torch.cat([b['feats'] for b in batch])
116
+ pack['feats'] = SparseTensor(
117
+ coords=coords,
118
+ feats=feats,
119
+ )
120
+
121
+ pack['image'] = torch.stack([b['image'] for b in batch])
122
+ pack['alpha'] = torch.stack([b['alpha'] for b in batch])
123
+ pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch])
124
+ pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch])
125
+
126
+ return pack
127
+
128
+ def get_instance(self, root, instance):
129
+ image = self._get_image(root, instance)
130
+ feat = self._get_feat(root, instance)
131
+ return {
132
+ **image,
133
+ **feat,
134
+ }
trellis/datasets/sparse_structure.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Union
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ import utils3d
9
+ from .components import StandardDatasetBase
10
+ from ..representations.octree import DfsOctree as Octree
11
+ from ..renderers import OctreeRenderer
12
+
13
+
14
+ class SparseStructure(StandardDatasetBase):
15
+ """
16
+ Sparse structure dataset
17
+
18
+ Args:
19
+ roots (str): path to the dataset
20
+ resolution (int): resolution of the voxel grid
21
+ min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
22
+ """
23
+
24
+ def __init__(self,
25
+ roots,
26
+ resolution: int = 64,
27
+ min_aesthetic_score: float = 5.0,
28
+ ):
29
+ self.resolution = resolution
30
+ self.min_aesthetic_score = min_aesthetic_score
31
+ self.value_range = (0, 1)
32
+
33
+ super().__init__(roots)
34
+
35
+ def filter_metadata(self, metadata):
36
+ stats = {}
37
+ metadata = metadata[metadata[f'voxelized']]
38
+ stats['Voxelized'] = len(metadata)
39
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
40
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
41
+ return metadata, stats
42
+
43
+ def get_instance(self, root, instance):
44
+ position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
45
+ coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
46
+ ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
47
+ ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
48
+ return {'ss': ss}
49
+
50
+ @torch.no_grad()
51
+ def visualize_sample(self, ss: Union[torch.Tensor, dict]):
52
+ ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
53
+
54
+ renderer = OctreeRenderer()
55
+ renderer.rendering_options.resolution = 512
56
+ renderer.rendering_options.near = 0.8
57
+ renderer.rendering_options.far = 1.6
58
+ renderer.rendering_options.bg_color = (0, 0, 0)
59
+ renderer.rendering_options.ssaa = 4
60
+ renderer.pipe.primitive = 'voxel'
61
+
62
+ # Build camera
63
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
64
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
65
+ yaws = [y + yaws_offset for y in yaws]
66
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
67
+
68
+ exts = []
69
+ ints = []
70
+ for yaw, pitch in zip(yaws, pitch):
71
+ orig = torch.tensor([
72
+ np.sin(yaw) * np.cos(pitch),
73
+ np.cos(yaw) * np.cos(pitch),
74
+ np.sin(pitch),
75
+ ]).float().cuda() * 2
76
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
77
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
78
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
79
+ exts.append(extrinsics)
80
+ ints.append(intrinsics)
81
+
82
+ images = []
83
+
84
+ # Build each representation
85
+ ss = ss.cuda()
86
+ for i in range(ss.shape[0]):
87
+ representation = Octree(
88
+ depth=10,
89
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
90
+ device='cuda',
91
+ primitive='voxel',
92
+ sh_degree=0,
93
+ primitive_config={'solid': True},
94
+ )
95
+ coords = torch.nonzero(ss[i, 0], as_tuple=False)
96
+ representation.position = coords.float() / self.resolution
97
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
98
+
99
+ image = torch.zeros(3, 1024, 1024).cuda()
100
+ tile = [2, 2]
101
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
102
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
103
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
104
+ images.append(image)
105
+
106
+ return torch.stack(images)
107
+