Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- config_slat_flow_128to512_pointnet_head.yaml +122 -0
- config_slat_flow_128to512_pointnet_head_test.yaml +123 -0
- test_slat_flow_128to512_pointnet_head.py +404 -0
- test_slat_flow_128to512_pointnet_head_tomesh.py +1630 -0
- test_slat_vae_128to512_pointnet_vae_head.py +12 -14
- train_slat_flow_128to512_pointnet_head.py +507 -0
- trellis/__init__.py +1 -1
- trellis/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/models/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc +0 -0
- trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc +0 -0
- trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc +0 -0
- trellis/modules/__pycache__/norm.cpython-310.pyc +0 -0
- trellis/modules/__pycache__/spatial.cpython-310.pyc +0 -0
- trellis/modules/__pycache__/utils.cpython-310.pyc +0 -0
- trellis/modules/attention/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc +0 -0
- trellis/modules/attention/__pycache__/modules.cpython-310.pyc +0 -0
- trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/modules/sparse/__pycache__/basic.cpython-310.pyc +0 -0
- trellis/modules/sparse/__pycache__/linear.cpython-310.pyc +0 -0
- trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc +0 -0
- trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc +0 -0
- trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc +0 -0
- trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc +0 -0
- trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc +0 -0
- trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc +0 -0
- trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc +0 -0
- trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc +0 -0
- trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc +0 -0
- trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc +0 -0
- trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc +0 -0
- trellis/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/pipelines/__pycache__/base.cpython-310.pyc +0 -0
- trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc +0 -0
- trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc +0 -0
- trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc +0 -0
- trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc +0 -0
- trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc +0 -0
- trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc +0 -0
- trellis/renderers/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/representations/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc +0 -0
- trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc +0 -0
- trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc +0 -0
config_slat_flow_128to512_pointnet_head.yaml
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
pred_direction: false
|
| 3 |
+
relative_embed: true
|
| 4 |
+
using_attn: false
|
| 5 |
+
add_block_embed: true
|
| 6 |
+
multires: 12
|
| 7 |
+
|
| 8 |
+
embed_dim: 1024
|
| 9 |
+
in_channels: 1024
|
| 10 |
+
model_channels: 384
|
| 11 |
+
latent_dim: 16
|
| 12 |
+
|
| 13 |
+
block_size: 16
|
| 14 |
+
pos_encoding: 'nerf'
|
| 15 |
+
attn_first: false
|
| 16 |
+
|
| 17 |
+
add_edge_glb_feats: true
|
| 18 |
+
add_direction: false
|
| 19 |
+
|
| 20 |
+
encoder_blocks:
|
| 21 |
+
- in_channels: 1024
|
| 22 |
+
model_channels: 512
|
| 23 |
+
num_blocks: 8
|
| 24 |
+
num_heads: 8
|
| 25 |
+
out_channels: 512
|
| 26 |
+
|
| 27 |
+
decoder_blocks_edge:
|
| 28 |
+
- in_channels: 512
|
| 29 |
+
model_channels: 512
|
| 30 |
+
num_blocks: 0
|
| 31 |
+
num_heads: 0
|
| 32 |
+
out_channels: 256
|
| 33 |
+
resolution: 128
|
| 34 |
+
- in_channels: 256
|
| 35 |
+
model_channels: 256
|
| 36 |
+
num_blocks: 0
|
| 37 |
+
num_heads: 0
|
| 38 |
+
out_channels: 128
|
| 39 |
+
resolution: 256
|
| 40 |
+
# - in_channels: 64
|
| 41 |
+
# model_channels: 64
|
| 42 |
+
# num_blocks: 0
|
| 43 |
+
# num_heads: 0
|
| 44 |
+
# out_channels: 32
|
| 45 |
+
# resolution: 512
|
| 46 |
+
|
| 47 |
+
decoder_blocks_vtx:
|
| 48 |
+
- in_channels: 512
|
| 49 |
+
model_channels: 512
|
| 50 |
+
num_blocks: 0
|
| 51 |
+
num_heads: 0
|
| 52 |
+
out_channels: 256
|
| 53 |
+
resolution: 128
|
| 54 |
+
- in_channels: 256
|
| 55 |
+
model_channels: 256
|
| 56 |
+
num_blocks: 0
|
| 57 |
+
num_heads: 0
|
| 58 |
+
out_channels: 128
|
| 59 |
+
resolution: 256
|
| 60 |
+
# - in_channels: 64
|
| 61 |
+
# model_channels: 64
|
| 62 |
+
# num_blocks: 0
|
| 63 |
+
# num_heads: 0
|
| 64 |
+
# out_channels: 32
|
| 65 |
+
# resolution: 512
|
| 66 |
+
|
| 67 |
+
"t_schedule":
|
| 68 |
+
"name": "logitNormal"
|
| 69 |
+
"args":
|
| 70 |
+
"mean": 1.0
|
| 71 |
+
"std": 1.0
|
| 72 |
+
|
| 73 |
+
"sigma_min": 1.e-5
|
| 74 |
+
|
| 75 |
+
training:
|
| 76 |
+
batch_size: 1
|
| 77 |
+
lr: 1.e-4
|
| 78 |
+
step_size: 20
|
| 79 |
+
gamma: 0.95
|
| 80 |
+
save_every: 500
|
| 81 |
+
start_epoch: 0
|
| 82 |
+
max_epochs: 300
|
| 83 |
+
num_workers: 32
|
| 84 |
+
|
| 85 |
+
output_dir: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope
|
| 86 |
+
clip_model_path: None
|
| 87 |
+
dinov2_model_path: None
|
| 88 |
+
|
| 89 |
+
vae_path: /home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt
|
| 90 |
+
denoiser_checkpoint_path: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512/checkpoint_step143000_loss0_736924.pt
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
dataset:
|
| 94 |
+
path: /home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01
|
| 95 |
+
cache_dir: /home/tiger/yy/src/dataset_cache/unique_files_glb_under6000face_2degree_30ratio_0.01
|
| 96 |
+
|
| 97 |
+
renders_dir: None
|
| 98 |
+
filter_active_voxels: true
|
| 99 |
+
cache_filter_path: /home/tiger/yy/src/40w_2000-100000edge_2000-100000active.txt
|
| 100 |
+
|
| 101 |
+
base_resolution: 1024
|
| 102 |
+
min_resolution: 128
|
| 103 |
+
|
| 104 |
+
n_train_samples: 1024
|
| 105 |
+
sample_type: dora
|
| 106 |
+
|
| 107 |
+
flow:
|
| 108 |
+
"resolution": 128
|
| 109 |
+
"in_channels": 16
|
| 110 |
+
"out_channels": 16
|
| 111 |
+
"model_channels": 768
|
| 112 |
+
"cond_channels": 1024
|
| 113 |
+
"num_blocks": 12
|
| 114 |
+
"num_heads": 12
|
| 115 |
+
"mlp_ratio": 4
|
| 116 |
+
"patch_size": 2
|
| 117 |
+
"num_io_res_blocks": 2
|
| 118 |
+
"io_block_channels": [128]
|
| 119 |
+
"pe_mode": "rope"
|
| 120 |
+
"qk_rms_norm": true
|
| 121 |
+
"qk_rms_norm_cross": false
|
| 122 |
+
"use_fp16": false
|
config_slat_flow_128to512_pointnet_head_test.yaml
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
pred_direction: false
|
| 3 |
+
relative_embed: true
|
| 4 |
+
using_attn: false
|
| 5 |
+
add_block_embed: true
|
| 6 |
+
multires: 12
|
| 7 |
+
|
| 8 |
+
embed_dim: 1024
|
| 9 |
+
in_channels: 1024
|
| 10 |
+
model_channels: 384
|
| 11 |
+
latent_dim: 16
|
| 12 |
+
|
| 13 |
+
block_size: 16
|
| 14 |
+
pos_encoding: 'nerf'
|
| 15 |
+
attn_first: false
|
| 16 |
+
|
| 17 |
+
add_edge_glb_feats: true
|
| 18 |
+
add_direction: false
|
| 19 |
+
|
| 20 |
+
encoder_blocks:
|
| 21 |
+
- in_channels: 1024
|
| 22 |
+
model_channels: 512
|
| 23 |
+
num_blocks: 8
|
| 24 |
+
num_heads: 8
|
| 25 |
+
out_channels: 512
|
| 26 |
+
|
| 27 |
+
decoder_blocks_edge:
|
| 28 |
+
- in_channels: 512
|
| 29 |
+
model_channels: 512
|
| 30 |
+
num_blocks: 0
|
| 31 |
+
num_heads: 0
|
| 32 |
+
out_channels: 256
|
| 33 |
+
resolution: 128
|
| 34 |
+
- in_channels: 256
|
| 35 |
+
model_channels: 256
|
| 36 |
+
num_blocks: 0
|
| 37 |
+
num_heads: 0
|
| 38 |
+
out_channels: 128
|
| 39 |
+
resolution: 256
|
| 40 |
+
# - in_channels: 64
|
| 41 |
+
# model_channels: 64
|
| 42 |
+
# num_blocks: 0
|
| 43 |
+
# num_heads: 0
|
| 44 |
+
# out_channels: 32
|
| 45 |
+
# resolution: 512
|
| 46 |
+
|
| 47 |
+
decoder_blocks_vtx:
|
| 48 |
+
- in_channels: 512
|
| 49 |
+
model_channels: 512
|
| 50 |
+
num_blocks: 0
|
| 51 |
+
num_heads: 0
|
| 52 |
+
out_channels: 256
|
| 53 |
+
resolution: 128
|
| 54 |
+
- in_channels: 256
|
| 55 |
+
model_channels: 256
|
| 56 |
+
num_blocks: 0
|
| 57 |
+
num_heads: 0
|
| 58 |
+
out_channels: 128
|
| 59 |
+
resolution: 256
|
| 60 |
+
# - in_channels: 64
|
| 61 |
+
# model_channels: 64
|
| 62 |
+
# num_blocks: 0
|
| 63 |
+
# num_heads: 0
|
| 64 |
+
# out_channels: 32
|
| 65 |
+
# resolution: 512
|
| 66 |
+
|
| 67 |
+
"t_schedule":
|
| 68 |
+
"name": "logitNormal"
|
| 69 |
+
"args":
|
| 70 |
+
"mean": 1.0
|
| 71 |
+
"std": 1.0
|
| 72 |
+
|
| 73 |
+
"sigma_min": 1.e-5
|
| 74 |
+
|
| 75 |
+
training:
|
| 76 |
+
batch_size: 1
|
| 77 |
+
lr: 1.e-4
|
| 78 |
+
step_size: 20
|
| 79 |
+
gamma: 0.95
|
| 80 |
+
save_every: 500
|
| 81 |
+
start_epoch: 0
|
| 82 |
+
max_epochs: 300
|
| 83 |
+
num_workers: 32
|
| 84 |
+
|
| 85 |
+
output_dir: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope
|
| 86 |
+
clip_model_path: None
|
| 87 |
+
dinov2_model_path: None
|
| 88 |
+
|
| 89 |
+
vae_path: /home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt
|
| 90 |
+
denoiser_checkpoint_path: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512/checkpoint_step143000_loss0_736924.pt
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
dataset:
|
| 94 |
+
path: /home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01
|
| 95 |
+
path: /home/tiger/yy/src/trellis_clean_mesh/mesh_data
|
| 96 |
+
cache_dir: /home/tiger/yy/src/dataset_cache/unique_files_glb_under6000face_2degree_30ratio_0.01
|
| 97 |
+
|
| 98 |
+
renders_dir: None
|
| 99 |
+
filter_active_voxels: false
|
| 100 |
+
cache_filter_path: /home/tiger/yy/src/40w_2000-100000edge_2000-100000active.txt
|
| 101 |
+
|
| 102 |
+
base_resolution: 1024
|
| 103 |
+
min_resolution: 128
|
| 104 |
+
|
| 105 |
+
n_train_samples: 1024
|
| 106 |
+
sample_type: dora
|
| 107 |
+
|
| 108 |
+
flow:
|
| 109 |
+
"resolution": 128
|
| 110 |
+
"in_channels": 16
|
| 111 |
+
"out_channels": 16
|
| 112 |
+
"model_channels": 768
|
| 113 |
+
"cond_channels": 1024
|
| 114 |
+
"num_blocks": 12
|
| 115 |
+
"num_heads": 12
|
| 116 |
+
"mlp_ratio": 4
|
| 117 |
+
"patch_size": 2
|
| 118 |
+
"num_io_res_blocks": 2
|
| 119 |
+
"io_block_channels": [128]
|
| 120 |
+
"pe_mode": "rope"
|
| 121 |
+
"qk_rms_norm": true
|
| 122 |
+
"qk_rms_norm_cross": false
|
| 123 |
+
"use_fp16": false
|
test_slat_flow_128to512_pointnet_head.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import yaml
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from functools import partial
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.amp import GradScaler, autocast
|
| 12 |
+
from typing import *
|
| 13 |
+
from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig, Dinov2Model, AutoImageProcessor, Dinov2Config
|
| 14 |
+
import torch
|
| 15 |
+
import re
|
| 16 |
+
from utils import load_pretrained_woself
|
| 17 |
+
|
| 18 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 19 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 20 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet
|
| 21 |
+
|
| 22 |
+
from trellis.models.structured_latent_flow import SLatFlowModel
|
| 23 |
+
from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
|
| 24 |
+
|
| 25 |
+
from trellis.pipelines.samplers import FlowEulerSampler
|
| 26 |
+
from safetensors.torch import load_file
|
| 27 |
+
import open3d as o3d
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 31 |
+
from trellis.modules.sparse.basic import SparseTensor as SparseTensor_trellis
|
| 32 |
+
|
| 33 |
+
from triposf.modules.utils import DiagonalGaussianDistribution
|
| 34 |
+
|
| 35 |
+
from sklearn.decomposition import PCA
|
| 36 |
+
import trimesh
|
| 37 |
+
import torchvision.transforms as transforms
|
| 38 |
+
|
| 39 |
+
# --- Helper Functions ---
|
| 40 |
+
def save_colored_ply(points, colors, filename):
|
| 41 |
+
if len(points) == 0:
|
| 42 |
+
print(f"[Warning] No points to save for {filename}")
|
| 43 |
+
return
|
| 44 |
+
# Ensure colors are uint8
|
| 45 |
+
if colors.max() <= 1.0:
|
| 46 |
+
colors = (colors * 255).astype(np.uint8)
|
| 47 |
+
colors = colors.astype(np.uint8)
|
| 48 |
+
|
| 49 |
+
# Add Alpha if missing
|
| 50 |
+
if colors.shape[1] == 3:
|
| 51 |
+
colors = np.hstack([colors, np.full((len(colors), 1), 255, dtype=np.uint8)])
|
| 52 |
+
|
| 53 |
+
cloud = trimesh.PointCloud(points, colors=colors)
|
| 54 |
+
cloud.export(filename)
|
| 55 |
+
print(f"Saved colored point cloud to {filename}")
|
| 56 |
+
|
| 57 |
+
def normalize_to_rgb(features_3d):
|
| 58 |
+
min_vals = features_3d.min(axis=0)
|
| 59 |
+
max_vals = features_3d.max(axis=0)
|
| 60 |
+
range_vals = max_vals - min_vals
|
| 61 |
+
range_vals[range_vals == 0] = 1
|
| 62 |
+
normalized = (features_3d - min_vals) / range_vals
|
| 63 |
+
return (normalized * 255).astype(np.uint8)
|
| 64 |
+
|
| 65 |
+
class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
|
| 66 |
+
def __init__(self, *args, **kwargs):
|
| 67 |
+
super().__init__(*args, **kwargs)
|
| 68 |
+
self.cfg = kwargs.pop('cfg', None)
|
| 69 |
+
if self.cfg is None:
|
| 70 |
+
raise ValueError("Configuration dictionary 'cfg' must be provided.")
|
| 71 |
+
|
| 72 |
+
self.sampler = FlowEulerSampler(sigma_min=1.e-5)
|
| 73 |
+
self.device = torch.device("cuda")
|
| 74 |
+
|
| 75 |
+
# Based on PointNet Encoder setting
|
| 76 |
+
self.resolution = 128
|
| 77 |
+
|
| 78 |
+
self.condition_type = 'image'
|
| 79 |
+
self.is_cond = False
|
| 80 |
+
|
| 81 |
+
self.img_res = 518
|
| 82 |
+
self.feature_dim = self.cfg['model']['latent_dim']
|
| 83 |
+
|
| 84 |
+
self._init_components(
|
| 85 |
+
clip_model_path=self.cfg['training'].get('clip_model_path', None),
|
| 86 |
+
dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
|
| 87 |
+
vae_path=self.cfg['training']['vae_path'],
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Classifier head removed as it is not part of the Active Voxel pipeline
|
| 91 |
+
|
| 92 |
+
def _load_denoiser(self, denoiser_checkpoint_path):
|
| 93 |
+
path = denoiser_checkpoint_path
|
| 94 |
+
if not path or not os.path.isfile(path):
|
| 95 |
+
print("No valid checkpoint path provided for fine-tuning. Starting from scratch.")
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
print(f"Loading checkpoint from: {path}")
|
| 99 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
denoiser_state_dict = checkpoint['denoiser']
|
| 103 |
+
# Handle DDP prefix
|
| 104 |
+
if next(iter(denoiser_state_dict)).startswith('module.'):
|
| 105 |
+
denoiser_state_dict = {k[7:]: v for k, v in denoiser_state_dict.items()}
|
| 106 |
+
|
| 107 |
+
self.denoiser.load_state_dict(denoiser_state_dict)
|
| 108 |
+
print("Denoiser weights loaded successfully.")
|
| 109 |
+
except KeyError:
|
| 110 |
+
print("[WARN] 'denoiser' key not found in checkpoint. Skipping.")
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"[ERROR] Failed to load denoiser state_dict: {e}")
|
| 113 |
+
|
| 114 |
+
def _init_components(self,
|
| 115 |
+
clip_model_path=None,
|
| 116 |
+
dinov2_model_path=None,
|
| 117 |
+
vae_path=None,
|
| 118 |
+
):
|
| 119 |
+
|
| 120 |
+
# 1. Initialize PointNet Voxel Encoder (Matches Training)
|
| 121 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 122 |
+
in_channels=15,
|
| 123 |
+
hidden_dim=256,
|
| 124 |
+
out_channels=1024,
|
| 125 |
+
scatter_type='mean',
|
| 126 |
+
n_blocks=5,
|
| 127 |
+
resolution=128,
|
| 128 |
+
add_label=False,
|
| 129 |
+
).to(self.device)
|
| 130 |
+
|
| 131 |
+
# 2. Initialize VAE
|
| 132 |
+
self.vae = VoxelVAE(
|
| 133 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 134 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 135 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 136 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 137 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 138 |
+
num_heads=8,
|
| 139 |
+
num_head_channels=64,
|
| 140 |
+
mlp_ratio=4.0,
|
| 141 |
+
attn_mode="swin",
|
| 142 |
+
window_size=8,
|
| 143 |
+
pe_mode="ape",
|
| 144 |
+
use_fp16=False,
|
| 145 |
+
use_checkpoint=False,
|
| 146 |
+
qk_rms_norm=False,
|
| 147 |
+
using_subdivide=True,
|
| 148 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 149 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 150 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 151 |
+
).to(self.device)
|
| 152 |
+
|
| 153 |
+
# 3. Initialize Dataset with collate_fn_pointnet
|
| 154 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 155 |
+
root_dir=self.cfg['dataset']['path'],
|
| 156 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 157 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 158 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 159 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 160 |
+
|
| 161 |
+
process_img=False,
|
| 162 |
+
|
| 163 |
+
active_voxel_res=128,
|
| 164 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 165 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 166 |
+
sample_type=self.cfg['dataset'].get('sample_type', 'dora'),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.dataloader = DataLoader(
|
| 170 |
+
self.dataset,
|
| 171 |
+
batch_size=1,
|
| 172 |
+
shuffle=True,
|
| 173 |
+
collate_fn=partial(collate_fn_pointnet,), # Critical Change
|
| 174 |
+
num_workers=0,
|
| 175 |
+
pin_memory=True,
|
| 176 |
+
persistent_workers=False,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# 4. Load Pretrained Weights
|
| 180 |
+
# Assuming vae_path contains 'voxel_encoder' and 'vae'
|
| 181 |
+
print(f"Loading VAE/Encoder from {vae_path}")
|
| 182 |
+
ckpt = torch.load(vae_path, map_location='cpu')
|
| 183 |
+
|
| 184 |
+
# Load VAE
|
| 185 |
+
if 'vae' in ckpt:
|
| 186 |
+
self.vae.load_state_dict(ckpt['vae'], strict=False)
|
| 187 |
+
else:
|
| 188 |
+
self.vae.load_state_dict(ckpt) # Fallback
|
| 189 |
+
|
| 190 |
+
# Load Encoder
|
| 191 |
+
if 'voxel_encoder' in ckpt:
|
| 192 |
+
self.voxel_encoder.load_state_dict(ckpt['voxel_encoder'])
|
| 193 |
+
else:
|
| 194 |
+
print("[WARN] 'voxel_encoder' not found in checkpoint, random init (BAD for inference).")
|
| 195 |
+
|
| 196 |
+
self.voxel_encoder.eval()
|
| 197 |
+
self.vae.eval()
|
| 198 |
+
|
| 199 |
+
# 5. Initialize Conditioning Model
|
| 200 |
+
if self.condition_type == 'text':
|
| 201 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
|
| 202 |
+
self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
|
| 203 |
+
elif self.condition_type == 'image':
|
| 204 |
+
model_name = 'dinov2_vitl14_reg'
|
| 205 |
+
local_repo_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main"
|
| 206 |
+
weights_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
|
| 207 |
+
|
| 208 |
+
dinov2_model = torch.hub.load(
|
| 209 |
+
repo_or_dir=local_repo_path,
|
| 210 |
+
model=model_name,
|
| 211 |
+
source='local',
|
| 212 |
+
pretrained=False
|
| 213 |
+
)
|
| 214 |
+
self.condition_model = dinov2_model
|
| 215 |
+
self.condition_model.load_state_dict(torch.load(weights_path))
|
| 216 |
+
|
| 217 |
+
self.image_cond_model_transform = transforms.Compose([
|
| 218 |
+
transforms.ToTensor(),
|
| 219 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 220 |
+
])
|
| 221 |
+
else:
|
| 222 |
+
raise ValueError(f"Unsupported condition type: {self.condition_type}")
|
| 223 |
+
|
| 224 |
+
self.condition_model.to(self.device).eval()
|
| 225 |
+
|
| 226 |
+
@torch.no_grad()
|
| 227 |
+
def encode_image(self, images) -> torch.Tensor:
|
| 228 |
+
if isinstance(images, torch.Tensor):
|
| 229 |
+
batch_tensor = images.to(self.device)
|
| 230 |
+
elif isinstance(images, list):
|
| 231 |
+
assert all(isinstance(i, Image.Image) for i in images)
|
| 232 |
+
image = [i.resize((518, 518), Image.LANCZOS) for i in images]
|
| 233 |
+
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
| 234 |
+
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
| 235 |
+
batch_tensor = torch.stack(image).to(self.device)
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Unsupported type of image: {type(images)}")
|
| 238 |
+
|
| 239 |
+
if batch_tensor.shape[-2:] != (518, 518):
|
| 240 |
+
batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
|
| 241 |
+
|
| 242 |
+
features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
|
| 243 |
+
patchtokens = F.layer_norm(features, features.shape[-1:])
|
| 244 |
+
return patchtokens
|
| 245 |
+
|
| 246 |
+
def process_batch(self, batch):
|
| 247 |
+
preprocessed_images = batch['image']
|
| 248 |
+
cond_ = self.encode_image(preprocessed_images)
|
| 249 |
+
return cond_
|
| 250 |
+
|
| 251 |
+
def eval(self):
|
| 252 |
+
# Unconditional Setup
|
| 253 |
+
if self.is_cond == False:
|
| 254 |
+
if self.condition_type == 'text':
|
| 255 |
+
txt = ['']
|
| 256 |
+
encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
| 257 |
+
tokens = encoding['input_ids'].to(self.device)
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
cond_ = self.condition_model(input_ids=tokens).last_hidden_state
|
| 260 |
+
else:
|
| 261 |
+
blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
dummy_cond = self.encode_image([blank_img])
|
| 264 |
+
cond_ = torch.zeros_like(dummy_cond)
|
| 265 |
+
print(f"Generated unconditional image prompt (zero tensor) with shape: {cond_.shape}")
|
| 266 |
+
|
| 267 |
+
self.denoiser.eval()
|
| 268 |
+
|
| 269 |
+
# Load Denoiser Checkpoint
|
| 270 |
+
# Update this path to your ACTIVE VOXEL trained checkpoint
|
| 271 |
+
checkpoint_path = '/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/checkpoint_step143500_loss0_766792.pt'
|
| 272 |
+
|
| 273 |
+
self._load_denoiser(checkpoint_path)
|
| 274 |
+
|
| 275 |
+
filename = os.path.basename(checkpoint_path)
|
| 276 |
+
match = re.search(r'step(\d+)', filename)
|
| 277 |
+
step_str = match.group(1) if match else "eval"
|
| 278 |
+
save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_trellis")
|
| 279 |
+
# save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_40w_train")
|
| 280 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 281 |
+
print(f"Results will be saved to: {save_dir}")
|
| 282 |
+
|
| 283 |
+
for i, batch in enumerate(self.dataloader):
|
| 284 |
+
if i > 50: exit() # Visualize first 10
|
| 285 |
+
|
| 286 |
+
if self.is_cond and self.condition_type == 'image':
|
| 287 |
+
cond_ = self.process_batch(batch)
|
| 288 |
+
|
| 289 |
+
if cond_.shape[0] != 1:
|
| 290 |
+
cond_ = cond_.expand(batch['active_voxels_128'].shape[0], -1, -1).contiguous().to(self.device)
|
| 291 |
+
else:
|
| 292 |
+
cond_ = cond_.to(self.device)
|
| 293 |
+
|
| 294 |
+
# --- Data Retrieval (Matches collate_fn_pointnet) ---
|
| 295 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 296 |
+
active_coords = batch['active_voxels_128'].to(self.device) # [N, 4]
|
| 297 |
+
|
| 298 |
+
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 299 |
+
with torch.no_grad():
|
| 300 |
+
# 1. Encode Ground Truth Latents
|
| 301 |
+
active_voxel_feats = self.voxel_encoder(
|
| 302 |
+
p=point_cloud,
|
| 303 |
+
sparse_coords=active_coords,
|
| 304 |
+
res=128,
|
| 305 |
+
bbox_size=(-0.5, 0.5),
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
sparse_input = SparseTensor(
|
| 309 |
+
feats=active_voxel_feats,
|
| 310 |
+
coords=active_coords.int()
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Encode to get GT distribution
|
| 314 |
+
gt_latents, posterior = self.vae.encode(sparse_input)
|
| 315 |
+
|
| 316 |
+
print(f"Batch {i}: Active voxels: {active_coords.shape[0]}")
|
| 317 |
+
|
| 318 |
+
# 2. Generation / Sampling
|
| 319 |
+
# Generate noise on the SAME active coordinates
|
| 320 |
+
noise = SparseTensor_trellis(
|
| 321 |
+
coords=active_coords.int(),
|
| 322 |
+
feats=torch.randn(
|
| 323 |
+
active_coords.shape[0],
|
| 324 |
+
self.feature_dim,
|
| 325 |
+
device=self.device,
|
| 326 |
+
),
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
sample_results = self.sampler.sample(
|
| 330 |
+
model=self.denoiser.float(),
|
| 331 |
+
noise=noise.to(self.device).float(),
|
| 332 |
+
cond=cond_.to(self.device).float(),
|
| 333 |
+
steps=50,
|
| 334 |
+
rescale_t=1.0,
|
| 335 |
+
verbose=True,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
generated_sparse_tensor = sample_results.samples
|
| 339 |
+
generated_coords = generated_sparse_tensor.coords
|
| 340 |
+
generated_features = generated_sparse_tensor.feats
|
| 341 |
+
|
| 342 |
+
print('Gen features mean:', generated_features.mean().item(), 'std:', generated_features.std().item())
|
| 343 |
+
print('GT features mean:', gt_latents.feats.mean().item(), 'std:', gt_latents.feats.std().item())
|
| 344 |
+
print('MSE:', F.mse_loss(generated_features, gt_latents.feats).item())
|
| 345 |
+
|
| 346 |
+
# --- Visualization (PCA) ---
|
| 347 |
+
gt_feats_np = gt_latents.feats.detach().cpu().numpy()
|
| 348 |
+
gen_feats_np = generated_features.detach().cpu().numpy()
|
| 349 |
+
coords_np = active_coords[:, 1:4].detach().cpu().numpy() # x, y, z
|
| 350 |
+
|
| 351 |
+
print("Visualizing features using PCA...")
|
| 352 |
+
pca = PCA(n_components=3)
|
| 353 |
+
|
| 354 |
+
# Fit PCA on GT, transform both
|
| 355 |
+
pca.fit(gt_feats_np)
|
| 356 |
+
gt_feats_3d = pca.transform(gt_feats_np)
|
| 357 |
+
gen_feats_3d = pca.transform(gen_feats_np)
|
| 358 |
+
|
| 359 |
+
gt_colors = normalize_to_rgb(gt_feats_3d)
|
| 360 |
+
gen_colors = normalize_to_rgb(gen_feats_3d)
|
| 361 |
+
|
| 362 |
+
# Save PLYs
|
| 363 |
+
save_colored_ply(coords_np, gt_colors, os.path.join(save_dir, f"batch_{i}_gt_pca.ply"))
|
| 364 |
+
save_colored_ply(coords_np, gen_colors, os.path.join(save_dir, f"batch_{i}_gen_pca.ply"))
|
| 365 |
+
|
| 366 |
+
# Save Tensors for further analysis
|
| 367 |
+
torch.save(gt_latents, os.path.join(save_dir, f"gt_latent_{i}.pt"))
|
| 368 |
+
|
| 369 |
+
torch.save(batch, os.path.join(save_dir, f"gt_data_batch_{i}.pt"))
|
| 370 |
+
torch.save(sample_results.samples, os.path.join(save_dir, f"sample_latent_{i}.pt"))
|
| 371 |
+
|
| 372 |
+
if __name__ == '__main__':
|
| 373 |
+
torch.manual_seed(42)
|
| 374 |
+
config_path = "/home/tiger/yy/src/Michelangelo-master/config_slat_flow_128to512_pointnet_head_test.yaml"
|
| 375 |
+
with open(config_path) as f:
|
| 376 |
+
cfg = yaml.safe_load(f)
|
| 377 |
+
|
| 378 |
+
# Initialize Model on CPU first
|
| 379 |
+
diffusion_model = SLatFlowModel(
|
| 380 |
+
resolution=cfg['flow']['resolution'],
|
| 381 |
+
in_channels=cfg['flow']['in_channels'],
|
| 382 |
+
out_channels=cfg['flow']['out_channels'],
|
| 383 |
+
model_channels=cfg['flow']['model_channels'],
|
| 384 |
+
cond_channels=cfg['flow']['cond_channels'],
|
| 385 |
+
num_blocks=cfg['flow']['num_blocks'],
|
| 386 |
+
num_heads=cfg['flow']['num_heads'],
|
| 387 |
+
mlp_ratio=cfg['flow']['mlp_ratio'],
|
| 388 |
+
patch_size=cfg['flow']['patch_size'],
|
| 389 |
+
num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
|
| 390 |
+
io_block_channels=cfg['flow']['io_block_channels'],
|
| 391 |
+
pe_mode=cfg['flow']['pe_mode'],
|
| 392 |
+
qk_rms_norm=cfg['flow']['qk_rms_norm'],
|
| 393 |
+
qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
|
| 394 |
+
use_fp16=cfg['flow'].get('use_fp16', False),
|
| 395 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 396 |
+
|
| 397 |
+
trainer = SLatFlowMatchingTrainer(
|
| 398 |
+
denoiser=diffusion_model,
|
| 399 |
+
t_schedule=cfg['t_schedule'],
|
| 400 |
+
sigma_min=cfg['sigma_min'],
|
| 401 |
+
cfg=cfg,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
trainer.eval()
|
test_slat_flow_128to512_pointnet_head_tomesh.py
ADDED
|
@@ -0,0 +1,1630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
from plotly.subplots import make_subplots
|
| 10 |
+
from torch.utils.data import DataLoader, Subset
|
| 11 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 12 |
+
|
| 13 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 14 |
+
|
| 15 |
+
from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 16 |
+
from utils import load_pretrained_woself
|
| 17 |
+
|
| 18 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from functools import partial
|
| 22 |
+
import itertools
|
| 23 |
+
from typing import List, Tuple, Set
|
| 24 |
+
from collections import OrderedDict
|
| 25 |
+
from scipy.spatial import cKDTree
|
| 26 |
+
from sklearn.neighbors import KDTree
|
| 27 |
+
|
| 28 |
+
import trimesh
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
import time
|
| 33 |
+
|
| 34 |
+
from sklearn.decomposition import PCA
|
| 35 |
+
import matplotlib.pyplot as plt
|
| 36 |
+
|
| 37 |
+
import networkx as nx
|
| 38 |
+
|
| 39 |
+
def predict_mesh_connectivity(
|
| 40 |
+
connection_head,
|
| 41 |
+
vtx_feats,
|
| 42 |
+
vtx_coords,
|
| 43 |
+
batch_size=10000,
|
| 44 |
+
threshold=0.5,
|
| 45 |
+
k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测
|
| 46 |
+
device='cuda'
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
connection_head: 训练好的 MLP 模型
|
| 51 |
+
vtx_feats: [N, C] 顶点特征
|
| 52 |
+
vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边)
|
| 53 |
+
batch_size: MLP 推理的 batch size
|
| 54 |
+
threshold: 判定连接的概率阈值
|
| 55 |
+
k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。
|
| 56 |
+
"""
|
| 57 |
+
num_verts = vtx_feats.shape[0]
|
| 58 |
+
if num_verts < 3:
|
| 59 |
+
return [], [] # 无法构成三角形
|
| 60 |
+
|
| 61 |
+
connection_head.eval()
|
| 62 |
+
|
| 63 |
+
# --- 1. 生成候选边 (Candidate Edges) ---
|
| 64 |
+
if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts:
|
| 65 |
+
# 策略 A: 局部 KNN (推荐)
|
| 66 |
+
# 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版
|
| 67 |
+
# 或者直接暴力 cdist 如果 N < 10000
|
| 68 |
+
|
| 69 |
+
# 为了简单且高效,这里演示简单的 cdist (注意显存)
|
| 70 |
+
# 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree
|
| 71 |
+
dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N]
|
| 72 |
+
|
| 73 |
+
# 取 topk (smallest distance),排除自己
|
| 74 |
+
# values: [N, K], indices: [N, K]
|
| 75 |
+
_, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False)
|
| 76 |
+
neighbor_indices = indices[:, 1:] # 去掉第一列(自己)
|
| 77 |
+
|
| 78 |
+
# 构建 source, target 索引
|
| 79 |
+
src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten()
|
| 80 |
+
dst = neighbor_indices.flatten()
|
| 81 |
+
|
| 82 |
+
# 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重
|
| 83 |
+
# 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可
|
| 84 |
+
# 这里为了简单,我们生成 u < v 的 mask
|
| 85 |
+
mask = src < dst
|
| 86 |
+
u_indices = src[mask]
|
| 87 |
+
v_indices = dst[mask]
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
# 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用
|
| 91 |
+
u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device)
|
| 92 |
+
|
| 93 |
+
# --- 2. 批量推理 ---
|
| 94 |
+
all_probs = []
|
| 95 |
+
num_candidates = u_indices.shape[0]
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
for i in range(0, num_candidates, batch_size):
|
| 99 |
+
end = min(i + batch_size, num_candidates)
|
| 100 |
+
batch_u = u_indices[i:end]
|
| 101 |
+
batch_v = v_indices[i:end]
|
| 102 |
+
|
| 103 |
+
feat_u = vtx_feats[batch_u]
|
| 104 |
+
feat_v = vtx_feats[batch_v]
|
| 105 |
+
|
| 106 |
+
# Symmetric Forward (和你训练时保持一致)
|
| 107 |
+
# A -> B
|
| 108 |
+
input_uv = torch.cat([feat_u, feat_v], dim=-1)
|
| 109 |
+
logits_uv = connection_head(input_uv)
|
| 110 |
+
|
| 111 |
+
# B -> A
|
| 112 |
+
input_vu = torch.cat([feat_v, feat_u], dim=-1)
|
| 113 |
+
logits_vu = connection_head(input_vu)
|
| 114 |
+
|
| 115 |
+
# Sum logits
|
| 116 |
+
logits = (logits_uv + logits_vu)
|
| 117 |
+
probs = torch.sigmoid(logits)
|
| 118 |
+
all_probs.append(probs)
|
| 119 |
+
|
| 120 |
+
all_probs = torch.cat(all_probs).squeeze() # [M]
|
| 121 |
+
|
| 122 |
+
# --- 3. 筛选连接边 ---
|
| 123 |
+
connected_mask = all_probs > threshold
|
| 124 |
+
final_u = u_indices[connected_mask].cpu().numpy()
|
| 125 |
+
final_v = v_indices[connected_mask].cpu().numpy()
|
| 126 |
+
|
| 127 |
+
edges = np.stack([final_u, final_v], axis=1) # [E, 2]
|
| 128 |
+
|
| 129 |
+
return edges
|
| 130 |
+
|
| 131 |
+
def build_triangles_from_edges(edges, num_verts):
|
| 132 |
+
"""
|
| 133 |
+
从边列表构建三角形。
|
| 134 |
+
寻找图中所有的 3-Cliques (三元环)。
|
| 135 |
+
这在图论中是一个经典问题,可以使用 networkx 库。
|
| 136 |
+
"""
|
| 137 |
+
if len(edges) == 0:
|
| 138 |
+
return np.empty((0, 3), dtype=int)
|
| 139 |
+
|
| 140 |
+
G = nx.Graph()
|
| 141 |
+
G.add_nodes_from(range(num_verts))
|
| 142 |
+
G.add_edges_from(edges)
|
| 143 |
+
|
| 144 |
+
# 寻找所有的 3-cliques (三角形)
|
| 145 |
+
# enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的
|
| 146 |
+
# 或者使用 nx.triangles ? 不,那个只返回数量
|
| 147 |
+
# 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以
|
| 148 |
+
|
| 149 |
+
# 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w
|
| 150 |
+
triangles = []
|
| 151 |
+
adj = [set(G.neighbors(n)) for n in range(num_verts)]
|
| 152 |
+
|
| 153 |
+
# 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w
|
| 154 |
+
# 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u]
|
| 155 |
+
|
| 156 |
+
# 优化算法:
|
| 157 |
+
for u, v in edges:
|
| 158 |
+
if u > v: u, v = v, u # 确保有序
|
| 159 |
+
|
| 160 |
+
# 找公共邻居
|
| 161 |
+
common = adj[u].intersection(adj[v])
|
| 162 |
+
for w in common:
|
| 163 |
+
if w > v: # 强制顺序 u < v < w 防止重复
|
| 164 |
+
triangles.append([u, v, w])
|
| 165 |
+
|
| 166 |
+
return np.array(triangles)
|
| 167 |
+
|
| 168 |
+
def downsample_voxels(
|
| 169 |
+
voxels: torch.Tensor,
|
| 170 |
+
input_resolution: int,
|
| 171 |
+
output_resolution: int
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
if input_resolution % output_resolution != 0:
|
| 174 |
+
raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
|
| 175 |
+
f"by output_resolution ({output_resolution}).")
|
| 176 |
+
|
| 177 |
+
factor = input_resolution // output_resolution
|
| 178 |
+
|
| 179 |
+
downsampled_voxels = voxels.clone().to(torch.long)
|
| 180 |
+
|
| 181 |
+
downsampled_voxels[:, 1:] //= factor
|
| 182 |
+
|
| 183 |
+
unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
|
| 184 |
+
return unique_downsampled_voxels
|
| 185 |
+
|
| 186 |
+
def visualize_colored_points_ply(coords, vectors, filename):
|
| 187 |
+
"""
|
| 188 |
+
可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。
|
| 192 |
+
vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。
|
| 193 |
+
filename (str): 保存输出文件的名称,必须是 .ply 格式。
|
| 194 |
+
"""
|
| 195 |
+
# 确保输入是 numpy 数组
|
| 196 |
+
if isinstance(coords, torch.Tensor):
|
| 197 |
+
coords = coords.detach().cpu().numpy()
|
| 198 |
+
if isinstance(vectors, torch.Tensor):
|
| 199 |
+
vectors = vectors.detach().cpu().to(torch.float32).numpy()
|
| 200 |
+
|
| 201 |
+
# 检查输入数据是否为空,防止崩溃
|
| 202 |
+
if coords.size == 0 or vectors.size == 0:
|
| 203 |
+
print(f"警告:输入数据为空,未生成 {filename} 文件。")
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
# 将向量分量从 [-1, 1] 映射到 [0, 255]
|
| 207 |
+
# np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出
|
| 208 |
+
# (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2]
|
| 209 |
+
# * 127.5 将范围从 [0, 2] 缩放到 [0, 255]
|
| 210 |
+
colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8)
|
| 211 |
+
|
| 212 |
+
# 创建一个点云对象,并传入颜色信息
|
| 213 |
+
# trimesh.PointCloud 能够自动处理带颜色的点
|
| 214 |
+
points = trimesh.points.PointCloud(coords, colors=colors)
|
| 215 |
+
# 导出为 PLY 文件
|
| 216 |
+
points.export(filename, file_type='ply')
|
| 217 |
+
print(f"可视化文件已成功保存为: {filename}")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0):
|
| 221 |
+
# 转换为整数坐标并去重
|
| 222 |
+
print('len(pred_coords)', len(pred_coords))
|
| 223 |
+
|
| 224 |
+
pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0)
|
| 225 |
+
gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0)
|
| 226 |
+
print('len(pred_array)', len(pred_array))
|
| 227 |
+
pred_total = len(pred_array)
|
| 228 |
+
gt_total = len(gt_array)
|
| 229 |
+
|
| 230 |
+
# 如果没有点,直接返回
|
| 231 |
+
if pred_total == 0 or gt_total == 0:
|
| 232 |
+
return 0, 0.0, pred_total, gt_total
|
| 233 |
+
|
| 234 |
+
# 建立 KDTree(以 gt 为基准)
|
| 235 |
+
tree = KDTree(gt_array)
|
| 236 |
+
|
| 237 |
+
# 查找预测点到最近的 gt 点
|
| 238 |
+
dist, indices = tree.query(pred_array, k=1)
|
| 239 |
+
dist = dist.squeeze()
|
| 240 |
+
indices = indices.squeeze()
|
| 241 |
+
|
| 242 |
+
# 贪心去重:确保 1 对 1
|
| 243 |
+
matches = 0
|
| 244 |
+
used_gt = set()
|
| 245 |
+
for d, idx in zip(dist, indices):
|
| 246 |
+
if d <= threshold and idx not in used_gt:
|
| 247 |
+
matches += 1
|
| 248 |
+
used_gt.add(idx)
|
| 249 |
+
|
| 250 |
+
match_rate = matches / max(gt_total, 1)
|
| 251 |
+
|
| 252 |
+
return matches, match_rate, pred_total, gt_total
|
| 253 |
+
|
| 254 |
+
def flatten_coords_4d(coords_4d: torch.Tensor):
|
| 255 |
+
coords_4d_long = coords_4d.long()
|
| 256 |
+
|
| 257 |
+
base_x = 512
|
| 258 |
+
base_y = 512 * 512
|
| 259 |
+
base_z = 512 * 512 * 512
|
| 260 |
+
|
| 261 |
+
flat_coords = coords_4d_long[:, 0] * base_z + \
|
| 262 |
+
coords_4d_long[:, 1] * base_y + \
|
| 263 |
+
coords_4d_long[:, 2] * base_x + \
|
| 264 |
+
coords_4d_long[:, 3]
|
| 265 |
+
return flat_coords
|
| 266 |
+
|
| 267 |
+
class Tester:
|
| 268 |
+
def __init__(self, ckpt_path, config_path=None, dataset_path=None):
|
| 269 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 270 |
+
self.ckpt_path = ckpt_path
|
| 271 |
+
|
| 272 |
+
self.config = self._load_config(config_path)
|
| 273 |
+
self.dataset_path = dataset_path # or self.config['dataset']['path']
|
| 274 |
+
checkpoint = torch.load(self.ckpt_path, map_location='cpu')
|
| 275 |
+
self.epoch = checkpoint.get('epoch', 0)
|
| 276 |
+
|
| 277 |
+
self._init_models()
|
| 278 |
+
self._init_dataset()
|
| 279 |
+
|
| 280 |
+
self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results")
|
| 281 |
+
os.makedirs(self.result_dir, exist_ok=True)
|
| 282 |
+
|
| 283 |
+
dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '')
|
| 284 |
+
self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path),
|
| 285 |
+
f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs")
|
| 286 |
+
os.makedirs(self.output_voxel_dir, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path),
|
| 289 |
+
f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs")
|
| 290 |
+
os.makedirs(self.output_obj_dir, exist_ok=True)
|
| 291 |
+
|
| 292 |
+
def _save_logit_visualization(self, dense_vol, name, sample_name, ply_threshold=0.01):
|
| 293 |
+
"""
|
| 294 |
+
保存 Logit 的 3D .npy 文件、2D 最大投影热力图,以及带颜色和透明度的 3D .ply 点云
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
dense_vol: (H, W, D) numpy array, values in [0, 1]
|
| 298 |
+
name: str (e.g., "edge" or "vertex")
|
| 299 |
+
sample_name: str
|
| 300 |
+
ply_threshold: float, 只有概率大于此值的点才会被保存
|
| 301 |
+
"""
|
| 302 |
+
# 1. 保存原始 Dense 数据 (可选)
|
| 303 |
+
npy_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_logits.npy")
|
| 304 |
+
# np.save(npy_path, dense_vol)
|
| 305 |
+
|
| 306 |
+
# 2. 生成 2D 投影热力图 (保持不变)
|
| 307 |
+
proj_x = np.max(dense_vol, axis=0)
|
| 308 |
+
proj_y = np.max(dense_vol, axis=1)
|
| 309 |
+
proj_z = np.max(dense_vol, axis=2)
|
| 310 |
+
|
| 311 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 312 |
+
im0 = axes[0].imshow(proj_x, cmap='turbo', vmin=0, vmax=1, origin='lower')
|
| 313 |
+
axes[0].set_title(f"{name} Max-Proj (YZ)")
|
| 314 |
+
im1 = axes[1].imshow(proj_y, cmap='turbo', vmin=0, vmax=1, origin='lower')
|
| 315 |
+
axes[1].set_title(f"{name} Max-Proj (XZ)")
|
| 316 |
+
im2 = axes[2].imshow(proj_z, cmap='turbo', vmin=0, vmax=1, origin='lower')
|
| 317 |
+
axes[2].set_title(f"{name} Max-Proj (XY)")
|
| 318 |
+
|
| 319 |
+
fig.colorbar(im2, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
|
| 320 |
+
plt.suptitle(f"{sample_name} - {name} Occupancy Probability")
|
| 321 |
+
|
| 322 |
+
png_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_heatmap.png")
|
| 323 |
+
plt.savefig(png_path, dpi=150)
|
| 324 |
+
plt.close(fig)
|
| 325 |
+
|
| 326 |
+
# ------------------------------------------------------------------
|
| 327 |
+
# 3. 保存为带颜色和透明度(RGBA)的 PLY 点云
|
| 328 |
+
# ------------------------------------------------------------------
|
| 329 |
+
# 筛选出概率大于阈值的点坐标
|
| 330 |
+
indices = np.argwhere(dense_vol > ply_threshold)
|
| 331 |
+
|
| 332 |
+
if len(indices) > 0:
|
| 333 |
+
# 获取这些点的概率值 [0, 1]
|
| 334 |
+
values = dense_vol[indices[:, 0], indices[:, 1], indices[:, 2]]
|
| 335 |
+
|
| 336 |
+
# 使用 matplotlib 的 colormap 进行颜色映射
|
| 337 |
+
import matplotlib.cm as cm
|
| 338 |
+
cmap = cm.get_cmap('turbo')
|
| 339 |
+
|
| 340 |
+
# map values [0, 1] to RGBA [0, 1] (N, 4)
|
| 341 |
+
colors_float = cmap(values)
|
| 342 |
+
|
| 343 |
+
# -------------------------------------------------------
|
| 344 |
+
# 【核心修改】:修改 Alpha 通道 (透明度)
|
| 345 |
+
# -------------------------------------------------------
|
| 346 |
+
# 让透明度直接等于概率值。
|
| 347 |
+
# 概率 1.0 -> Alpha 1.0 (完全不透明/颜色深)
|
| 348 |
+
# 概率 0.1 -> Alpha 0.1 (非常透明/颜色浅)
|
| 349 |
+
colors_float[:, 3] = values
|
| 350 |
+
|
| 351 |
+
# 转换为 uint8 [0, 255],保留 4 个通道 (R, G, B, A)
|
| 352 |
+
colors_uint8 = (colors_float * 255).astype(np.uint8)
|
| 353 |
+
|
| 354 |
+
# 坐标转换
|
| 355 |
+
vertices = indices
|
| 356 |
+
|
| 357 |
+
ply_filename = f"{sample_name}_{name}_logits_colored.ply"
|
| 358 |
+
ply_save_path = os.path.join(self.output_voxel_dir, ply_filename)
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
# 使用 Trimesh 保存 (Trimesh 支持 (N, 4) 的 colors)
|
| 362 |
+
pcd = trimesh.points.PointCloud(vertices=vertices, colors=colors_uint8)
|
| 363 |
+
pcd.export(ply_save_path)
|
| 364 |
+
print(f"Saved colored RGBA logit PLY to {ply_save_path}")
|
| 365 |
+
except Exception as e:
|
| 366 |
+
print(f"Failed to save PLY with trimesh: {e}")
|
| 367 |
+
# Fallback: 手动写入 PLY (需要添加 alpha 属性)
|
| 368 |
+
with open(ply_save_path, 'w') as f:
|
| 369 |
+
f.write("ply\n")
|
| 370 |
+
f.write("format ascii 1.0\n")
|
| 371 |
+
f.write(f"element vertex {len(vertices)}\n")
|
| 372 |
+
f.write("property float x\n")
|
| 373 |
+
f.write("property float y\n")
|
| 374 |
+
f.write("property float z\n")
|
| 375 |
+
f.write("property uchar red\n")
|
| 376 |
+
f.write("property uchar green\n")
|
| 377 |
+
f.write("property uchar blue\n")
|
| 378 |
+
f.write("property uchar alpha\n") # 新增 Alpha 属性
|
| 379 |
+
f.write("end_header\n")
|
| 380 |
+
for i in range(len(vertices)):
|
| 381 |
+
v = vertices[i]
|
| 382 |
+
c = colors_uint8[i] # c is now (R, G, B, A)
|
| 383 |
+
f.write(f"{v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]} {c[3]}\n")
|
| 384 |
+
|
| 385 |
+
def _point_line_segment_distance(self, px, py, pz, x1, y1, z1, x2, y2, z2):
|
| 386 |
+
"""
|
| 387 |
+
计算点 (px,py,pz) 到线段 (x1,y1,z1)-(x2,y2,z2) 的最短距离的平方。
|
| 388 |
+
全部输入为 Tensor,支持广播。
|
| 389 |
+
"""
|
| 390 |
+
# 线段向量 AB
|
| 391 |
+
ABx = x2 - x1
|
| 392 |
+
ABy = y2 - y1
|
| 393 |
+
ABz = z2 - z1
|
| 394 |
+
|
| 395 |
+
# 向量 AP
|
| 396 |
+
APx = px - x1
|
| 397 |
+
APy = py - y1
|
| 398 |
+
APz = pz - z1
|
| 399 |
+
|
| 400 |
+
# AB 的长度平方
|
| 401 |
+
AB_sq = ABx**2 + ABy**2 + ABz**2
|
| 402 |
+
|
| 403 |
+
# 避免除以0 (如果两端点重合)
|
| 404 |
+
AB_sq = torch.clamp(AB_sq, min=1e-6)
|
| 405 |
+
|
| 406 |
+
# 投影系数 t = (AP · AB) / |AB|^2
|
| 407 |
+
t = (APx * ABx + APy * ABy + APz * ABz) / AB_sq
|
| 408 |
+
|
| 409 |
+
# 限制 t 在 [0, 1] 之间(线段约束)
|
| 410 |
+
t = torch.clamp(t, 0.0, 1.0)
|
| 411 |
+
|
| 412 |
+
# 最近点 (Projection)
|
| 413 |
+
closestX = x1 + t * ABx
|
| 414 |
+
closestY = y1 + t * ABy
|
| 415 |
+
closestZ = z1 + t * ABz
|
| 416 |
+
|
| 417 |
+
# 距离平方
|
| 418 |
+
dx = px - closestX
|
| 419 |
+
dy = py - closestY
|
| 420 |
+
dz = pz - closestZ
|
| 421 |
+
|
| 422 |
+
return dx**2 + dy**2 + dz**2
|
| 423 |
+
|
| 424 |
+
def _extract_mesh_projection_based(
|
| 425 |
+
self,
|
| 426 |
+
vtx_result: dict,
|
| 427 |
+
edge_result: dict,
|
| 428 |
+
resolution: int = 1024,
|
| 429 |
+
vtx_prob_threshold: float = 0.5,
|
| 430 |
+
|
| 431 |
+
# --- 你的新逻辑参数 ---
|
| 432 |
+
search_radius: float = 128.0, # 1. 候选边最大长度
|
| 433 |
+
project_dist_thresh: float = 1.5, # 2. 投影距离阈值 (管子半径,单位:voxel)
|
| 434 |
+
dir_align_threshold: float = 0.6, # 3. 方向相似度阈值 (cos theta)
|
| 435 |
+
connect_ratio_threshold: float = 0.4, # 4. 最终连接阈值 (匹配点数 / 理论长度)
|
| 436 |
+
|
| 437 |
+
edge_prob_threshold: float = 0.1, # 仅仅用于提取"存在的"体素
|
| 438 |
+
):
|
| 439 |
+
t_start = time.perf_counter()
|
| 440 |
+
|
| 441 |
+
# ---------------------------------------------------------------------
|
| 442 |
+
# 1. 准备全局数据:提取所有"活着"的 Edge Voxels (作为点云处理)
|
| 443 |
+
# ---------------------------------------------------------------------
|
| 444 |
+
e_probs = torch.sigmoid(edge_result['occ_probs'][:, 0])
|
| 445 |
+
e_coords = edge_result['coords_4d'][:, 1:].float() # (N, 3)
|
| 446 |
+
|
| 447 |
+
# 获取方向向量
|
| 448 |
+
if 'predicted_direction_feats' in edge_result:
|
| 449 |
+
e_dirs = edge_result['predicted_direction_feats'] # (N, 3)
|
| 450 |
+
# 归一化方向
|
| 451 |
+
e_dirs = F.normalize(e_dirs, p=2, dim=1)
|
| 452 |
+
else:
|
| 453 |
+
print("Warning: No direction features, using dummy.")
|
| 454 |
+
e_dirs = torch.zeros_like(e_coords)
|
| 455 |
+
|
| 456 |
+
# 筛选有效的 Edge Voxels (Global Point Cloud)
|
| 457 |
+
valid_mask = e_probs > edge_prob_threshold
|
| 458 |
+
|
| 459 |
+
cloud_coords = e_coords[valid_mask] # (M, 3)
|
| 460 |
+
cloud_dirs = e_dirs[valid_mask] # (M, 3)
|
| 461 |
+
|
| 462 |
+
num_cloud = cloud_coords.shape[0]
|
| 463 |
+
print(f"[Projection] Global active edge voxels: {num_cloud}")
|
| 464 |
+
|
| 465 |
+
if num_cloud == 0:
|
| 466 |
+
return [], []
|
| 467 |
+
|
| 468 |
+
# ---------------------------------------------------------------------
|
| 469 |
+
# 2. 准备顶点和候选边
|
| 470 |
+
# ---------------------------------------------------------------------
|
| 471 |
+
v_probs = torch.sigmoid(vtx_result['occ_probs'][:, 0])
|
| 472 |
+
v_coords = vtx_result['coords_4d'][:, 1:].float()
|
| 473 |
+
v_mask = v_probs > vtx_prob_threshold
|
| 474 |
+
valid_v_coords = v_coords[v_mask] # (V, 3)
|
| 475 |
+
|
| 476 |
+
if valid_v_coords.shape[0] < 2:
|
| 477 |
+
return valid_v_coords.cpu().numpy() / resolution, []
|
| 478 |
+
|
| 479 |
+
# 生成所有可能的候选边 (基于距离粗筛)
|
| 480 |
+
dists = torch.cdist(valid_v_coords, valid_v_coords)
|
| 481 |
+
triu_mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
|
| 482 |
+
cand_mask = (dists < search_radius) & triu_mask
|
| 483 |
+
cand_indices = torch.nonzero(cand_mask, as_tuple=False) # (E_cand, 2)
|
| 484 |
+
|
| 485 |
+
p1s = valid_v_coords[cand_indices[:, 0]] # (E, 3)
|
| 486 |
+
p2s = valid_v_coords[cand_indices[:, 1]] # (E, 3)
|
| 487 |
+
|
| 488 |
+
num_candidates = p1s.shape[0]
|
| 489 |
+
print(f"[Projection] Checking {num_candidates} candidate pairs...")
|
| 490 |
+
|
| 491 |
+
# ---------------------------------------------------------------------
|
| 492 |
+
# 3. 循环处理候选边 (使用 Bounding Box 快速裁剪)
|
| 493 |
+
# ---------------------------------------------------------------------
|
| 494 |
+
final_edges = []
|
| 495 |
+
|
| 496 |
+
# 预计算所有候选边的方向和长度
|
| 497 |
+
edge_vecs = p2s - p1s
|
| 498 |
+
edge_lengths = torch.norm(edge_vecs, dim=1)
|
| 499 |
+
edge_dirs = F.normalize(edge_vecs, p=2, dim=1)
|
| 500 |
+
|
| 501 |
+
# 为了避免显存爆炸,也不要在 Python 里做太慢的循环
|
| 502 |
+
# 我们对点云进行操作太慢,对每一条边去遍历整个点云也太慢。
|
| 503 |
+
# 策略:
|
| 504 |
+
# 我们循环“边”,但在循环内部利用 mask 快速筛选点云。
|
| 505 |
+
# 由于 Python 循环 10000 次会很慢,我们只处理那些有希望的边。
|
| 506 |
+
# 这里为了演示逻辑的准确性,我们使用简单的循环,但在 GPU 上做计算。
|
| 507 |
+
|
| 508 |
+
# 将全局点云拆分到各个坐标轴,便于快速 BBox 筛选
|
| 509 |
+
cx, cy, cz = cloud_coords[:, 0], cloud_coords[:, 1], cloud_coords[:, 2]
|
| 510 |
+
|
| 511 |
+
# 优化:如果候选边太多,可以分块。这里假设边在 5万以内,点在 10万以内,可以处理。
|
| 512 |
+
|
| 513 |
+
# 这一步是瓶颈,我们尝试用 Python 循环,但只对局部点计算
|
| 514 |
+
# 为了加速,我们可以将点云放入 HashGrid 或者只是简单的 BBox Check。
|
| 515 |
+
|
| 516 |
+
# 让我们用简单的逻辑:对于每条边,找出 BBox 内的点,算距离。
|
| 517 |
+
# 这里的 batch_size 是指一次并行处理多少条边
|
| 518 |
+
|
| 519 |
+
batch_size = 128 # 每次处理 128 条边
|
| 520 |
+
|
| 521 |
+
for i in range(0, num_candidates, batch_size):
|
| 522 |
+
end = min(i + batch_size, num_candidates)
|
| 523 |
+
|
| 524 |
+
# 当前批次的边数据
|
| 525 |
+
b_p1 = p1s[i:end] # (B, 3)
|
| 526 |
+
b_p2 = p2s[i:end] # (B, 3)
|
| 527 |
+
b_dirs = edge_dirs[i:end] # (B, 3)
|
| 528 |
+
b_lens = edge_lengths[i:end] # (B,)
|
| 529 |
+
|
| 530 |
+
# --- 步骤 A: 投影 & 距离检查 ---
|
| 531 |
+
# 这是一个 (B, M) 的大矩阵计算,容易 OOM。
|
| 532 |
+
# M (点云数) 可能很大。
|
| 533 |
+
# 解决方法:我们反过来思考。
|
| 534 |
+
# 不计算矩阵,我们只对单个边进行循环?太慢。
|
| 535 |
+
|
| 536 |
+
# 实用优化:只对 bounding box 内的点进行距离计算。
|
| 537 |
+
# 由于 GPU 难以动态索引不规则数据,我们还是逐个边循环比较稳妥,
|
| 538 |
+
# 但为了 Python 速度,必须尽可能向量化。
|
| 539 |
+
|
| 540 |
+
# 这里我采用一种折中方案:逐个处理边,但是利用 torch.where 快速定位。
|
| 541 |
+
# 实际上,对于 Python 里的 for loop,几千次是可以接受的。
|
| 542 |
+
|
| 543 |
+
current_edges_indices = cand_indices[i:end]
|
| 544 |
+
|
| 545 |
+
for j in range(len(b_p1)):
|
| 546 |
+
# 单条边处理
|
| 547 |
+
p1 = b_p1[j]
|
| 548 |
+
p2 = b_p2[j]
|
| 549 |
+
e_dir = b_dirs[j]
|
| 550 |
+
e_len = b_lens[j].item()
|
| 551 |
+
|
| 552 |
+
# 1. Bounding Box Filter (快速大幅裁剪)
|
| 553 |
+
# 找出这条边 BBox 范围内的所有点 (+ padding)
|
| 554 |
+
padding = project_dist_thresh + 2.0
|
| 555 |
+
min_xyz = torch.min(p1, p2) - padding
|
| 556 |
+
max_xyz = torch.max(p1, p2) + padding
|
| 557 |
+
|
| 558 |
+
# 利用 boolean mask 筛选
|
| 559 |
+
mask_x = (cx >= min_xyz[0]) & (cx <= max_xyz[0])
|
| 560 |
+
mask_y = (cy >= min_xyz[1]) & (cy <= max_xyz[1])
|
| 561 |
+
mask_z = (cz >= min_xyz[2]) & (cz <= max_xyz[2])
|
| 562 |
+
bbox_mask = mask_x & mask_y & mask_z
|
| 563 |
+
|
| 564 |
+
subset_coords = cloud_coords[bbox_mask]
|
| 565 |
+
subset_dirs = cloud_dirs[bbox_mask]
|
| 566 |
+
|
| 567 |
+
if subset_coords.shape[0] == 0:
|
| 568 |
+
continue
|
| 569 |
+
|
| 570 |
+
# 2. 精确距离计算 (Projection Distance)
|
| 571 |
+
# 计算 subset 中每个点到线段 p1-p2 的距离平方
|
| 572 |
+
dist_sq = self._point_line_segment_distance(
|
| 573 |
+
subset_coords[:, 0], subset_coords[:, 1], subset_coords[:, 2],
|
| 574 |
+
p1[0], p1[1], p1[2],
|
| 575 |
+
p2[0], p2[1], p2[2]
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
# 3. 距离阈值过滤 (Keep voxels inside the tube)
|
| 579 |
+
dist_mask = dist_sq < (project_dist_thresh ** 2)
|
| 580 |
+
|
| 581 |
+
# 获取在管子内部的体素
|
| 582 |
+
tube_dirs = subset_dirs[dist_mask]
|
| 583 |
+
|
| 584 |
+
if tube_dirs.shape[0] == 0:
|
| 585 |
+
continue
|
| 586 |
+
|
| 587 |
+
# 4. 方向一致性检查 (Direction Check)
|
| 588 |
+
# 计算点积 (cos theta)
|
| 589 |
+
# e_dir 是 (3,), tube_dirs 是 (K, 3)
|
| 590 |
+
dot_prod = torch.matmul(tube_dirs, e_dir)
|
| 591 |
+
|
| 592 |
+
# 这里使用 abs,因为边可能是无向的,或者网络预测可能反向
|
| 593 |
+
# 如果你的网络严格预测流向,可以去掉 abs
|
| 594 |
+
dir_sim = torch.abs(dot_prod)
|
| 595 |
+
|
| 596 |
+
# 统计方向符合要求的体素数量
|
| 597 |
+
valid_voxel_count = (dir_sim > dir_align_threshold).sum().item()
|
| 598 |
+
|
| 599 |
+
# 5. 比值判决 (Ratio Check)
|
| 600 |
+
# 量化出的 Voxel 数目 ≈ 边的长度 (e_len)
|
| 601 |
+
# 如果 e_len 很小(比如<1),我们设为1防止除以0
|
| 602 |
+
theoretical_count = max(e_len, 1.0)
|
| 603 |
+
|
| 604 |
+
ratio = valid_voxel_count / theoretical_count
|
| 605 |
+
|
| 606 |
+
if ratio > connect_ratio_threshold:
|
| 607 |
+
# 找到了!
|
| 608 |
+
global_idx = i + j
|
| 609 |
+
edge_tuple = cand_indices[global_idx].cpu().numpy().tolist()
|
| 610 |
+
final_edges.append(edge_tuple)
|
| 611 |
+
|
| 612 |
+
t_end = time.perf_counter()
|
| 613 |
+
print(f"[Projection] Logic finished. Accepted {len(final_edges)} edges. Time={t_end - t_start:.4f}s")
|
| 614 |
+
|
| 615 |
+
out_vertices = valid_v_coords.cpu().numpy() / resolution
|
| 616 |
+
return out_vertices, final_edges
|
| 617 |
+
|
| 618 |
+
def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str):
|
| 619 |
+
if coords.numel() == 0:
|
| 620 |
+
return
|
| 621 |
+
|
| 622 |
+
coords_np = coords.cpu().to(torch.float32).numpy()
|
| 623 |
+
labels_np = labels.cpu().to(torch.float32).numpy()
|
| 624 |
+
|
| 625 |
+
colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8)
|
| 626 |
+
colors[labels_np == 0] = [255, 0, 0]
|
| 627 |
+
colors[labels_np == 1] = [0, 0, 255]
|
| 628 |
+
|
| 629 |
+
try:
|
| 630 |
+
import trimesh
|
| 631 |
+
point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors)
|
| 632 |
+
ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
|
| 633 |
+
point_cloud.export(ply_path)
|
| 634 |
+
except ImportError:
|
| 635 |
+
ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
|
| 636 |
+
with open(ply_path, 'w') as f:
|
| 637 |
+
f.write("ply\n")
|
| 638 |
+
f.write("format ascii 1.0\n")
|
| 639 |
+
f.write(f"element vertex {coords_np.shape[0]}\n")
|
| 640 |
+
f.write("property float x\n")
|
| 641 |
+
f.write("property float y\n")
|
| 642 |
+
f.write("property float z\n")
|
| 643 |
+
f.write("property uchar red\n")
|
| 644 |
+
f.write("property uchar green\n")
|
| 645 |
+
f.write("property uchar blue\n")
|
| 646 |
+
f.write("end_header\n")
|
| 647 |
+
for i in range(coords_np.shape[0]):
|
| 648 |
+
f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
|
| 649 |
+
|
| 650 |
+
def _load_config(self, config_path=None):
|
| 651 |
+
if config_path and os.path.exists(config_path):
|
| 652 |
+
with open(config_path) as f:
|
| 653 |
+
return yaml.safe_load(f)
|
| 654 |
+
|
| 655 |
+
ckpt_dir = os.path.dirname(self.ckpt_path)
|
| 656 |
+
possible_configs = [
|
| 657 |
+
os.path.join(ckpt_dir, "config.yaml"),
|
| 658 |
+
os.path.join(os.path.dirname(ckpt_dir), "config.yaml")
|
| 659 |
+
]
|
| 660 |
+
|
| 661 |
+
for config_file in possible_configs:
|
| 662 |
+
if os.path.exists(config_file):
|
| 663 |
+
with open(config_file) as f:
|
| 664 |
+
print(f"Loaded config from: {config_file}")
|
| 665 |
+
return yaml.safe_load(f)
|
| 666 |
+
|
| 667 |
+
checkpoint = torch.load(self.ckpt_path, map_location='cpu')
|
| 668 |
+
if 'config' in checkpoint:
|
| 669 |
+
print("Loaded config from checkpoint")
|
| 670 |
+
return checkpoint['config']
|
| 671 |
+
|
| 672 |
+
raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.")
|
| 673 |
+
|
| 674 |
+
def _init_models(self):
|
| 675 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 676 |
+
in_channels=15,
|
| 677 |
+
hidden_dim=256,
|
| 678 |
+
out_channels=1024,
|
| 679 |
+
scatter_type='mean',
|
| 680 |
+
n_blocks=5,
|
| 681 |
+
resolution=128,
|
| 682 |
+
|
| 683 |
+
).to(self.device)
|
| 684 |
+
|
| 685 |
+
self.connection_head = ConnectionHead(
|
| 686 |
+
channels=128 * 2,
|
| 687 |
+
out_channels=1,
|
| 688 |
+
mlp_ratio=4,
|
| 689 |
+
).to(self.device)
|
| 690 |
+
|
| 691 |
+
self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation
|
| 692 |
+
in_channels=self.config['model']['in_channels'],
|
| 693 |
+
latent_dim=self.config['model']['latent_dim'],
|
| 694 |
+
encoder_blocks=self.config['model']['encoder_blocks'],
|
| 695 |
+
# decoder_blocks=self.config['model']['decoder_blocks'],
|
| 696 |
+
decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'],
|
| 697 |
+
decoder_blocks_edge=self.config['model']['decoder_blocks_edge'],
|
| 698 |
+
num_heads=8,
|
| 699 |
+
num_head_channels=64,
|
| 700 |
+
mlp_ratio=4.0,
|
| 701 |
+
attn_mode="swin",
|
| 702 |
+
window_size=8,
|
| 703 |
+
pe_mode="ape",
|
| 704 |
+
use_fp16=False,
|
| 705 |
+
use_checkpoint=False,
|
| 706 |
+
qk_rms_norm=False,
|
| 707 |
+
using_subdivide=True,
|
| 708 |
+
using_attn=self.config['model']['using_attn'],
|
| 709 |
+
attn_first=self.config['model'].get('attn_first', True),
|
| 710 |
+
pred_direction=self.config['model'].get('pred_direction', False),
|
| 711 |
+
).to(self.device)
|
| 712 |
+
|
| 713 |
+
load_pretrained_woself(
|
| 714 |
+
checkpoint_path=self.ckpt_path,
|
| 715 |
+
voxel_encoder=self.voxel_encoder,
|
| 716 |
+
connection_head=self.connection_head,
|
| 717 |
+
vae=self.vae,
|
| 718 |
+
)
|
| 719 |
+
# --- 【新增】在这里添加权重检查逻辑 ---
|
| 720 |
+
print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---")
|
| 721 |
+
has_nan_inf = False
|
| 722 |
+
if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"):
|
| 723 |
+
has_nan_inf = True
|
| 724 |
+
|
| 725 |
+
if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"):
|
| 726 |
+
has_nan_inf = True
|
| 727 |
+
|
| 728 |
+
if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"):
|
| 729 |
+
has_nan_inf = True
|
| 730 |
+
|
| 731 |
+
if not has_nan_inf:
|
| 732 |
+
print("--- 权重检查通过。未发现 NaN/Inf 值。 ---")
|
| 733 |
+
else:
|
| 734 |
+
# 如果发现坏值,直接抛出异常,因为评估无法继续
|
| 735 |
+
raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。")
|
| 736 |
+
# --- 检查逻辑结束 ---
|
| 737 |
+
|
| 738 |
+
self.vae.eval()
|
| 739 |
+
self.voxel_encoder.eval()
|
| 740 |
+
self.connection_head.eval()
|
| 741 |
+
|
| 742 |
+
def _init_dataset(self):
|
| 743 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 744 |
+
root_dir=self.dataset_path,
|
| 745 |
+
base_resolution=self.config['dataset']['base_resolution'],
|
| 746 |
+
min_resolution=self.config['dataset']['min_resolution'],
|
| 747 |
+
cache_dir='/home/tiger/yy/src/dataset_cache/test_15c_dora',
|
| 748 |
+
# cache_dir=self.config['dataset']['cache_dir'],
|
| 749 |
+
renders_dir=self.config['dataset']['renders_dir'],
|
| 750 |
+
|
| 751 |
+
# filter_active_voxels=self.config['dataset']['filter_active_voxels'],
|
| 752 |
+
filter_active_voxels=False,
|
| 753 |
+
cache_filter_path=self.config['dataset']['cache_filter_path'],
|
| 754 |
+
|
| 755 |
+
sample_type=self.config['dataset']['sample_type'],
|
| 756 |
+
active_voxel_res=128,
|
| 757 |
+
pc_sample_number=819200,
|
| 758 |
+
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
self.dataloader = DataLoader(
|
| 762 |
+
self.dataset,
|
| 763 |
+
batch_size=1,
|
| 764 |
+
shuffle=False,
|
| 765 |
+
collate_fn=partial(collate_fn_pointnet),
|
| 766 |
+
num_workers=0,
|
| 767 |
+
pin_memory=True,
|
| 768 |
+
# prefetch_factor=4,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool:
|
| 772 |
+
"""
|
| 773 |
+
检查模型的所有参数中是否存在 NaN 或 Inf 值。
|
| 774 |
+
|
| 775 |
+
Args:
|
| 776 |
+
model (torch.nn.Module): 要检查的模型。
|
| 777 |
+
model_name (str): 模型的名称,用于打印日志。
|
| 778 |
+
|
| 779 |
+
Returns:
|
| 780 |
+
bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。
|
| 781 |
+
"""
|
| 782 |
+
found_issue = False
|
| 783 |
+
for name, param in model.named_parameters():
|
| 784 |
+
if torch.isnan(param.data).any():
|
| 785 |
+
print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!")
|
| 786 |
+
found_issue = True
|
| 787 |
+
if torch.isinf(param.data).any():
|
| 788 |
+
print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!")
|
| 789 |
+
found_issue = True
|
| 790 |
+
return found_issue
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
|
| 794 |
+
"""
|
| 795 |
+
修改后的函数,确保一对一匹配,并优先匹配最近的点对。
|
| 796 |
+
"""
|
| 797 |
+
pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
|
| 798 |
+
gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
|
| 799 |
+
|
| 800 |
+
pred_total = len(pred_array)
|
| 801 |
+
gt_total = len(gt_array)
|
| 802 |
+
|
| 803 |
+
if pred_total == 0 or gt_total == 0:
|
| 804 |
+
return {
|
| 805 |
+
'recall': 0.0,
|
| 806 |
+
'precision': 0.0,
|
| 807 |
+
'f1': 0.0,
|
| 808 |
+
'matches': 0,
|
| 809 |
+
'pred_count': pred_total,
|
| 810 |
+
'gt_count': gt_total
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
# 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点
|
| 814 |
+
tree = cKDTree(pred_array)
|
| 815 |
+
dists, pred_idxs = tree.query(gt_array, k=1)
|
| 816 |
+
|
| 817 |
+
# --- 核心修改部分 ---
|
| 818 |
+
|
| 819 |
+
# 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引)
|
| 820 |
+
# 这样我们就可以按距离对所有可能的匹配进行排序
|
| 821 |
+
possible_matches = []
|
| 822 |
+
for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)):
|
| 823 |
+
if dist <= threshold:
|
| 824 |
+
possible_matches.append((dist, gt_idx, pred_idx))
|
| 825 |
+
|
| 826 |
+
# 2. 按距离从小到大排序(贪心策略)
|
| 827 |
+
possible_matches.sort(key=lambda x: x[0])
|
| 828 |
+
|
| 829 |
+
matches = 0
|
| 830 |
+
# 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配
|
| 831 |
+
used_pred_indices = set()
|
| 832 |
+
used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨
|
| 833 |
+
|
| 834 |
+
# 3. 遍历排序后的可能匹配,进行一对一分配
|
| 835 |
+
for dist, gt_idx, pred_idx in possible_matches:
|
| 836 |
+
# 如果这个预测点和这个真实点都还没有被使用过
|
| 837 |
+
if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices:
|
| 838 |
+
matches += 1
|
| 839 |
+
used_pred_indices.add(pred_idx)
|
| 840 |
+
used_gt_indices.add(gt_idx)
|
| 841 |
+
|
| 842 |
+
# --- 修改结束 ---
|
| 843 |
+
|
| 844 |
+
# matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total
|
| 845 |
+
recall = matches / gt_total if gt_total > 0 else 0.0
|
| 846 |
+
precision = matches / pred_total if pred_total > 0 else 0.0
|
| 847 |
+
|
| 848 |
+
# 计算F1时,使用标准的 Precision 和 Recall 定义
|
| 849 |
+
if (precision + recall) == 0:
|
| 850 |
+
f1 = 0.0
|
| 851 |
+
else:
|
| 852 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 853 |
+
|
| 854 |
+
return {
|
| 855 |
+
'recall': recall,
|
| 856 |
+
'precision': precision,
|
| 857 |
+
'f1': f1,
|
| 858 |
+
'matches': matches,
|
| 859 |
+
'pred_count': pred_total,
|
| 860 |
+
'gt_count': gt_total
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
|
| 864 |
+
"""
|
| 865 |
+
一个折衷的顶点指标计算方案。
|
| 866 |
+
它沿用“为每个真实点寻找最近预测点”的逻辑,
|
| 867 |
+
但通过修正计算方式,确保Precision和F1值不会超过1.0。
|
| 868 |
+
"""
|
| 869 |
+
# 假设 pred_coords 和 gt_coords 是 PyTorch 张量
|
| 870 |
+
pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
|
| 871 |
+
gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
|
| 872 |
+
|
| 873 |
+
pred_total = len(pred_array)
|
| 874 |
+
gt_total = len(gt_array)
|
| 875 |
+
|
| 876 |
+
if pred_total == 0 or gt_total == 0:
|
| 877 |
+
return {
|
| 878 |
+
'recall': 0.0,
|
| 879 |
+
'precision': 0.0,
|
| 880 |
+
'f1': 0.0,
|
| 881 |
+
'matches': 0,
|
| 882 |
+
'pred_count': pred_total,
|
| 883 |
+
'gt_count': gt_total
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
# 在预测点上构建KD-Tree,为每个真实点查找最近的预测点
|
| 887 |
+
tree = cKDTree(pred_array)
|
| 888 |
+
dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引
|
| 889 |
+
|
| 890 |
+
# 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall)
|
| 891 |
+
# 这和您的第一个函数完全一样。
|
| 892 |
+
# 这个值代表了“有多少个真实点被成功找到了”。
|
| 893 |
+
matches_from_gt = np.sum(dists <= threshold)
|
| 894 |
+
|
| 895 |
+
# 2. 计算 Recall (召回率)
|
| 896 |
+
# 召回率的分母是真实点的总数,所以这里的计算是合理的。
|
| 897 |
+
recall = matches_from_gt / gt_total if gt_total > 0 else 0.0
|
| 898 |
+
|
| 899 |
+
# 3. 计算 Precision (精确率) - ✅ 这是核心修正点
|
| 900 |
+
# 精确率的分母是预测点的总数。
|
| 901 |
+
# 分子(True Positives)不能超过预测点的总数。
|
| 902 |
+
# 因此,我们取 matches_from_gt 和 pred_total 中的较小值。
|
| 903 |
+
# 这解决了 Precision > 1 的问题。
|
| 904 |
+
tp_for_precision = min(matches_from_gt, pred_total)
|
| 905 |
+
precision = tp_for_precision / pred_total if pred_total > 0 else 0.0
|
| 906 |
+
|
| 907 |
+
# 4. 使用标准的F1分数公式
|
| 908 |
+
# 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score,
|
| 909 |
+
# 更常用的是基于 Precision 和 Recall 的调和平均数。
|
| 910 |
+
if (precision + recall) == 0:
|
| 911 |
+
f1 = 0.0
|
| 912 |
+
else:
|
| 913 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 914 |
+
|
| 915 |
+
return {
|
| 916 |
+
'recall': recall,
|
| 917 |
+
'precision': precision,
|
| 918 |
+
'f1': f1,
|
| 919 |
+
'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察
|
| 920 |
+
'pred_count': pred_total,
|
| 921 |
+
'gt_count': gt_total
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False):
|
| 925 |
+
if len(p1) == 0 or len(p2) == 0:
|
| 926 |
+
return float('nan')
|
| 927 |
+
|
| 928 |
+
dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean()
|
| 929 |
+
|
| 930 |
+
if one_sided:
|
| 931 |
+
return dist_p1_p2.item()
|
| 932 |
+
else:
|
| 933 |
+
dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean()
|
| 934 |
+
return (dist_p1_p2 + dist_p2_p1).item() / 2
|
| 935 |
+
|
| 936 |
+
def visualize_latent_space_pca(self, sample_idx: int):
|
| 937 |
+
"""
|
| 938 |
+
Encodes a sample, performs PCA on its latent features, and saves a
|
| 939 |
+
colored PLY file for visualization.
|
| 940 |
+
|
| 941 |
+
The position of each point in the PLY file corresponds to the spatial
|
| 942 |
+
location in the latent grid.
|
| 943 |
+
|
| 944 |
+
The color of each point represents the first three principal components
|
| 945 |
+
of its feature vector.
|
| 946 |
+
"""
|
| 947 |
+
print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---")
|
| 948 |
+
self.vae.eval()
|
| 949 |
+
|
| 950 |
+
try:
|
| 951 |
+
# 1. Get the latent representation for the sample
|
| 952 |
+
latent = self._get_latent_for_sample(sample_idx)
|
| 953 |
+
except ValueError as e:
|
| 954 |
+
print(f"Error: {e}")
|
| 955 |
+
return
|
| 956 |
+
|
| 957 |
+
latent_coords = latent.coords.detach().cpu().numpy()
|
| 958 |
+
latent_feats = latent.feats.detach().cpu().numpy()
|
| 959 |
+
|
| 960 |
+
if latent_feats.shape[0] < 3:
|
| 961 |
+
print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.")
|
| 962 |
+
return
|
| 963 |
+
|
| 964 |
+
print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...")
|
| 965 |
+
|
| 966 |
+
# 2. Perform PCA to reduce feature dimensions to 3
|
| 967 |
+
pca = PCA(n_components=3)
|
| 968 |
+
pca_features = pca.fit_transform(latent_feats)
|
| 969 |
+
|
| 970 |
+
print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}")
|
| 971 |
+
print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}")
|
| 972 |
+
|
| 973 |
+
# 3. Normalize the PCA components to be used as RGB colors [0, 255]
|
| 974 |
+
# We normalize each component independently to maximize color contrast
|
| 975 |
+
normalized_colors = np.zeros_like(pca_features)
|
| 976 |
+
for i in range(3):
|
| 977 |
+
min_val = pca_features[:, i].min()
|
| 978 |
+
max_val = pca_features[:, i].max()
|
| 979 |
+
if max_val - min_val > 1e-6:
|
| 980 |
+
normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val)
|
| 981 |
+
else:
|
| 982 |
+
normalized_colors[:, i] = 0.5 # Handle case of constant value
|
| 983 |
+
|
| 984 |
+
colors_uint8 = (normalized_colors * 255).astype(np.uint8)
|
| 985 |
+
|
| 986 |
+
# 4. Prepare spatial coordinates for the point cloud
|
| 987 |
+
# latent_coords is (batch_idx, x, y, z), we want the xyz part
|
| 988 |
+
spatial_coords = latent_coords[:, 1:]
|
| 989 |
+
|
| 990 |
+
# 5. Create and save the colored PLY file
|
| 991 |
+
try:
|
| 992 |
+
# Create a Trimesh PointCloud object
|
| 993 |
+
point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8)
|
| 994 |
+
|
| 995 |
+
# Define the output filename
|
| 996 |
+
filename = f"sample_{sample_idx}_latent_pca.ply"
|
| 997 |
+
ply_path = os.path.join(self.output_voxel_dir, filename)
|
| 998 |
+
|
| 999 |
+
# Export the file
|
| 1000 |
+
point_cloud.export(ply_path)
|
| 1001 |
+
print(f"--> Successfully saved PCA visualization to: {ply_path}")
|
| 1002 |
+
|
| 1003 |
+
except Exception as e:
|
| 1004 |
+
print(f"Error during Trimesh export: {e}")
|
| 1005 |
+
print("Please ensure 'trimesh' is installed correctly.")
|
| 1006 |
+
|
| 1007 |
+
def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor:
|
| 1008 |
+
"""
|
| 1009 |
+
Encodes a single sample and returns its latent representation.
|
| 1010 |
+
"""
|
| 1011 |
+
print(f"--> Encoding sample {sample_idx} to get its latent vector...")
|
| 1012 |
+
# Get data for the specified sample
|
| 1013 |
+
batch_data = self.dataset[sample_idx]
|
| 1014 |
+
if batch_data is None:
|
| 1015 |
+
raise ValueError(f"Sample at index {sample_idx} could not be loaded.")
|
| 1016 |
+
|
| 1017 |
+
# Use the collate function to form a batch
|
| 1018 |
+
batch_data = collate_fn_pointnet([batch_data])
|
| 1019 |
+
|
| 1020 |
+
with torch.no_grad():
|
| 1021 |
+
active_coords = batch_data['active_voxels_128'].to(self.device)
|
| 1022 |
+
point_cloud = batch_data['point_cloud_128'].to(self.device)
|
| 1023 |
+
|
| 1024 |
+
active_voxel_feats = self.voxel_encoder(
|
| 1025 |
+
p=point_cloud,
|
| 1026 |
+
sparse_coords=active_coords,
|
| 1027 |
+
res=128,
|
| 1028 |
+
bbox_size=(-0.5, 0.5),
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
sparse_input = SparseTensor(
|
| 1032 |
+
feats=active_voxel_feats,
|
| 1033 |
+
coords=active_coords.int()
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# 2. Encode to get the latent representation
|
| 1037 |
+
latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,)
|
| 1038 |
+
print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}")
|
| 1039 |
+
return latent_128
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.):
|
| 1044 |
+
total_samples = len(self.dataset)
|
| 1045 |
+
eval_samples = min(num_samples or total_samples, total_samples)
|
| 1046 |
+
sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples)
|
| 1047 |
+
# sample_indices = range(eval_samples)
|
| 1048 |
+
|
| 1049 |
+
eval_dataset = Subset(self.dataset, sample_indices)
|
| 1050 |
+
eval_loader = DataLoader(
|
| 1051 |
+
eval_dataset,
|
| 1052 |
+
batch_size=1,
|
| 1053 |
+
shuffle=False,
|
| 1054 |
+
collate_fn=partial(collate_fn_pointnet),
|
| 1055 |
+
num_workers=self.config['training']['num_workers'],
|
| 1056 |
+
pin_memory=True,
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
per_sample_metrics = {
|
| 1060 |
+
'vertex': {res: [] for res in [128, 256, 512]},
|
| 1061 |
+
'edge': {res: [] for res in [128, 256, 512]},
|
| 1062 |
+
'sample_names': []
|
| 1063 |
+
}
|
| 1064 |
+
avg_metrics = {
|
| 1065 |
+
'vertex': {res: defaultdict(list) for res in [128, 256, 512]},
|
| 1066 |
+
'edge': {res: defaultdict(list) for res in [128, 256, 512]},
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
self.vae.eval()
|
| 1070 |
+
|
| 1071 |
+
for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")):
|
| 1072 |
+
if batch_data is None:
|
| 1073 |
+
continue
|
| 1074 |
+
sample_idx = sample_indices[batch_idx]
|
| 1075 |
+
sample_name = f'sample_{sample_idx}'
|
| 1076 |
+
per_sample_metrics['sample_names'].append(sample_name)
|
| 1077 |
+
|
| 1078 |
+
batch_save_path = f"/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/143500_sample_active_vis_42seed_trellis/gt_data_batch_{batch_idx}.pt"
|
| 1079 |
+
if not os.path.exists(batch_save_path):
|
| 1080 |
+
print(f"Warning: Saved batch file not found: {batch_save_path}")
|
| 1081 |
+
continue
|
| 1082 |
+
batch_data = torch.load(batch_save_path, map_location=self.device)
|
| 1083 |
+
|
| 1084 |
+
with torch.no_grad():
|
| 1085 |
+
# 1. Get input data
|
| 1086 |
+
combined_voxels_512 = batch_data['combined_voxels_512'].to(self.device)
|
| 1087 |
+
combined_voxel_labels_512 = batch_data['combined_voxel_labels_512'].to(self.device)
|
| 1088 |
+
gt_combined_endpoints_512 = batch_data['gt_combined_endpoints_512'].to(self.device)
|
| 1089 |
+
gt_combined_errors_512 = batch_data['gt_combined_errors_512'].to(self.device)
|
| 1090 |
+
|
| 1091 |
+
edge_mask = (combined_voxel_labels_512 == 1)
|
| 1092 |
+
|
| 1093 |
+
gt_edge_endpoints_512 = gt_combined_endpoints_512[edge_mask].to(self.device)
|
| 1094 |
+
|
| 1095 |
+
gt_edge_voxels_512 = combined_voxels_512[edge_mask].to(self.device)
|
| 1096 |
+
|
| 1097 |
+
p1 = gt_edge_endpoints_512[:, 1:4].float()
|
| 1098 |
+
p2 = gt_edge_endpoints_512[:, 4:7].float()
|
| 1099 |
+
|
| 1100 |
+
mask = ( (p1[:,0] < p2[:,0]) |
|
| 1101 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
|
| 1102 |
+
((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
|
| 1103 |
+
|
| 1104 |
+
pA = torch.where(mask[:, None], p1, p2) # smaller one
|
| 1105 |
+
pB = torch.where(mask[:, None], p2, p1) # larger one
|
| 1106 |
+
|
| 1107 |
+
d = pB - pA
|
| 1108 |
+
dir_gt = F.normalize(d, dim=-1, eps=1e-6)
|
| 1109 |
+
|
| 1110 |
+
gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'].to(self.device).int()
|
| 1111 |
+
|
| 1112 |
+
vtx_128 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=128)
|
| 1113 |
+
vtx_256 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=256)
|
| 1114 |
+
|
| 1115 |
+
edge_128 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=128)
|
| 1116 |
+
edge_256 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=256)
|
| 1117 |
+
edge_512 = combined_voxels_512
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
gt_edge_voxels_list = [
|
| 1121 |
+
edge_128,
|
| 1122 |
+
edge_256,
|
| 1123 |
+
edge_512,
|
| 1124 |
+
]
|
| 1125 |
+
|
| 1126 |
+
active_coords = batch_data['active_voxels_128'].to(self.device)
|
| 1127 |
+
point_cloud = batch_data['point_cloud_128'].to(self.device)
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
active_voxel_feats = self.voxel_encoder(
|
| 1131 |
+
p=point_cloud,
|
| 1132 |
+
sparse_coords=active_coords,
|
| 1133 |
+
res=128,
|
| 1134 |
+
bbox_size=(-0.5, 0.5),
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
sparse_input = SparseTensor(
|
| 1138 |
+
feats=active_voxel_feats,
|
| 1139 |
+
coords=active_coords.int()
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
latent_128, posterior = self.vae.encode(sparse_input)
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
load_path = f'/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/143500_sample_active_vis_42seed_trellis/sample_latent_{batch_idx}.pt'
|
| 1146 |
+
latent_128 = torch.load(load_path, map_location=self.device)
|
| 1147 |
+
|
| 1148 |
+
print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std())
|
| 1149 |
+
print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean())
|
| 1150 |
+
print('latent_128.coords.shape', latent_128.coords.shape)
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
latent_128 = SparseTensor(
|
| 1154 |
+
coords=latent_128.coords,
|
| 1155 |
+
feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats),
|
| 1156 |
+
)
|
| 1157 |
+
|
| 1158 |
+
self.output_voxel_dir = os.path.dirname(load_path)
|
| 1159 |
+
self.output_obj_dir = os.path.dirname(load_path)
|
| 1160 |
+
|
| 1161 |
+
# 7. Decoding with separate vertex and edge processing
|
| 1162 |
+
decoded_results = self.vae.decode(
|
| 1163 |
+
latent_128,
|
| 1164 |
+
gt_vertex_voxels_list=[],
|
| 1165 |
+
gt_edge_voxels_list=[],
|
| 1166 |
+
training=False,
|
| 1167 |
+
|
| 1168 |
+
inference_threshold=0.5,
|
| 1169 |
+
vis_last_layer=False,
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
error = 0 # decoded_results[-1]['edge']['predicted_offset_feats']
|
| 1173 |
+
|
| 1174 |
+
if self.config['model'].get('pred_direction', False):
|
| 1175 |
+
pred_dir = decoded_results[-1]['edge']['predicted_direction_feats']
|
| 1176 |
+
zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0
|
| 1177 |
+
num_zeros = zero_mask.sum().item()
|
| 1178 |
+
print("Number of zero vectors:", num_zeros)
|
| 1179 |
+
|
| 1180 |
+
pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
|
| 1181 |
+
print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
|
| 1182 |
+
print('pred_dir.shape', pred_dir.shape)
|
| 1183 |
+
if pred_edge_coords_3d.shape[-1] == 4:
|
| 1184 |
+
pred_edge_coords_3d = pred_edge_coords_3d[:, 1:]
|
| 1185 |
+
|
| 1186 |
+
save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply")
|
| 1187 |
+
visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth)
|
| 1188 |
+
|
| 1189 |
+
save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply")
|
| 1190 |
+
visualize_colored_points_ply((gt_edge_voxels_512[:, 1:]), dir_gt, save_pth)
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords']
|
| 1194 |
+
pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'][:, 1:].to(self.device)
|
| 1198 |
+
gt_edge_voxels_512 = batch_data['gt_edge_voxels_512'][:, 1:].to(self.device)
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
# Calculate metrics and save results
|
| 1202 |
+
matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_512, threshold=threshold,)
|
| 1203 |
+
print(f"\n----- Resolution {512} vtx -----")
|
| 1204 |
+
print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
|
| 1205 |
+
print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
|
| 1206 |
+
|
| 1207 |
+
self._save_voxel_ply(pred_vtx_coords_3d / 512., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx")
|
| 1208 |
+
self._save_voxel_ply((pred_edge_coords_3d) / 512, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge")
|
| 1209 |
+
|
| 1210 |
+
self._save_voxel_ply(gt_vertex_voxels_512 / 512, torch.zeros(len(gt_vertex_voxels_512)), f"{sample_name}_gt_vertex")
|
| 1211 |
+
self._save_voxel_ply((combined_voxels_512[:, 1:]) / 512., torch.zeros(len(gt_combined_errors_512)), f"{sample_name}_gt_edge")
|
| 1212 |
+
|
| 1213 |
+
|
| 1214 |
+
# Calculate vertex-specific metrics
|
| 1215 |
+
matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_512[:, 1:], threshold=threshold,)
|
| 1216 |
+
print(f"\n----- Resolution {512} edge -----")
|
| 1217 |
+
print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
|
| 1218 |
+
print('gt_edge_voxels_512.shape', gt_edge_voxels_512.shape)
|
| 1219 |
+
|
| 1220 |
+
print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
|
| 1221 |
+
print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
|
| 1222 |
+
|
| 1223 |
+
pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int)
|
| 1224 |
+
pred_edges = []
|
| 1225 |
+
gt_vertex_coords_np = np.round(gt_vertex_voxels_512.cpu().numpy()).astype(int)
|
| 1226 |
+
if visualize:
|
| 1227 |
+
if pred_vtx_coords_3d.shape[-1] == 4:
|
| 1228 |
+
pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float()
|
| 1229 |
+
else:
|
| 1230 |
+
pred_vtx_coords_float = pred_vtx_coords_3d.float()
|
| 1231 |
+
|
| 1232 |
+
pred_vtx_feats = decoded_results[-1]['vertex']['feats']
|
| 1233 |
+
|
| 1234 |
+
# ==========================================
|
| 1235 |
+
# Link Prediction & Mesh Generation
|
| 1236 |
+
# ==========================================
|
| 1237 |
+
print("Predicting connectivity...")
|
| 1238 |
+
|
| 1239 |
+
# 1. 预测边
|
| 1240 |
+
# 注意:K_neighbors 的设置。如果是物体,64 足够了。
|
| 1241 |
+
# 如果点非常稀疏,可能需要更大。
|
| 1242 |
+
pred_edges = predict_mesh_connectivity(
|
| 1243 |
+
connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的
|
| 1244 |
+
vtx_feats=pred_vtx_feats,
|
| 1245 |
+
vtx_coords=pred_vtx_coords_float,
|
| 1246 |
+
batch_size=4096,
|
| 1247 |
+
threshold=0.5,
|
| 1248 |
+
k_neighbors=None,
|
| 1249 |
+
device=self.device
|
| 1250 |
+
)
|
| 1251 |
+
print(f"Predicted {len(pred_edges)} edges.")
|
| 1252 |
+
|
| 1253 |
+
# 2. 构建三角形
|
| 1254 |
+
num_verts = pred_vtx_coords_float.shape[0]
|
| 1255 |
+
pred_faces = build_triangles_from_edges(pred_edges, num_verts)
|
| 1256 |
+
print(f"Constructed {len(pred_faces)} triangles.")
|
| 1257 |
+
|
| 1258 |
+
# 3. 保存 OBJ
|
| 1259 |
+
import trimesh
|
| 1260 |
+
|
| 1261 |
+
# 坐标归一化/还原 (根据你的需求,这里假设你是 0-512 的体素坐标)
|
| 1262 |
+
# 如果想保存为归一化坐标:
|
| 1263 |
+
mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0
|
| 1264 |
+
|
| 1265 |
+
# 如果有 error offset,记得加上!
|
| 1266 |
+
# 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
|
| 1267 |
+
# 如果 vertex 也有 offset (如 dual contouring),在这里加上
|
| 1268 |
+
|
| 1269 |
+
# 移动到中心 (可选)
|
| 1270 |
+
mesh_verts = mesh_verts - 0.5
|
| 1271 |
+
|
| 1272 |
+
mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
|
| 1273 |
+
|
| 1274 |
+
trimesh.repair.fix_normals(mesh)
|
| 1275 |
+
|
| 1276 |
+
output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj")
|
| 1277 |
+
mesh.export(output_obj_path)
|
| 1278 |
+
print(f"Saved mesh to {output_obj_path}")
|
| 1279 |
+
|
| 1280 |
+
# 保存边线 (用于 Debug)
|
| 1281 |
+
# 有时候三角形很难形成,只看边也很有用
|
| 1282 |
+
edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply")
|
| 1283 |
+
# self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison")
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
# Process results at different resolutions
|
| 1287 |
+
for i, res in enumerate([128, 256, 512]):
|
| 1288 |
+
if i >= len(decoded_results):
|
| 1289 |
+
continue
|
| 1290 |
+
|
| 1291 |
+
gt_key = f'gt_vertex_voxels_{res}'
|
| 1292 |
+
if gt_key not in batch_data:
|
| 1293 |
+
continue
|
| 1294 |
+
if i == 0:
|
| 1295 |
+
pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float()
|
| 1296 |
+
gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
|
| 1297 |
+
else:
|
| 1298 |
+
pred_coords_res = decoded_results[i]['vertex']['coords'].float()
|
| 1299 |
+
gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
|
| 1300 |
+
|
| 1301 |
+
|
| 1302 |
+
v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold)
|
| 1303 |
+
|
| 1304 |
+
per_sample_metrics['vertex'][res].append({
|
| 1305 |
+
'recall': v_metrics['recall'],
|
| 1306 |
+
'precision': v_metrics['precision'],
|
| 1307 |
+
'f1': v_metrics['f1'],
|
| 1308 |
+
'num_pred': len(pred_coords_res),
|
| 1309 |
+
'num_gt': len(gt_coords_res)
|
| 1310 |
+
})
|
| 1311 |
+
|
| 1312 |
+
avg_metrics['vertex'][res]['recall'].append(v_metrics['recall'])
|
| 1313 |
+
avg_metrics['vertex'][res]['precision'].append(v_metrics['precision'])
|
| 1314 |
+
avg_metrics['vertex'][res]['f1'].append(v_metrics['f1'])
|
| 1315 |
+
|
| 1316 |
+
gt_edge_key = f'gt_edge_voxels_{res}'
|
| 1317 |
+
if gt_edge_key not in batch_data:
|
| 1318 |
+
continue
|
| 1319 |
+
|
| 1320 |
+
if i == 0:
|
| 1321 |
+
pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float()
|
| 1322 |
+
# gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1323 |
+
idx = i
|
| 1324 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1325 |
+
elif i == 1:
|
| 1326 |
+
idx = i
|
| 1327 |
+
#################################
|
| 1328 |
+
# pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5
|
| 1329 |
+
# # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1330 |
+
# gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_512[:, 1:].to(self.device) + 0.5
|
| 1331 |
+
|
| 1332 |
+
|
| 1333 |
+
pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
|
| 1334 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1335 |
+
|
| 1336 |
+
# self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset")
|
| 1337 |
+
# self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset")
|
| 1338 |
+
|
| 1339 |
+
else:
|
| 1340 |
+
idx = i
|
| 1341 |
+
pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
|
| 1342 |
+
# gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
|
| 1343 |
+
gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
|
| 1344 |
+
|
| 1345 |
+
# self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res")
|
| 1346 |
+
# self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res")
|
| 1347 |
+
|
| 1348 |
+
e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold)
|
| 1349 |
+
|
| 1350 |
+
per_sample_metrics['edge'][res].append({
|
| 1351 |
+
'recall': e_metrics['recall'],
|
| 1352 |
+
'precision': e_metrics['precision'],
|
| 1353 |
+
'f1': e_metrics['f1'],
|
| 1354 |
+
'num_pred': len(pred_edge_coords_res),
|
| 1355 |
+
'num_gt': len(gt_edge_coords_res)
|
| 1356 |
+
})
|
| 1357 |
+
|
| 1358 |
+
avg_metrics['edge'][res]['recall'].append(e_metrics['recall'])
|
| 1359 |
+
avg_metrics['edge'][res]['precision'].append(e_metrics['precision'])
|
| 1360 |
+
avg_metrics['edge'][res]['f1'].append(e_metrics['f1'])
|
| 1361 |
+
|
| 1362 |
+
avg_metrics_processed = {}
|
| 1363 |
+
for category, res_dict in avg_metrics.items():
|
| 1364 |
+
avg_metrics_processed[category] = {}
|
| 1365 |
+
for resolution, metric_dict in res_dict.items():
|
| 1366 |
+
avg_metrics_processed[category][resolution] = {
|
| 1367 |
+
metric_name: np.mean(values) if values else float('nan')
|
| 1368 |
+
for metric_name, values in metric_dict.items()
|
| 1369 |
+
}
|
| 1370 |
+
|
| 1371 |
+
result_data = {
|
| 1372 |
+
'config': self.config,
|
| 1373 |
+
'checkpoint': self.ckpt_path,
|
| 1374 |
+
'dataset': self.dataset_path,
|
| 1375 |
+
'num_samples': eval_samples,
|
| 1376 |
+
'per_sample_metrics': per_sample_metrics,
|
| 1377 |
+
'avg_metrics': avg_metrics_processed
|
| 1378 |
+
}
|
| 1379 |
+
|
| 1380 |
+
results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml")
|
| 1381 |
+
with open(results_file_path, 'w') as f:
|
| 1382 |
+
yaml.dump(result_data, f, default_flow_style=False)
|
| 1383 |
+
|
| 1384 |
+
return result_data
|
| 1385 |
+
|
| 1386 |
+
def _generate_line_voxels(
|
| 1387 |
+
self,
|
| 1388 |
+
p1: torch.Tensor,
|
| 1389 |
+
p2: torch.Tensor
|
| 1390 |
+
) -> Tuple[
|
| 1391 |
+
List[Tuple[int, int, int]],
|
| 1392 |
+
List[Tuple[torch.Tensor, torch.Tensor]],
|
| 1393 |
+
List[np.ndarray]
|
| 1394 |
+
]:
|
| 1395 |
+
"""
|
| 1396 |
+
Improved version using better sampling strategy
|
| 1397 |
+
"""
|
| 1398 |
+
p1_np = p1 #.cpu().numpy()
|
| 1399 |
+
p2_np = p2 #.cpu().numpy()
|
| 1400 |
+
voxel_dict = OrderedDict()
|
| 1401 |
+
|
| 1402 |
+
# Use proper 3D line voxelization algorithm
|
| 1403 |
+
def bresenham_3d(p1, p2):
|
| 1404 |
+
"""3D Bresenham's line algorithm"""
|
| 1405 |
+
x1, y1, z1 = np.round(p1).astype(int)
|
| 1406 |
+
x2, y2, z2 = np.round(p2).astype(int)
|
| 1407 |
+
|
| 1408 |
+
points = []
|
| 1409 |
+
dx = abs(x2 - x1)
|
| 1410 |
+
dy = abs(y2 - y1)
|
| 1411 |
+
dz = abs(z2 - z1)
|
| 1412 |
+
|
| 1413 |
+
xs = 1 if x2 > x1 else -1
|
| 1414 |
+
ys = 1 if y2 > y1 else -1
|
| 1415 |
+
zs = 1 if z2 > z1 else -1
|
| 1416 |
+
|
| 1417 |
+
# Driving axis is X
|
| 1418 |
+
if dx >= dy and dx >= dz:
|
| 1419 |
+
err_1 = 2 * dy - dx
|
| 1420 |
+
err_2 = 2 * dz - dx
|
| 1421 |
+
for i in range(dx + 1):
|
| 1422 |
+
points.append((x1, y1, z1))
|
| 1423 |
+
if err_1 > 0:
|
| 1424 |
+
y1 += ys
|
| 1425 |
+
err_1 -= 2 * dx
|
| 1426 |
+
if err_2 > 0:
|
| 1427 |
+
z1 += zs
|
| 1428 |
+
err_2 -= 2 * dx
|
| 1429 |
+
err_1 += 2 * dy
|
| 1430 |
+
err_2 += 2 * dz
|
| 1431 |
+
x1 += xs
|
| 1432 |
+
|
| 1433 |
+
# Driving axis is Y
|
| 1434 |
+
elif dy >= dx and dy >= dz:
|
| 1435 |
+
err_1 = 2 * dx - dy
|
| 1436 |
+
err_2 = 2 * dz - dy
|
| 1437 |
+
for i in range(dy + 1):
|
| 1438 |
+
points.append((x1, y1, z1))
|
| 1439 |
+
if err_1 > 0:
|
| 1440 |
+
x1 += xs
|
| 1441 |
+
err_1 -= 2 * dy
|
| 1442 |
+
if err_2 > 0:
|
| 1443 |
+
z1 += zs
|
| 1444 |
+
err_2 -= 2 * dy
|
| 1445 |
+
err_1 += 2 * dx
|
| 1446 |
+
err_2 += 2 * dz
|
| 1447 |
+
y1 += ys
|
| 1448 |
+
|
| 1449 |
+
# Driving axis is Z
|
| 1450 |
+
else:
|
| 1451 |
+
err_1 = 2 * dx - dz
|
| 1452 |
+
err_2 = 2 * dy - dz
|
| 1453 |
+
for i in range(dz + 1):
|
| 1454 |
+
points.append((x1, y1, z1))
|
| 1455 |
+
if err_1 > 0:
|
| 1456 |
+
x1 += xs
|
| 1457 |
+
err_1 -= 2 * dz
|
| 1458 |
+
if err_2 > 0:
|
| 1459 |
+
y1 += ys
|
| 1460 |
+
err_2 -= 2 * dz
|
| 1461 |
+
err_1 += 2 * dx
|
| 1462 |
+
err_2 += 2 * dy
|
| 1463 |
+
z1 += zs
|
| 1464 |
+
|
| 1465 |
+
return points
|
| 1466 |
+
|
| 1467 |
+
# Get all voxels using Bresenham algorithm
|
| 1468 |
+
voxel_coords = bresenham_3d(p1_np, p2_np)
|
| 1469 |
+
|
| 1470 |
+
# Add all voxels to dictionary
|
| 1471 |
+
for coord in voxel_coords:
|
| 1472 |
+
voxel_dict[tuple(coord)] = (p1, p2)
|
| 1473 |
+
|
| 1474 |
+
voxel_coords = list(voxel_dict.keys())
|
| 1475 |
+
endpoint_pairs = list(voxel_dict.values())
|
| 1476 |
+
|
| 1477 |
+
# --- compute error vectors ---
|
| 1478 |
+
error_vectors = []
|
| 1479 |
+
diff = p2_np - p1_np
|
| 1480 |
+
d_norm_sq = np.dot(diff, diff)
|
| 1481 |
+
|
| 1482 |
+
for v in voxel_coords:
|
| 1483 |
+
v_center = np.array(v, dtype=float) + 0.5
|
| 1484 |
+
if d_norm_sq == 0: # degenerate line
|
| 1485 |
+
closest = p1_np
|
| 1486 |
+
else:
|
| 1487 |
+
t = np.dot(v_center - p1_np, diff) / d_norm_sq
|
| 1488 |
+
t = np.clip(t, 0.0, 1.0)
|
| 1489 |
+
closest = p1_np + t * diff
|
| 1490 |
+
error_vectors.append(v_center - closest)
|
| 1491 |
+
|
| 1492 |
+
return voxel_coords, endpoint_pairs, error_vectors
|
| 1493 |
+
|
| 1494 |
+
|
| 1495 |
+
# 使用示例
|
| 1496 |
+
def set_seed(seed: int):
|
| 1497 |
+
random.seed(seed)
|
| 1498 |
+
np.random.seed(seed)
|
| 1499 |
+
torch.manual_seed(seed)
|
| 1500 |
+
if torch.cuda.is_available():
|
| 1501 |
+
torch.cuda.manual_seed(seed)
|
| 1502 |
+
torch.cuda.manual_seed_all(seed)
|
| 1503 |
+
torch.backends.cudnn.deterministic = True
|
| 1504 |
+
torch.backends.cudnn.benchmark = False
|
| 1505 |
+
|
| 1506 |
+
def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
|
| 1507 |
+
set_seed(42)
|
| 1508 |
+
tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
|
| 1509 |
+
result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD)
|
| 1510 |
+
|
| 1511 |
+
# 生成文件名
|
| 1512 |
+
epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0]
|
| 1513 |
+
dataset_name = os.path.basename(os.path.normpath(dataset_path))
|
| 1514 |
+
|
| 1515 |
+
# 保存简版报告(TXT)
|
| 1516 |
+
summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt")
|
| 1517 |
+
with open(summary_path, 'w') as f:
|
| 1518 |
+
# 头部信息
|
| 1519 |
+
f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n")
|
| 1520 |
+
f.write(f"Dataset: {dataset_name}\n")
|
| 1521 |
+
f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n")
|
| 1522 |
+
|
| 1523 |
+
# 平均指标
|
| 1524 |
+
f.write("=== Average Metrics ===\n")
|
| 1525 |
+
for category, data in result_data['avg_metrics'].items():
|
| 1526 |
+
if isinstance(data, dict): # 处理多分辨率情况
|
| 1527 |
+
f.write(f"\n{category.upper()}:\n")
|
| 1528 |
+
for res, metrics in data.items():
|
| 1529 |
+
f.write(f" Resolution {res}:\n")
|
| 1530 |
+
for k, v in metrics.items():
|
| 1531 |
+
# 确保值是数字类型后再格式化
|
| 1532 |
+
if isinstance(v, (int, float)):
|
| 1533 |
+
f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
|
| 1534 |
+
else:
|
| 1535 |
+
f.write(f" {str(k).ljust(15)}: {str(v)}\n")
|
| 1536 |
+
else: # 处理非多分辨率情况
|
| 1537 |
+
f.write(f"\n{category.upper()}:\n")
|
| 1538 |
+
for k, v in data.items():
|
| 1539 |
+
if isinstance(v, (int, float)):
|
| 1540 |
+
f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
|
| 1541 |
+
else:
|
| 1542 |
+
f.write(f" {str(k).ljust(15)}: {str(v)}\n")
|
| 1543 |
+
|
| 1544 |
+
# 样本级详细统计
|
| 1545 |
+
f.write("\n\n=== Detailed Per-Sample Metrics ===\n")
|
| 1546 |
+
for name, vertex_metrics, edge_metrics in zip(
|
| 1547 |
+
result_data['per_sample_metrics']['sample_names'],
|
| 1548 |
+
zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256, 512]]),
|
| 1549 |
+
zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256, 512]])
|
| 1550 |
+
):
|
| 1551 |
+
# 样本标题
|
| 1552 |
+
f.write(f"\n◆ Sample: {name}\n")
|
| 1553 |
+
|
| 1554 |
+
# 顶点指标
|
| 1555 |
+
f.write(f"[Vertex Prediction]\n")
|
| 1556 |
+
f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
|
| 1557 |
+
for res, metrics in zip([128, 256, 512], vertex_metrics):
|
| 1558 |
+
f.write(f" {str(res).ljust(10)} "
|
| 1559 |
+
f"{metrics['recall']:.4f} "
|
| 1560 |
+
f"{metrics['precision']:.4f} "
|
| 1561 |
+
f"{metrics['f1']:.4f} "
|
| 1562 |
+
f"{metrics['num_pred']}/{metrics['num_gt']}\n")
|
| 1563 |
+
|
| 1564 |
+
# Edge指标
|
| 1565 |
+
f.write(f"[Edge Prediction]\n")
|
| 1566 |
+
f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
|
| 1567 |
+
for res, metrics in zip([128, 256, 512], edge_metrics):
|
| 1568 |
+
f.write(f" {str(res).ljust(10)} "
|
| 1569 |
+
f"{metrics['recall']:.4f} "
|
| 1570 |
+
f"{metrics['precision']:.4f} "
|
| 1571 |
+
f"{metrics['f1']:.4f} "
|
| 1572 |
+
f"{metrics['num_pred']}/{metrics['num_gt']}\n")
|
| 1573 |
+
|
| 1574 |
+
f.write("-"*60 + "\n")
|
| 1575 |
+
|
| 1576 |
+
print(f"Saved summary to: {summary_path}")
|
| 1577 |
+
return result_data
|
| 1578 |
+
|
| 1579 |
+
|
| 1580 |
+
if __name__ == '__main__':
|
| 1581 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 1582 |
+
evaluate_all_checkpoints = False # 设置为 True 启用范围过滤
|
| 1583 |
+
EPOCH_START = 1
|
| 1584 |
+
EPOCH_END = 12
|
| 1585 |
+
CHAMFER_EDGE_THRESHOLD=0.5
|
| 1586 |
+
NUM_SAMPLES=50
|
| 1587 |
+
VISUALIZE=True
|
| 1588 |
+
THRESHOLD=1.5
|
| 1589 |
+
VISUAL_FIELD=False
|
| 1590 |
+
|
| 1591 |
+
ckpt_path = '/home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt'
|
| 1592 |
+
dataset_path = '/home/tiger/yy/src/trellis_clean_mesh/mesh_data'
|
| 1593 |
+
|
| 1594 |
+
if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
|
| 1595 |
+
RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
|
| 1596 |
+
else:
|
| 1597 |
+
RENDERS_DIR = ''
|
| 1598 |
+
|
| 1599 |
+
|
| 1600 |
+
ckpt_dir = os.path.dirname(ckpt_path)
|
| 1601 |
+
eval_dir = os.path.join(ckpt_dir, "evaluate")
|
| 1602 |
+
os.makedirs(eval_dir, exist_ok=True)
|
| 1603 |
+
|
| 1604 |
+
if False:
|
| 1605 |
+
for i in range(NUM_SAMPLES):
|
| 1606 |
+
print("--- Starting Latent Space PCA Visualization ---")
|
| 1607 |
+
tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
|
| 1608 |
+
tester.visualize_latent_space_pca(sample_idx=i)
|
| 1609 |
+
print("--- PCA Visualization Finished ---")
|
| 1610 |
+
|
| 1611 |
+
if not evaluate_all_checkpoints:
|
| 1612 |
+
evaluate_checkpoint(ckpt_path, dataset_path, eval_dir)
|
| 1613 |
+
else:
|
| 1614 |
+
pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
|
| 1615 |
+
|
| 1616 |
+
filtered_pt_files = []
|
| 1617 |
+
for f in pt_files:
|
| 1618 |
+
try:
|
| 1619 |
+
parts = f.split('_')
|
| 1620 |
+
epoch_str = parts[1].replace('epoch', '')
|
| 1621 |
+
epoch = int(epoch_str)
|
| 1622 |
+
if EPOCH_START <= epoch <= EPOCH_END:
|
| 1623 |
+
filtered_pt_files.append(f)
|
| 1624 |
+
except Exception as e:
|
| 1625 |
+
print(f"Warning: Could not parse epoch from {f}: {e}")
|
| 1626 |
+
continue
|
| 1627 |
+
|
| 1628 |
+
for pt_file in filtered_pt_files:
|
| 1629 |
+
full_ckpt_path = os.path.join(ckpt_dir, pt_file)
|
| 1630 |
+
evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir)
|
test_slat_vae_128to512_pointnet_vae_head.py
CHANGED
|
@@ -744,7 +744,7 @@ class Tester:
|
|
| 744 |
root_dir=self.dataset_path,
|
| 745 |
base_resolution=self.config['dataset']['base_resolution'],
|
| 746 |
min_resolution=self.config['dataset']['min_resolution'],
|
| 747 |
-
cache_dir='/
|
| 748 |
# cache_dir=self.config['dataset']['cache_dir'],
|
| 749 |
renders_dir=self.config['dataset']['renders_dir'],
|
| 750 |
|
|
@@ -1262,15 +1262,12 @@ class Tester:
|
|
| 1262 |
# 如果想保存为归一化坐标:
|
| 1263 |
mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0
|
| 1264 |
|
| 1265 |
-
# 如果有 error offset,记得加上!
|
| 1266 |
-
# 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
|
| 1267 |
-
# 如果 vertex 也有 offset (如 dual contouring),在这里加上
|
| 1268 |
-
|
| 1269 |
# 移动到中心 (可选)
|
| 1270 |
mesh_verts = mesh_verts - 0.5
|
| 1271 |
|
| 1272 |
mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
|
| 1273 |
|
|
|
|
| 1274 |
# 过滤孤立点 (可选)
|
| 1275 |
# mesh.remove_unreferenced_vertices()
|
| 1276 |
|
|
@@ -1581,22 +1578,23 @@ def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
|
|
| 1581 |
if __name__ == '__main__':
|
| 1582 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 1583 |
evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
|
| 1584 |
-
EPOCH_START =
|
| 1585 |
EPOCH_END = 12
|
| 1586 |
CHAMFER_EDGE_THRESHOLD=0.5
|
| 1587 |
-
NUM_SAMPLES=
|
| 1588 |
VISUALIZE=True
|
| 1589 |
THRESHOLD=1.5
|
| 1590 |
VISUAL_FIELD=False
|
| 1591 |
|
| 1592 |
-
ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch10433_loss1.2657.pt'
|
| 1593 |
-
ckpt_path = '/
|
| 1594 |
-
|
| 1595 |
-
dataset_path = '/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/test'
|
| 1596 |
-
dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized'
|
| 1597 |
-
# dataset_path = '/gemini/user/private/zhaotianhao/data/trellis500k_compress_glb'
|
| 1598 |
-
dataset_path = '/gemini/user/private/zhaotianhao/data/unique_files_glb_under6000face_2degree_30ratio_0.01'
|
| 1599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1600 |
if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
|
| 1601 |
RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
|
| 1602 |
else:
|
|
|
|
| 744 |
root_dir=self.dataset_path,
|
| 745 |
base_resolution=self.config['dataset']['base_resolution'],
|
| 746 |
min_resolution=self.config['dataset']['min_resolution'],
|
| 747 |
+
cache_dir='/home/tiger/yy/src/dataset_cache',
|
| 748 |
# cache_dir=self.config['dataset']['cache_dir'],
|
| 749 |
renders_dir=self.config['dataset']['renders_dir'],
|
| 750 |
|
|
|
|
| 1262 |
# 如果想保存为归一化坐标:
|
| 1263 |
mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0
|
| 1264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1265 |
# 移动到中心 (可选)
|
| 1266 |
mesh_verts = mesh_verts - 0.5
|
| 1267 |
|
| 1268 |
mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
|
| 1269 |
|
| 1270 |
+
trimesh.repair.fix_normals(mesh)
|
| 1271 |
# 过滤孤立点 (可选)
|
| 1272 |
# mesh.remove_unreferenced_vertices()
|
| 1273 |
|
|
|
|
| 1578 |
if __name__ == '__main__':
|
| 1579 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 1580 |
evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
|
| 1581 |
+
EPOCH_START = 1
|
| 1582 |
EPOCH_END = 12
|
| 1583 |
CHAMFER_EDGE_THRESHOLD=0.5
|
| 1584 |
+
NUM_SAMPLES=50
|
| 1585 |
VISUALIZE=True
|
| 1586 |
THRESHOLD=1.5
|
| 1587 |
VISUAL_FIELD=False
|
| 1588 |
|
| 1589 |
+
# ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch10433_loss1.2657.pt'
|
| 1590 |
+
ckpt_path = '/home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch0_batch6000_loss0.1150.pt'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1591 |
|
| 1592 |
+
dataset_path = '/home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01'
|
| 1593 |
+
# dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized'
|
| 1594 |
+
# # dataset_path = '/gemini/user/private/zhaotianhao/data/trellis500k_compress_glb'
|
| 1595 |
+
# dataset_path = '/gemini/user/private/zhaotianhao/data/unique_files_glb_under6000face_2degree_30ratio_0.01'
|
| 1596 |
+
dataset_path = '/home/tiger/yy/src/trellis_clean_mesh/mesh_data'
|
| 1597 |
+
|
| 1598 |
if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
|
| 1599 |
RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
|
| 1600 |
else:
|
train_slat_flow_128to512_pointnet_head.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# os.environ['ATTN_BACKEND'] = 'xformers' # xformers is generally compatible with DDP
|
| 3 |
+
# os.environ["OMP_NUM_THREADS"] = "1"
|
| 4 |
+
# os.environ["MKL_NUM_THREADS"] = "1"
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import yaml
|
| 8 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 9 |
+
from functools import partial
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.optim import AdamW
|
| 12 |
+
from torch.amp import GradScaler, autocast
|
| 13 |
+
from typing import *
|
| 14 |
+
from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 17 |
+
|
| 18 |
+
# --- Updated Imports based on VAE script ---
|
| 19 |
+
from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
|
| 20 |
+
from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
|
| 21 |
+
from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
|
| 22 |
+
from triposf.modules.sparse.basic import SparseTensor
|
| 23 |
+
|
| 24 |
+
from trellis.models.structured_latent_flow import SLatFlowModel
|
| 25 |
+
from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
|
| 26 |
+
from safetensors.torch import load_file
|
| 27 |
+
import torch.multiprocessing as mp
|
| 28 |
+
from PIL import Image
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
|
| 31 |
+
from triposf.modules.utils import DiagonalGaussianDistribution
|
| 32 |
+
import torchvision.transforms as transforms
|
| 33 |
+
import re
|
| 34 |
+
import contextlib
|
| 35 |
+
|
| 36 |
+
# --- Distributed Setup Functions ---
|
| 37 |
+
def setup_distributed(backend="nccl"):
|
| 38 |
+
"""Initializes the distributed environment."""
|
| 39 |
+
if not dist.is_initialized():
|
| 40 |
+
rank = int(os.environ["RANK"])
|
| 41 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 42 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 43 |
+
|
| 44 |
+
torch.cuda.set_device(local_rank)
|
| 45 |
+
dist.init_process_group(backend=backend)
|
| 46 |
+
|
| 47 |
+
return int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
| 48 |
+
|
| 49 |
+
def cleanup_distributed():
|
| 50 |
+
dist.destroy_process_group()
|
| 51 |
+
|
| 52 |
+
# --- Modified Trainer Class ---
|
| 53 |
+
class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
|
| 54 |
+
def __init__(self, *args, rank: int, local_rank: int, world_size: int, **kwargs):
|
| 55 |
+
super().__init__(*args, **kwargs)
|
| 56 |
+
self.cfg = kwargs.pop('cfg', None)
|
| 57 |
+
if self.cfg is None:
|
| 58 |
+
raise ValueError("Configuration dictionary 'cfg' must be provided.")
|
| 59 |
+
|
| 60 |
+
# --- Distributed-related attributes ---
|
| 61 |
+
self.rank = rank
|
| 62 |
+
self.local_rank = local_rank
|
| 63 |
+
self.world_size = world_size
|
| 64 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 65 |
+
self.is_master = (self.rank == 0)
|
| 66 |
+
self.gradient_accumulation_steps = 8
|
| 67 |
+
|
| 68 |
+
self.i_save = self.cfg['training']['save_every']
|
| 69 |
+
self.save_dir = self.cfg['training']['output_dir']
|
| 70 |
+
|
| 71 |
+
self.resolution = 128
|
| 72 |
+
self.condition_type = 'image'
|
| 73 |
+
self.is_cond = False
|
| 74 |
+
self.img_res = 518
|
| 75 |
+
|
| 76 |
+
if self.is_master:
|
| 77 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 78 |
+
print(f"Checkpoints and logs will be saved to: {self.save_dir}")
|
| 79 |
+
|
| 80 |
+
# Initialize components and set up for DDP
|
| 81 |
+
self._init_components(
|
| 82 |
+
clip_model_path=self.cfg['training'].get('clip_model_path', None),
|
| 83 |
+
dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
|
| 84 |
+
vae_path=self.cfg['training']['vae_path'],
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self._setup_ddp()
|
| 88 |
+
|
| 89 |
+
self.denoiser_checkpoint_path = self.cfg['training'].get('denoiser_checkpoint_path', None)
|
| 90 |
+
|
| 91 |
+
trainable_params = list(self.denoiser.parameters())
|
| 92 |
+
self.optimizer = AdamW(trainable_params, lr=self.cfg['training'].get('lr', 0.0001), weight_decay=0.0)
|
| 93 |
+
|
| 94 |
+
self.scaler = GradScaler()
|
| 95 |
+
|
| 96 |
+
if self.is_master:
|
| 97 |
+
print("Using Automatic Mixed Precision (AMP) with GradScaler.")
|
| 98 |
+
|
| 99 |
+
def _init_components(self,
|
| 100 |
+
clip_model_path=None,
|
| 101 |
+
dinov2_model_path=None,
|
| 102 |
+
vae_path=None,
|
| 103 |
+
):
|
| 104 |
+
"""
|
| 105 |
+
Initializes VAE, VoxelEncoder (PointNet), and condition models.
|
| 106 |
+
"""
|
| 107 |
+
# Use the Dataset from the VAE script
|
| 108 |
+
self.dataset = VoxelVertexDataset_edge(
|
| 109 |
+
root_dir=self.cfg['dataset']['path'],
|
| 110 |
+
base_resolution=self.cfg['dataset']['base_resolution'],
|
| 111 |
+
min_resolution=self.cfg['dataset']['min_resolution'],
|
| 112 |
+
cache_dir=self.cfg['dataset']['cache_dir'],
|
| 113 |
+
renders_dir=self.cfg['dataset']['renders_dir'],
|
| 114 |
+
|
| 115 |
+
filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
|
| 116 |
+
cache_filter_path=self.cfg['dataset']['cache_filter_path'],
|
| 117 |
+
|
| 118 |
+
active_voxel_res=128,
|
| 119 |
+
pc_sample_number=819200,
|
| 120 |
+
|
| 121 |
+
sample_type=self.cfg['dataset']['sample_type'],
|
| 122 |
+
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.sampler = DistributedSampler(
|
| 126 |
+
self.dataset,
|
| 127 |
+
num_replicas=self.world_size,
|
| 128 |
+
rank=self.rank,
|
| 129 |
+
shuffle=True
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Use collate_fn_pointnet
|
| 133 |
+
self.dataloader = DataLoader(
|
| 134 |
+
self.dataset,
|
| 135 |
+
batch_size=self.cfg['training']['batch_size'],
|
| 136 |
+
shuffle=False,
|
| 137 |
+
collate_fn=partial(collate_fn_pointnet,),
|
| 138 |
+
num_workers=self.cfg['training']['num_workers'],
|
| 139 |
+
pin_memory=True,
|
| 140 |
+
sampler=self.sampler,
|
| 141 |
+
prefetch_factor=4,
|
| 142 |
+
persistent_workers=True,
|
| 143 |
+
drop_last=True,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def load_file_func(path, device='cpu'):
|
| 147 |
+
return torch.load(path, map_location=device)
|
| 148 |
+
|
| 149 |
+
def _load_and_broadcast(model, load_fn=None, path=None, strict=True):
|
| 150 |
+
if self.is_master:
|
| 151 |
+
try:
|
| 152 |
+
state = load_fn(path) if load_fn else model.state_dict()
|
| 153 |
+
except Exception as e:
|
| 154 |
+
raise RuntimeError(f"Failed to load weights from {path}: {e}")
|
| 155 |
+
else:
|
| 156 |
+
state = None
|
| 157 |
+
|
| 158 |
+
dist.barrier()
|
| 159 |
+
state_b = [state] if self.is_master else [None]
|
| 160 |
+
dist.broadcast_object_list(state_b, src=0)
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
# Handle potential key mismatches (e.g. 'module.' prefix)
|
| 164 |
+
model.load_state_dict(state_b[0], strict=strict)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
if self.is_master: print(f"Strict loading failed for {model.__class__.__name__}, trying non-strict: {e}")
|
| 167 |
+
model.load_state_dict(state_b[0], strict=False)
|
| 168 |
+
|
| 169 |
+
# ------------------------- Voxel Encoder (PointNet) -------------------------
|
| 170 |
+
# Matching the VAE script configuration
|
| 171 |
+
self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
|
| 172 |
+
in_channels=15,
|
| 173 |
+
hidden_dim=256,
|
| 174 |
+
out_channels=1024,
|
| 175 |
+
scatter_type='mean',
|
| 176 |
+
n_blocks=5,
|
| 177 |
+
resolution=128,
|
| 178 |
+
add_label=False,
|
| 179 |
+
).to(self.device)
|
| 180 |
+
|
| 181 |
+
# ------------------------- VAE -------------------------
|
| 182 |
+
self.vae = VoxelVAE(
|
| 183 |
+
in_channels=self.cfg['model']['in_channels'],
|
| 184 |
+
latent_dim=self.cfg['model']['latent_dim'],
|
| 185 |
+
encoder_blocks=self.cfg['model']['encoder_blocks'],
|
| 186 |
+
decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
|
| 187 |
+
decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
|
| 188 |
+
num_heads=8,
|
| 189 |
+
num_head_channels=64,
|
| 190 |
+
mlp_ratio=4.0,
|
| 191 |
+
attn_mode="swin",
|
| 192 |
+
window_size=8,
|
| 193 |
+
pe_mode="ape",
|
| 194 |
+
use_fp16=False,
|
| 195 |
+
use_checkpoint=True,
|
| 196 |
+
qk_rms_norm=False,
|
| 197 |
+
using_subdivide=True,
|
| 198 |
+
using_attn=self.cfg['model']['using_attn'],
|
| 199 |
+
attn_first=self.cfg['model'].get('attn_first', True),
|
| 200 |
+
pred_direction=self.cfg['model'].get('pred_direction', False),
|
| 201 |
+
).to(self.device)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ------------------------- Conditioning -------------------------
|
| 205 |
+
if self.condition_type == 'text':
|
| 206 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
|
| 207 |
+
if self.is_master:
|
| 208 |
+
self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
|
| 209 |
+
else:
|
| 210 |
+
config = CLIPTextConfig.from_pretrained(clip_model_path)
|
| 211 |
+
self.condition_model = CLIPTextModel(config)
|
| 212 |
+
_load_and_broadcast(self.condition_model)
|
| 213 |
+
|
| 214 |
+
elif self.condition_type == 'image':
|
| 215 |
+
if self.is_master:
|
| 216 |
+
print("Initializing for IMAGE conditioning (DINOv2).")
|
| 217 |
+
|
| 218 |
+
# Update paths as per your environment
|
| 219 |
+
local_repo_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main"
|
| 220 |
+
weights_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
|
| 221 |
+
|
| 222 |
+
dinov2_model = torch.hub.load(
|
| 223 |
+
repo_or_dir=local_repo_path,
|
| 224 |
+
model='dinov2_vitl14_reg',
|
| 225 |
+
source='local',
|
| 226 |
+
pretrained=False
|
| 227 |
+
)
|
| 228 |
+
self.condition_model = dinov2_model
|
| 229 |
+
|
| 230 |
+
_load_and_broadcast(self.condition_model, load_fn=torch.load, path=weights_path)
|
| 231 |
+
|
| 232 |
+
self.image_cond_model_transform = transforms.Compose([
|
| 233 |
+
transforms.ToTensor(),
|
| 234 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 235 |
+
])
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Unsupported condition type: {self.condition_type}")
|
| 238 |
+
|
| 239 |
+
self.condition_model.to(self.device).eval()
|
| 240 |
+
for p in self.condition_model.parameters(): p.requires_grad = False
|
| 241 |
+
|
| 242 |
+
# ------------------------- Load VAE/Encoder Weights -------------------------
|
| 243 |
+
# Load weights corresponding to the logic in VAE script's `load_pretrained_woself`
|
| 244 |
+
# Assuming checkpoint contains 'vae' and 'voxel_encoder' keys
|
| 245 |
+
_load_and_broadcast(self.vae,
|
| 246 |
+
load_fn=lambda p: load_file_func(p)['vae'],
|
| 247 |
+
path=vae_path)
|
| 248 |
+
|
| 249 |
+
_load_and_broadcast(self.voxel_encoder,
|
| 250 |
+
load_fn=lambda p: load_file_func(p)['voxel_encoder'],
|
| 251 |
+
path=vae_path)
|
| 252 |
+
|
| 253 |
+
self.vae.eval()
|
| 254 |
+
self.voxel_encoder.eval()
|
| 255 |
+
for p in self.vae.parameters(): p.requires_grad = False
|
| 256 |
+
for p in self.voxel_encoder.parameters(): p.requires_grad = False
|
| 257 |
+
|
| 258 |
+
def _load_denoiser(self):
|
| 259 |
+
"""Loads a checkpoint for the denoiser."""
|
| 260 |
+
path = self.denoiser_checkpoint_path
|
| 261 |
+
if not path or not os.path.isfile(path):
|
| 262 |
+
if self.is_master:
|
| 263 |
+
print("No valid checkpoint path provided for denoiser. Starting from scratch.")
|
| 264 |
+
return
|
| 265 |
+
|
| 266 |
+
if self.is_master:
|
| 267 |
+
print(f"Loading denoiser checkpoint from: {path}")
|
| 268 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 269 |
+
else:
|
| 270 |
+
checkpoint = None
|
| 271 |
+
|
| 272 |
+
dist.barrier()
|
| 273 |
+
dist_list = [checkpoint] if self.is_master else [None]
|
| 274 |
+
dist.broadcast_object_list(dist_list, src=0)
|
| 275 |
+
checkpoint = dist_list[0]
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
self.denoiser.module.load_state_dict(checkpoint['denoiser'])
|
| 279 |
+
if self.is_master: print("Denoiser weights loaded successfully.")
|
| 280 |
+
except Exception as e:
|
| 281 |
+
if self.is_master: print(f"[ERROR] Failed to load denoiser state_dict: {e}")
|
| 282 |
+
|
| 283 |
+
if 'step' in checkpoint and self.is_master:
|
| 284 |
+
print(f"Checkpoint from step {checkpoint['step']}.")
|
| 285 |
+
|
| 286 |
+
dist.barrier()
|
| 287 |
+
|
| 288 |
+
def _setup_ddp(self):
|
| 289 |
+
"""Sets up DDP and DataLoaders."""
|
| 290 |
+
self.denoiser = self.denoiser.to(self.device)
|
| 291 |
+
self.denoiser = DDP(self.denoiser, device_ids=[self.local_rank])
|
| 292 |
+
|
| 293 |
+
for param in self.denoiser.parameters():
|
| 294 |
+
param.requires_grad = True
|
| 295 |
+
|
| 296 |
+
@torch.no_grad()
|
| 297 |
+
def encode_image(self, images) -> torch.Tensor:
|
| 298 |
+
if isinstance(images, torch.Tensor):
|
| 299 |
+
batch_tensor = images.to(self.device)
|
| 300 |
+
elif isinstance(images, list):
|
| 301 |
+
assert all(isinstance(i, Image.Image) for i in images), "Image list should be list of PIL images"
|
| 302 |
+
image = [i.resize((518, 518), Image.LANCZOS) for i in images]
|
| 303 |
+
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
| 304 |
+
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
| 305 |
+
batch_tensor = torch.stack(image).to(self.device)
|
| 306 |
+
else:
|
| 307 |
+
raise ValueError(f"Unsupported type of image: {type(image)}")
|
| 308 |
+
|
| 309 |
+
if batch_tensor.shape[-2:] != (518, 518):
|
| 310 |
+
batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
|
| 311 |
+
|
| 312 |
+
features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
|
| 313 |
+
patchtokens = F.layer_norm(features, features.shape[-1:])
|
| 314 |
+
return patchtokens
|
| 315 |
+
|
| 316 |
+
def process_batch(self, batch):
|
| 317 |
+
preprocessed_images = batch['image']
|
| 318 |
+
cond_ = self.encode_image(preprocessed_images)
|
| 319 |
+
return cond_
|
| 320 |
+
|
| 321 |
+
def train(self, num_epochs=1000):
|
| 322 |
+
# 1. 无条件生成的准备工作 (和之前一样)
|
| 323 |
+
if self.is_cond == False:
|
| 324 |
+
if self.condition_type == 'text':
|
| 325 |
+
txt = ['']
|
| 326 |
+
encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
| 327 |
+
tokens = encoding['input_ids'].to(self.device)
|
| 328 |
+
with torch.no_grad():
|
| 329 |
+
cond_ = self.condition_model(input_ids=tokens).last_hidden_state
|
| 330 |
+
else:
|
| 331 |
+
blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
|
| 332 |
+
with torch.no_grad():
|
| 333 |
+
dummy_cond = self.encode_image([blank_img])
|
| 334 |
+
cond_ = torch.zeros_like(dummy_cond)
|
| 335 |
+
if self.is_master: print(f"Generated unconditional image prompt with shape: {cond_.shape}")
|
| 336 |
+
|
| 337 |
+
self._load_denoiser()
|
| 338 |
+
self.denoiser.train()
|
| 339 |
+
|
| 340 |
+
# 获取全局步数
|
| 341 |
+
global_step = 0
|
| 342 |
+
if self.denoiser_checkpoint_path:
|
| 343 |
+
match = re.search(r'step(\d+)', self.denoiser_checkpoint_path)
|
| 344 |
+
if match:
|
| 345 |
+
global_step = int(match.group(1))
|
| 346 |
+
|
| 347 |
+
accum_steps = self.gradient_accumulation_steps
|
| 348 |
+
if self.is_master:
|
| 349 |
+
print(f"Training with Gradient Accumulation Steps: {accum_steps}")
|
| 350 |
+
|
| 351 |
+
# 确保循环开始前梯度清零
|
| 352 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 353 |
+
|
| 354 |
+
for epoch in range(num_epochs):
|
| 355 |
+
self.dataloader.sampler.set_epoch(epoch)
|
| 356 |
+
epoch_losses = []
|
| 357 |
+
|
| 358 |
+
# 遍历数据
|
| 359 |
+
for i, batch in enumerate(self.dataloader):
|
| 360 |
+
|
| 361 |
+
# --- A. 数据准备 ---
|
| 362 |
+
if self.is_cond and self.condition_type == 'image':
|
| 363 |
+
cond_ = self.process_batch(batch)
|
| 364 |
+
|
| 365 |
+
point_cloud = batch['point_cloud_128'].to(self.device)
|
| 366 |
+
active_coords = batch['active_voxels_128'].to(self.device)
|
| 367 |
+
|
| 368 |
+
batch_size = int(active_coords[:, 0].max().item() + 1)
|
| 369 |
+
if cond_.shape[0] != batch_size:
|
| 370 |
+
cond_ = cond_.expand(batch_size, -1, -1).contiguous().to(self.device)
|
| 371 |
+
else:
|
| 372 |
+
cond_ = cond_.to(self.device)
|
| 373 |
+
|
| 374 |
+
# --- B. 前向传播 & Loss计算 ---
|
| 375 |
+
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 376 |
+
with torch.no_grad():
|
| 377 |
+
active_voxel_feats = self.voxel_encoder(
|
| 378 |
+
p=point_cloud,
|
| 379 |
+
sparse_coords=active_coords,
|
| 380 |
+
res=128,
|
| 381 |
+
bbox_size=(-0.5, 0.5),
|
| 382 |
+
)
|
| 383 |
+
sparse_input = SparseTensor(
|
| 384 |
+
feats=active_voxel_feats,
|
| 385 |
+
coords=active_coords.int()
|
| 386 |
+
)
|
| 387 |
+
latent_128, posterior = self.vae.encode(sparse_input)
|
| 388 |
+
|
| 389 |
+
terms, _ = self.training_losses(x_0=latent_128, cond=cond_)
|
| 390 |
+
loss = terms['loss']
|
| 391 |
+
|
| 392 |
+
# [重点] Loss 除以累积步数
|
| 393 |
+
loss = loss / accum_steps
|
| 394 |
+
|
| 395 |
+
# --- C. 反向传播 ---
|
| 396 |
+
# 注意:这里没有 no_sync,每次都会同步梯度
|
| 397 |
+
self.scaler.scale(loss).backward()
|
| 398 |
+
|
| 399 |
+
# 记录还原后的 Loss 用于显示
|
| 400 |
+
current_real_loss = loss.item() * accum_steps
|
| 401 |
+
epoch_losses.append(current_real_loss)
|
| 402 |
+
|
| 403 |
+
# --- D. 梯度累积判断与更新 ---
|
| 404 |
+
if (i + 1) % accum_steps == 0:
|
| 405 |
+
# 如果需要 clip_grad_norm,可以在这里加:
|
| 406 |
+
# self.scaler.unscale_(self.optimizer)
|
| 407 |
+
# torch.nn.utils.clip_grad_norm_(self.denoiser.parameters(), max_norm=1.0)
|
| 408 |
+
|
| 409 |
+
self.scaler.step(self.optimizer)
|
| 410 |
+
self.scaler.update()
|
| 411 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 412 |
+
|
| 413 |
+
global_step += 1
|
| 414 |
+
|
| 415 |
+
# --- Logging (只在更新步进行) ---
|
| 416 |
+
if self.is_master:
|
| 417 |
+
if global_step % 10 == 0:
|
| 418 |
+
print(f"Epoch {epoch+1} Step {global_step}: "
|
| 419 |
+
f"Batch_Loss = {current_real_loss:.4f}, "
|
| 420 |
+
f"Epoch_Mean = {np.mean(epoch_losses):.4f}")
|
| 421 |
+
|
| 422 |
+
if global_step % self.i_save == 0:
|
| 423 |
+
checkpoint = {
|
| 424 |
+
'denoiser': self.denoiser.module.state_dict(),
|
| 425 |
+
'step': global_step
|
| 426 |
+
}
|
| 427 |
+
loss_str = f"{np.mean(epoch_losses):.6f}".replace('.', '_')
|
| 428 |
+
save_path = os.path.join(self.save_dir, f"checkpoint_step{global_step}_loss{loss_str}.pt")
|
| 429 |
+
torch.save(checkpoint, save_path)
|
| 430 |
+
print(f"Saved checkpoint to {save_path}")
|
| 431 |
+
|
| 432 |
+
# --- E. 处理 Epoch 结束时的残留 Batch (Leftovers) ---
|
| 433 |
+
# 如果 dataloader 长度不能被 accum_steps 整除,且 drop_last=False,
|
| 434 |
+
# 这里需要把最后累积的一点梯度更新掉。
|
| 435 |
+
# (如果你的 dataloader 设置了 drop_last=True 且总数够整除,这里不会触发,但写上比较保险)
|
| 436 |
+
if (i + 1) % accum_steps != 0:
|
| 437 |
+
self.scaler.step(self.optimizer)
|
| 438 |
+
self.scaler.update()
|
| 439 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 440 |
+
# 注意:这里通常不增加 global_step 或者看你习惯,
|
| 441 |
+
# 因为这是一个“不完整”的 step,通常梯度也是不对等的(因为除数还是accum_steps)
|
| 442 |
+
# 所以很多实现为了稳定,直接设置 drop_last=True 避开这种情况。
|
| 443 |
+
|
| 444 |
+
if self.is_master:
|
| 445 |
+
avg_loss = np.mean(epoch_losses) if epoch_losses else 0
|
| 446 |
+
log_path = os.path.join(self.save_dir, "loss_log.txt")
|
| 447 |
+
with open(log_path, "a") as f:
|
| 448 |
+
f.write(f"Epoch {epoch+1}, Step {global_step}, AvgLoss {avg_loss:.6f}\n")
|
| 449 |
+
|
| 450 |
+
dist.barrier()
|
| 451 |
+
|
| 452 |
+
if self.is_master:
|
| 453 |
+
print("Training complete.")
|
| 454 |
+
|
| 455 |
+
def main():
|
| 456 |
+
# if mp.get_start_method(allow_none=True) != 'spawn':
|
| 457 |
+
# mp.set_start_method('spawn', force=True)
|
| 458 |
+
|
| 459 |
+
# if mp.get_start_method(allow_none=True) != 'forkserver':
|
| 460 |
+
# mp.set_start_method('forkserver', force=True)
|
| 461 |
+
|
| 462 |
+
rank, local_rank, world_size = setup_distributed()
|
| 463 |
+
torch.manual_seed(42+rank)
|
| 464 |
+
np.random.seed(42+rank)
|
| 465 |
+
|
| 466 |
+
# Path to your config
|
| 467 |
+
config_path = "/home/tiger/yy/src/Michelangelo-master/config_slat_flow_128to512_pointnet_head.yaml"
|
| 468 |
+
with open(config_path) as f:
|
| 469 |
+
cfg = yaml.safe_load(f)
|
| 470 |
+
|
| 471 |
+
# Initialize Flow Model (on CPU first)
|
| 472 |
+
diffusion_model = SLatFlowModel(
|
| 473 |
+
resolution=cfg['flow']['resolution'],
|
| 474 |
+
in_channels=cfg['flow']['in_channels'],
|
| 475 |
+
out_channels=cfg['flow']['out_channels'],
|
| 476 |
+
model_channels=cfg['flow']['model_channels'],
|
| 477 |
+
cond_channels=cfg['flow']['cond_channels'],
|
| 478 |
+
num_blocks=cfg['flow']['num_blocks'],
|
| 479 |
+
num_heads=cfg['flow']['num_heads'],
|
| 480 |
+
mlp_ratio=cfg['flow']['mlp_ratio'],
|
| 481 |
+
patch_size=cfg['flow']['patch_size'],
|
| 482 |
+
num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
|
| 483 |
+
io_block_channels=cfg['flow']['io_block_channels'],
|
| 484 |
+
pe_mode=cfg['flow']['pe_mode'],
|
| 485 |
+
qk_rms_norm=cfg['flow']['qk_rms_norm'],
|
| 486 |
+
qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
|
| 487 |
+
use_fp16=cfg['flow'].get('use_fp16', False),
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
torch.manual_seed(42 + rank)
|
| 491 |
+
np.random.seed(42 + rank)
|
| 492 |
+
|
| 493 |
+
trainer = SLatFlowMatchingTrainer(
|
| 494 |
+
denoiser=diffusion_model,
|
| 495 |
+
t_schedule=cfg['t_schedule'],
|
| 496 |
+
sigma_min=cfg['sigma_min'],
|
| 497 |
+
cfg=cfg,
|
| 498 |
+
rank=rank,
|
| 499 |
+
local_rank=local_rank,
|
| 500 |
+
world_size=world_size,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
trainer.train()
|
| 504 |
+
cleanup_distributed()
|
| 505 |
+
|
| 506 |
+
if __name__ == '__main__':
|
| 507 |
+
main()
|
trellis/__init__.py
CHANGED
|
@@ -2,5 +2,5 @@ from . import models
|
|
| 2 |
from . import modules
|
| 3 |
from . import pipelines
|
| 4 |
from . import renderers
|
| 5 |
-
from . import representations
|
| 6 |
from . import utils
|
|
|
|
| 2 |
from . import modules
|
| 3 |
from . import pipelines
|
| 4 |
from . import renderers
|
| 5 |
+
# from . import representations
|
| 6 |
from . import utils
|
trellis/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/__pycache__/__init__.cpython-310.pyc and b/trellis/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/models/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/models/__pycache__/__init__.cpython-310.pyc and b/trellis/models/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc
CHANGED
|
Binary files a/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc and b/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc differ
|
|
|
trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc
CHANGED
|
Binary files a/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc and b/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc differ
|
|
|
trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc
CHANGED
|
Binary files a/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc and b/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc differ
|
|
|
trellis/modules/__pycache__/norm.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/__pycache__/norm.cpython-310.pyc and b/trellis/modules/__pycache__/norm.cpython-310.pyc differ
|
|
|
trellis/modules/__pycache__/spatial.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/__pycache__/spatial.cpython-310.pyc and b/trellis/modules/__pycache__/spatial.cpython-310.pyc differ
|
|
|
trellis/modules/__pycache__/utils.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/__pycache__/utils.cpython-310.pyc and b/trellis/modules/__pycache__/utils.cpython-310.pyc differ
|
|
|
trellis/modules/attention/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc and b/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc differ
|
|
|
trellis/modules/attention/__pycache__/modules.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/attention/__pycache__/modules.cpython-310.pyc and b/trellis/modules/attention/__pycache__/modules.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/__pycache__/basic.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/__pycache__/linear.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc and b/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc differ
|
|
|
trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc differ
|
|
|
trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc differ
|
|
|
trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc
CHANGED
|
Binary files a/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc differ
|
|
|
trellis/pipelines/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/__pycache__/__init__.cpython-310.pyc and b/trellis/pipelines/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/pipelines/__pycache__/base.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/__pycache__/base.cpython-310.pyc and b/trellis/pipelines/__pycache__/base.cpython-310.pyc differ
|
|
|
trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc and b/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc differ
|
|
|
trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc and b/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc differ
|
|
|
trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc differ
|
|
|
trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc differ
|
|
|
trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc differ
|
|
|
trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc
CHANGED
|
Binary files a/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc differ
|
|
|
trellis/renderers/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/renderers/__pycache__/__init__.cpython-310.pyc and b/trellis/renderers/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/representations/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/representations/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc differ
|
|
|
trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc
CHANGED
|
Binary files a/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc differ
|
|
|
trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc
CHANGED
|
Binary files a/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc differ
|
|
|