udbbdh commited on
Commit
4605a84
·
verified ·
1 Parent(s): f39c012

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. config_slat_flow_128to512_pointnet_head.yaml +122 -0
  2. config_slat_flow_128to512_pointnet_head_test.yaml +123 -0
  3. test_slat_flow_128to512_pointnet_head.py +404 -0
  4. test_slat_flow_128to512_pointnet_head_tomesh.py +1630 -0
  5. test_slat_vae_128to512_pointnet_vae_head.py +12 -14
  6. train_slat_flow_128to512_pointnet_head.py +507 -0
  7. trellis/__init__.py +1 -1
  8. trellis/__pycache__/__init__.cpython-310.pyc +0 -0
  9. trellis/models/__pycache__/__init__.cpython-310.pyc +0 -0
  10. trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc +0 -0
  11. trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc +0 -0
  12. trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc +0 -0
  13. trellis/modules/__pycache__/norm.cpython-310.pyc +0 -0
  14. trellis/modules/__pycache__/spatial.cpython-310.pyc +0 -0
  15. trellis/modules/__pycache__/utils.cpython-310.pyc +0 -0
  16. trellis/modules/attention/__pycache__/__init__.cpython-310.pyc +0 -0
  17. trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc +0 -0
  18. trellis/modules/attention/__pycache__/modules.cpython-310.pyc +0 -0
  19. trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc +0 -0
  20. trellis/modules/sparse/__pycache__/basic.cpython-310.pyc +0 -0
  21. trellis/modules/sparse/__pycache__/linear.cpython-310.pyc +0 -0
  22. trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc +0 -0
  23. trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc +0 -0
  24. trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc +0 -0
  25. trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc +0 -0
  26. trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc +0 -0
  27. trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc +0 -0
  28. trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc +0 -0
  29. trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc +0 -0
  30. trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc +0 -0
  31. trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  32. trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc +0 -0
  33. trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc +0 -0
  34. trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  35. trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc +0 -0
  36. trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc +0 -0
  37. trellis/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
  38. trellis/pipelines/__pycache__/base.cpython-310.pyc +0 -0
  39. trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc +0 -0
  40. trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc +0 -0
  41. trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc +0 -0
  42. trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc +0 -0
  43. trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc +0 -0
  44. trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc +0 -0
  45. trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc +0 -0
  46. trellis/renderers/__pycache__/__init__.cpython-310.pyc +0 -0
  47. trellis/representations/__pycache__/__init__.cpython-310.pyc +0 -0
  48. trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc +0 -0
  49. trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc +0 -0
  50. trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc +0 -0
config_slat_flow_128to512_pointnet_head.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pred_direction: false
3
+ relative_embed: true
4
+ using_attn: false
5
+ add_block_embed: true
6
+ multires: 12
7
+
8
+ embed_dim: 1024
9
+ in_channels: 1024
10
+ model_channels: 384
11
+ latent_dim: 16
12
+
13
+ block_size: 16
14
+ pos_encoding: 'nerf'
15
+ attn_first: false
16
+
17
+ add_edge_glb_feats: true
18
+ add_direction: false
19
+
20
+ encoder_blocks:
21
+ - in_channels: 1024
22
+ model_channels: 512
23
+ num_blocks: 8
24
+ num_heads: 8
25
+ out_channels: 512
26
+
27
+ decoder_blocks_edge:
28
+ - in_channels: 512
29
+ model_channels: 512
30
+ num_blocks: 0
31
+ num_heads: 0
32
+ out_channels: 256
33
+ resolution: 128
34
+ - in_channels: 256
35
+ model_channels: 256
36
+ num_blocks: 0
37
+ num_heads: 0
38
+ out_channels: 128
39
+ resolution: 256
40
+ # - in_channels: 64
41
+ # model_channels: 64
42
+ # num_blocks: 0
43
+ # num_heads: 0
44
+ # out_channels: 32
45
+ # resolution: 512
46
+
47
+ decoder_blocks_vtx:
48
+ - in_channels: 512
49
+ model_channels: 512
50
+ num_blocks: 0
51
+ num_heads: 0
52
+ out_channels: 256
53
+ resolution: 128
54
+ - in_channels: 256
55
+ model_channels: 256
56
+ num_blocks: 0
57
+ num_heads: 0
58
+ out_channels: 128
59
+ resolution: 256
60
+ # - in_channels: 64
61
+ # model_channels: 64
62
+ # num_blocks: 0
63
+ # num_heads: 0
64
+ # out_channels: 32
65
+ # resolution: 512
66
+
67
+ "t_schedule":
68
+ "name": "logitNormal"
69
+ "args":
70
+ "mean": 1.0
71
+ "std": 1.0
72
+
73
+ "sigma_min": 1.e-5
74
+
75
+ training:
76
+ batch_size: 1
77
+ lr: 1.e-4
78
+ step_size: 20
79
+ gamma: 0.95
80
+ save_every: 500
81
+ start_epoch: 0
82
+ max_epochs: 300
83
+ num_workers: 32
84
+
85
+ output_dir: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope
86
+ clip_model_path: None
87
+ dinov2_model_path: None
88
+
89
+ vae_path: /home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt
90
+ denoiser_checkpoint_path: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512/checkpoint_step143000_loss0_736924.pt
91
+
92
+
93
+ dataset:
94
+ path: /home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01
95
+ cache_dir: /home/tiger/yy/src/dataset_cache/unique_files_glb_under6000face_2degree_30ratio_0.01
96
+
97
+ renders_dir: None
98
+ filter_active_voxels: true
99
+ cache_filter_path: /home/tiger/yy/src/40w_2000-100000edge_2000-100000active.txt
100
+
101
+ base_resolution: 1024
102
+ min_resolution: 128
103
+
104
+ n_train_samples: 1024
105
+ sample_type: dora
106
+
107
+ flow:
108
+ "resolution": 128
109
+ "in_channels": 16
110
+ "out_channels": 16
111
+ "model_channels": 768
112
+ "cond_channels": 1024
113
+ "num_blocks": 12
114
+ "num_heads": 12
115
+ "mlp_ratio": 4
116
+ "patch_size": 2
117
+ "num_io_res_blocks": 2
118
+ "io_block_channels": [128]
119
+ "pe_mode": "rope"
120
+ "qk_rms_norm": true
121
+ "qk_rms_norm_cross": false
122
+ "use_fp16": false
config_slat_flow_128to512_pointnet_head_test.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pred_direction: false
3
+ relative_embed: true
4
+ using_attn: false
5
+ add_block_embed: true
6
+ multires: 12
7
+
8
+ embed_dim: 1024
9
+ in_channels: 1024
10
+ model_channels: 384
11
+ latent_dim: 16
12
+
13
+ block_size: 16
14
+ pos_encoding: 'nerf'
15
+ attn_first: false
16
+
17
+ add_edge_glb_feats: true
18
+ add_direction: false
19
+
20
+ encoder_blocks:
21
+ - in_channels: 1024
22
+ model_channels: 512
23
+ num_blocks: 8
24
+ num_heads: 8
25
+ out_channels: 512
26
+
27
+ decoder_blocks_edge:
28
+ - in_channels: 512
29
+ model_channels: 512
30
+ num_blocks: 0
31
+ num_heads: 0
32
+ out_channels: 256
33
+ resolution: 128
34
+ - in_channels: 256
35
+ model_channels: 256
36
+ num_blocks: 0
37
+ num_heads: 0
38
+ out_channels: 128
39
+ resolution: 256
40
+ # - in_channels: 64
41
+ # model_channels: 64
42
+ # num_blocks: 0
43
+ # num_heads: 0
44
+ # out_channels: 32
45
+ # resolution: 512
46
+
47
+ decoder_blocks_vtx:
48
+ - in_channels: 512
49
+ model_channels: 512
50
+ num_blocks: 0
51
+ num_heads: 0
52
+ out_channels: 256
53
+ resolution: 128
54
+ - in_channels: 256
55
+ model_channels: 256
56
+ num_blocks: 0
57
+ num_heads: 0
58
+ out_channels: 128
59
+ resolution: 256
60
+ # - in_channels: 64
61
+ # model_channels: 64
62
+ # num_blocks: 0
63
+ # num_heads: 0
64
+ # out_channels: 32
65
+ # resolution: 512
66
+
67
+ "t_schedule":
68
+ "name": "logitNormal"
69
+ "args":
70
+ "mean": 1.0
71
+ "std": 1.0
72
+
73
+ "sigma_min": 1.e-5
74
+
75
+ training:
76
+ batch_size: 1
77
+ lr: 1.e-4
78
+ step_size: 20
79
+ gamma: 0.95
80
+ save_every: 500
81
+ start_epoch: 0
82
+ max_epochs: 300
83
+ num_workers: 32
84
+
85
+ output_dir: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope
86
+ clip_model_path: None
87
+ dinov2_model_path: None
88
+
89
+ vae_path: /home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt
90
+ denoiser_checkpoint_path: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512/checkpoint_step143000_loss0_736924.pt
91
+
92
+
93
+ dataset:
94
+ path: /home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01
95
+ path: /home/tiger/yy/src/trellis_clean_mesh/mesh_data
96
+ cache_dir: /home/tiger/yy/src/dataset_cache/unique_files_glb_under6000face_2degree_30ratio_0.01
97
+
98
+ renders_dir: None
99
+ filter_active_voxels: false
100
+ cache_filter_path: /home/tiger/yy/src/40w_2000-100000edge_2000-100000active.txt
101
+
102
+ base_resolution: 1024
103
+ min_resolution: 128
104
+
105
+ n_train_samples: 1024
106
+ sample_type: dora
107
+
108
+ flow:
109
+ "resolution": 128
110
+ "in_channels": 16
111
+ "out_channels": 16
112
+ "model_channels": 768
113
+ "cond_channels": 1024
114
+ "num_blocks": 12
115
+ "num_heads": 12
116
+ "mlp_ratio": 4
117
+ "patch_size": 2
118
+ "num_io_res_blocks": 2
119
+ "io_block_channels": [128]
120
+ "pe_mode": "rope"
121
+ "qk_rms_norm": true
122
+ "qk_rms_norm_cross": false
123
+ "use_fp16": false
test_slat_flow_128to512_pointnet_head.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import yaml
6
+ import time
7
+ from datetime import datetime
8
+ from torch.utils.data import DataLoader
9
+ from functools import partial
10
+ import torch.nn.functional as F
11
+ from torch.amp import GradScaler, autocast
12
+ from typing import *
13
+ from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig, Dinov2Model, AutoImageProcessor, Dinov2Config
14
+ import torch
15
+ import re
16
+ from utils import load_pretrained_woself
17
+
18
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
19
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
20
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet
21
+
22
+ from trellis.models.structured_latent_flow import SLatFlowModel
23
+ from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
24
+
25
+ from trellis.pipelines.samplers import FlowEulerSampler
26
+ from safetensors.torch import load_file
27
+ import open3d as o3d
28
+ from PIL import Image
29
+
30
+ from triposf.modules.sparse.basic import SparseTensor
31
+ from trellis.modules.sparse.basic import SparseTensor as SparseTensor_trellis
32
+
33
+ from triposf.modules.utils import DiagonalGaussianDistribution
34
+
35
+ from sklearn.decomposition import PCA
36
+ import trimesh
37
+ import torchvision.transforms as transforms
38
+
39
+ # --- Helper Functions ---
40
+ def save_colored_ply(points, colors, filename):
41
+ if len(points) == 0:
42
+ print(f"[Warning] No points to save for {filename}")
43
+ return
44
+ # Ensure colors are uint8
45
+ if colors.max() <= 1.0:
46
+ colors = (colors * 255).astype(np.uint8)
47
+ colors = colors.astype(np.uint8)
48
+
49
+ # Add Alpha if missing
50
+ if colors.shape[1] == 3:
51
+ colors = np.hstack([colors, np.full((len(colors), 1), 255, dtype=np.uint8)])
52
+
53
+ cloud = trimesh.PointCloud(points, colors=colors)
54
+ cloud.export(filename)
55
+ print(f"Saved colored point cloud to {filename}")
56
+
57
+ def normalize_to_rgb(features_3d):
58
+ min_vals = features_3d.min(axis=0)
59
+ max_vals = features_3d.max(axis=0)
60
+ range_vals = max_vals - min_vals
61
+ range_vals[range_vals == 0] = 1
62
+ normalized = (features_3d - min_vals) / range_vals
63
+ return (normalized * 255).astype(np.uint8)
64
+
65
+ class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
66
+ def __init__(self, *args, **kwargs):
67
+ super().__init__(*args, **kwargs)
68
+ self.cfg = kwargs.pop('cfg', None)
69
+ if self.cfg is None:
70
+ raise ValueError("Configuration dictionary 'cfg' must be provided.")
71
+
72
+ self.sampler = FlowEulerSampler(sigma_min=1.e-5)
73
+ self.device = torch.device("cuda")
74
+
75
+ # Based on PointNet Encoder setting
76
+ self.resolution = 128
77
+
78
+ self.condition_type = 'image'
79
+ self.is_cond = False
80
+
81
+ self.img_res = 518
82
+ self.feature_dim = self.cfg['model']['latent_dim']
83
+
84
+ self._init_components(
85
+ clip_model_path=self.cfg['training'].get('clip_model_path', None),
86
+ dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
87
+ vae_path=self.cfg['training']['vae_path'],
88
+ )
89
+
90
+ # Classifier head removed as it is not part of the Active Voxel pipeline
91
+
92
+ def _load_denoiser(self, denoiser_checkpoint_path):
93
+ path = denoiser_checkpoint_path
94
+ if not path or not os.path.isfile(path):
95
+ print("No valid checkpoint path provided for fine-tuning. Starting from scratch.")
96
+ return
97
+
98
+ print(f"Loading checkpoint from: {path}")
99
+ checkpoint = torch.load(path, map_location=self.device)
100
+
101
+ try:
102
+ denoiser_state_dict = checkpoint['denoiser']
103
+ # Handle DDP prefix
104
+ if next(iter(denoiser_state_dict)).startswith('module.'):
105
+ denoiser_state_dict = {k[7:]: v for k, v in denoiser_state_dict.items()}
106
+
107
+ self.denoiser.load_state_dict(denoiser_state_dict)
108
+ print("Denoiser weights loaded successfully.")
109
+ except KeyError:
110
+ print("[WARN] 'denoiser' key not found in checkpoint. Skipping.")
111
+ except Exception as e:
112
+ print(f"[ERROR] Failed to load denoiser state_dict: {e}")
113
+
114
+ def _init_components(self,
115
+ clip_model_path=None,
116
+ dinov2_model_path=None,
117
+ vae_path=None,
118
+ ):
119
+
120
+ # 1. Initialize PointNet Voxel Encoder (Matches Training)
121
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
122
+ in_channels=15,
123
+ hidden_dim=256,
124
+ out_channels=1024,
125
+ scatter_type='mean',
126
+ n_blocks=5,
127
+ resolution=128,
128
+ add_label=False,
129
+ ).to(self.device)
130
+
131
+ # 2. Initialize VAE
132
+ self.vae = VoxelVAE(
133
+ in_channels=self.cfg['model']['in_channels'],
134
+ latent_dim=self.cfg['model']['latent_dim'],
135
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
136
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
137
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
138
+ num_heads=8,
139
+ num_head_channels=64,
140
+ mlp_ratio=4.0,
141
+ attn_mode="swin",
142
+ window_size=8,
143
+ pe_mode="ape",
144
+ use_fp16=False,
145
+ use_checkpoint=False,
146
+ qk_rms_norm=False,
147
+ using_subdivide=True,
148
+ using_attn=self.cfg['model']['using_attn'],
149
+ attn_first=self.cfg['model'].get('attn_first', True),
150
+ pred_direction=self.cfg['model'].get('pred_direction', False),
151
+ ).to(self.device)
152
+
153
+ # 3. Initialize Dataset with collate_fn_pointnet
154
+ self.dataset = VoxelVertexDataset_edge(
155
+ root_dir=self.cfg['dataset']['path'],
156
+ base_resolution=self.cfg['dataset']['base_resolution'],
157
+ min_resolution=self.cfg['dataset']['min_resolution'],
158
+ cache_dir=self.cfg['dataset']['cache_dir'],
159
+ renders_dir=self.cfg['dataset']['renders_dir'],
160
+
161
+ process_img=False,
162
+
163
+ active_voxel_res=128,
164
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
165
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
166
+ sample_type=self.cfg['dataset'].get('sample_type', 'dora'),
167
+ )
168
+
169
+ self.dataloader = DataLoader(
170
+ self.dataset,
171
+ batch_size=1,
172
+ shuffle=True,
173
+ collate_fn=partial(collate_fn_pointnet,), # Critical Change
174
+ num_workers=0,
175
+ pin_memory=True,
176
+ persistent_workers=False,
177
+ )
178
+
179
+ # 4. Load Pretrained Weights
180
+ # Assuming vae_path contains 'voxel_encoder' and 'vae'
181
+ print(f"Loading VAE/Encoder from {vae_path}")
182
+ ckpt = torch.load(vae_path, map_location='cpu')
183
+
184
+ # Load VAE
185
+ if 'vae' in ckpt:
186
+ self.vae.load_state_dict(ckpt['vae'], strict=False)
187
+ else:
188
+ self.vae.load_state_dict(ckpt) # Fallback
189
+
190
+ # Load Encoder
191
+ if 'voxel_encoder' in ckpt:
192
+ self.voxel_encoder.load_state_dict(ckpt['voxel_encoder'])
193
+ else:
194
+ print("[WARN] 'voxel_encoder' not found in checkpoint, random init (BAD for inference).")
195
+
196
+ self.voxel_encoder.eval()
197
+ self.vae.eval()
198
+
199
+ # 5. Initialize Conditioning Model
200
+ if self.condition_type == 'text':
201
+ self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
202
+ self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
203
+ elif self.condition_type == 'image':
204
+ model_name = 'dinov2_vitl14_reg'
205
+ local_repo_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main"
206
+ weights_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
207
+
208
+ dinov2_model = torch.hub.load(
209
+ repo_or_dir=local_repo_path,
210
+ model=model_name,
211
+ source='local',
212
+ pretrained=False
213
+ )
214
+ self.condition_model = dinov2_model
215
+ self.condition_model.load_state_dict(torch.load(weights_path))
216
+
217
+ self.image_cond_model_transform = transforms.Compose([
218
+ transforms.ToTensor(),
219
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
220
+ ])
221
+ else:
222
+ raise ValueError(f"Unsupported condition type: {self.condition_type}")
223
+
224
+ self.condition_model.to(self.device).eval()
225
+
226
+ @torch.no_grad()
227
+ def encode_image(self, images) -> torch.Tensor:
228
+ if isinstance(images, torch.Tensor):
229
+ batch_tensor = images.to(self.device)
230
+ elif isinstance(images, list):
231
+ assert all(isinstance(i, Image.Image) for i in images)
232
+ image = [i.resize((518, 518), Image.LANCZOS) for i in images]
233
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
234
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
235
+ batch_tensor = torch.stack(image).to(self.device)
236
+ else:
237
+ raise ValueError(f"Unsupported type of image: {type(images)}")
238
+
239
+ if batch_tensor.shape[-2:] != (518, 518):
240
+ batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
241
+
242
+ features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
243
+ patchtokens = F.layer_norm(features, features.shape[-1:])
244
+ return patchtokens
245
+
246
+ def process_batch(self, batch):
247
+ preprocessed_images = batch['image']
248
+ cond_ = self.encode_image(preprocessed_images)
249
+ return cond_
250
+
251
+ def eval(self):
252
+ # Unconditional Setup
253
+ if self.is_cond == False:
254
+ if self.condition_type == 'text':
255
+ txt = ['']
256
+ encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
257
+ tokens = encoding['input_ids'].to(self.device)
258
+ with torch.no_grad():
259
+ cond_ = self.condition_model(input_ids=tokens).last_hidden_state
260
+ else:
261
+ blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
262
+ with torch.no_grad():
263
+ dummy_cond = self.encode_image([blank_img])
264
+ cond_ = torch.zeros_like(dummy_cond)
265
+ print(f"Generated unconditional image prompt (zero tensor) with shape: {cond_.shape}")
266
+
267
+ self.denoiser.eval()
268
+
269
+ # Load Denoiser Checkpoint
270
+ # Update this path to your ACTIVE VOXEL trained checkpoint
271
+ checkpoint_path = '/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/checkpoint_step143500_loss0_766792.pt'
272
+
273
+ self._load_denoiser(checkpoint_path)
274
+
275
+ filename = os.path.basename(checkpoint_path)
276
+ match = re.search(r'step(\d+)', filename)
277
+ step_str = match.group(1) if match else "eval"
278
+ save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_trellis")
279
+ # save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_40w_train")
280
+ os.makedirs(save_dir, exist_ok=True)
281
+ print(f"Results will be saved to: {save_dir}")
282
+
283
+ for i, batch in enumerate(self.dataloader):
284
+ if i > 50: exit() # Visualize first 10
285
+
286
+ if self.is_cond and self.condition_type == 'image':
287
+ cond_ = self.process_batch(batch)
288
+
289
+ if cond_.shape[0] != 1:
290
+ cond_ = cond_.expand(batch['active_voxels_128'].shape[0], -1, -1).contiguous().to(self.device)
291
+ else:
292
+ cond_ = cond_.to(self.device)
293
+
294
+ # --- Data Retrieval (Matches collate_fn_pointnet) ---
295
+ point_cloud = batch['point_cloud_128'].to(self.device)
296
+ active_coords = batch['active_voxels_128'].to(self.device) # [N, 4]
297
+
298
+ with autocast(device_type='cuda', dtype=torch.bfloat16):
299
+ with torch.no_grad():
300
+ # 1. Encode Ground Truth Latents
301
+ active_voxel_feats = self.voxel_encoder(
302
+ p=point_cloud,
303
+ sparse_coords=active_coords,
304
+ res=128,
305
+ bbox_size=(-0.5, 0.5),
306
+ )
307
+
308
+ sparse_input = SparseTensor(
309
+ feats=active_voxel_feats,
310
+ coords=active_coords.int()
311
+ )
312
+
313
+ # Encode to get GT distribution
314
+ gt_latents, posterior = self.vae.encode(sparse_input)
315
+
316
+ print(f"Batch {i}: Active voxels: {active_coords.shape[0]}")
317
+
318
+ # 2. Generation / Sampling
319
+ # Generate noise on the SAME active coordinates
320
+ noise = SparseTensor_trellis(
321
+ coords=active_coords.int(),
322
+ feats=torch.randn(
323
+ active_coords.shape[0],
324
+ self.feature_dim,
325
+ device=self.device,
326
+ ),
327
+ )
328
+
329
+ sample_results = self.sampler.sample(
330
+ model=self.denoiser.float(),
331
+ noise=noise.to(self.device).float(),
332
+ cond=cond_.to(self.device).float(),
333
+ steps=50,
334
+ rescale_t=1.0,
335
+ verbose=True,
336
+ )
337
+
338
+ generated_sparse_tensor = sample_results.samples
339
+ generated_coords = generated_sparse_tensor.coords
340
+ generated_features = generated_sparse_tensor.feats
341
+
342
+ print('Gen features mean:', generated_features.mean().item(), 'std:', generated_features.std().item())
343
+ print('GT features mean:', gt_latents.feats.mean().item(), 'std:', gt_latents.feats.std().item())
344
+ print('MSE:', F.mse_loss(generated_features, gt_latents.feats).item())
345
+
346
+ # --- Visualization (PCA) ---
347
+ gt_feats_np = gt_latents.feats.detach().cpu().numpy()
348
+ gen_feats_np = generated_features.detach().cpu().numpy()
349
+ coords_np = active_coords[:, 1:4].detach().cpu().numpy() # x, y, z
350
+
351
+ print("Visualizing features using PCA...")
352
+ pca = PCA(n_components=3)
353
+
354
+ # Fit PCA on GT, transform both
355
+ pca.fit(gt_feats_np)
356
+ gt_feats_3d = pca.transform(gt_feats_np)
357
+ gen_feats_3d = pca.transform(gen_feats_np)
358
+
359
+ gt_colors = normalize_to_rgb(gt_feats_3d)
360
+ gen_colors = normalize_to_rgb(gen_feats_3d)
361
+
362
+ # Save PLYs
363
+ save_colored_ply(coords_np, gt_colors, os.path.join(save_dir, f"batch_{i}_gt_pca.ply"))
364
+ save_colored_ply(coords_np, gen_colors, os.path.join(save_dir, f"batch_{i}_gen_pca.ply"))
365
+
366
+ # Save Tensors for further analysis
367
+ torch.save(gt_latents, os.path.join(save_dir, f"gt_latent_{i}.pt"))
368
+
369
+ torch.save(batch, os.path.join(save_dir, f"gt_data_batch_{i}.pt"))
370
+ torch.save(sample_results.samples, os.path.join(save_dir, f"sample_latent_{i}.pt"))
371
+
372
+ if __name__ == '__main__':
373
+ torch.manual_seed(42)
374
+ config_path = "/home/tiger/yy/src/Michelangelo-master/config_slat_flow_128to512_pointnet_head_test.yaml"
375
+ with open(config_path) as f:
376
+ cfg = yaml.safe_load(f)
377
+
378
+ # Initialize Model on CPU first
379
+ diffusion_model = SLatFlowModel(
380
+ resolution=cfg['flow']['resolution'],
381
+ in_channels=cfg['flow']['in_channels'],
382
+ out_channels=cfg['flow']['out_channels'],
383
+ model_channels=cfg['flow']['model_channels'],
384
+ cond_channels=cfg['flow']['cond_channels'],
385
+ num_blocks=cfg['flow']['num_blocks'],
386
+ num_heads=cfg['flow']['num_heads'],
387
+ mlp_ratio=cfg['flow']['mlp_ratio'],
388
+ patch_size=cfg['flow']['patch_size'],
389
+ num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
390
+ io_block_channels=cfg['flow']['io_block_channels'],
391
+ pe_mode=cfg['flow']['pe_mode'],
392
+ qk_rms_norm=cfg['flow']['qk_rms_norm'],
393
+ qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
394
+ use_fp16=cfg['flow'].get('use_fp16', False),
395
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
396
+
397
+ trainer = SLatFlowMatchingTrainer(
398
+ denoiser=diffusion_model,
399
+ t_schedule=cfg['t_schedule'],
400
+ sigma_min=cfg['sigma_min'],
401
+ cfg=cfg,
402
+ )
403
+
404
+ trainer.eval()
test_slat_flow_128to512_pointnet_head_tomesh.py ADDED
@@ -0,0 +1,1630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from tqdm import tqdm
7
+ from collections import defaultdict
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+ from torch.utils.data import DataLoader, Subset
11
+ from triposf.modules.sparse.basic import SparseTensor
12
+
13
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
14
+
15
+ from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead
16
+ from utils import load_pretrained_woself
17
+
18
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
19
+
20
+
21
+ from functools import partial
22
+ import itertools
23
+ from typing import List, Tuple, Set
24
+ from collections import OrderedDict
25
+ from scipy.spatial import cKDTree
26
+ from sklearn.neighbors import KDTree
27
+
28
+ import trimesh
29
+
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import time
33
+
34
+ from sklearn.decomposition import PCA
35
+ import matplotlib.pyplot as plt
36
+
37
+ import networkx as nx
38
+
39
+ def predict_mesh_connectivity(
40
+ connection_head,
41
+ vtx_feats,
42
+ vtx_coords,
43
+ batch_size=10000,
44
+ threshold=0.5,
45
+ k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测
46
+ device='cuda'
47
+ ):
48
+ """
49
+ Args:
50
+ connection_head: 训练好的 MLP 模型
51
+ vtx_feats: [N, C] 顶点特征
52
+ vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边)
53
+ batch_size: MLP 推理的 batch size
54
+ threshold: 判定连接的概率阈值
55
+ k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。
56
+ """
57
+ num_verts = vtx_feats.shape[0]
58
+ if num_verts < 3:
59
+ return [], [] # 无法构成三角形
60
+
61
+ connection_head.eval()
62
+
63
+ # --- 1. 生成候选边 (Candidate Edges) ---
64
+ if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts:
65
+ # 策略 A: 局部 KNN (推荐)
66
+ # 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版
67
+ # 或者直接暴力 cdist 如果 N < 10000
68
+
69
+ # 为了简单且高效,这里演示简单的 cdist (注意显存)
70
+ # 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree
71
+ dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N]
72
+
73
+ # 取 topk (smallest distance),排除自己
74
+ # values: [N, K], indices: [N, K]
75
+ _, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False)
76
+ neighbor_indices = indices[:, 1:] # 去掉第一列(自己)
77
+
78
+ # 构建 source, target 索引
79
+ src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten()
80
+ dst = neighbor_indices.flatten()
81
+
82
+ # 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重
83
+ # 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可
84
+ # 这里为了简单,我们生成 u < v 的 mask
85
+ mask = src < dst
86
+ u_indices = src[mask]
87
+ v_indices = dst[mask]
88
+
89
+ else:
90
+ # 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用
91
+ u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device)
92
+
93
+ # --- 2. 批量推理 ---
94
+ all_probs = []
95
+ num_candidates = u_indices.shape[0]
96
+
97
+ with torch.no_grad():
98
+ for i in range(0, num_candidates, batch_size):
99
+ end = min(i + batch_size, num_candidates)
100
+ batch_u = u_indices[i:end]
101
+ batch_v = v_indices[i:end]
102
+
103
+ feat_u = vtx_feats[batch_u]
104
+ feat_v = vtx_feats[batch_v]
105
+
106
+ # Symmetric Forward (和你训练时保持一致)
107
+ # A -> B
108
+ input_uv = torch.cat([feat_u, feat_v], dim=-1)
109
+ logits_uv = connection_head(input_uv)
110
+
111
+ # B -> A
112
+ input_vu = torch.cat([feat_v, feat_u], dim=-1)
113
+ logits_vu = connection_head(input_vu)
114
+
115
+ # Sum logits
116
+ logits = (logits_uv + logits_vu)
117
+ probs = torch.sigmoid(logits)
118
+ all_probs.append(probs)
119
+
120
+ all_probs = torch.cat(all_probs).squeeze() # [M]
121
+
122
+ # --- 3. 筛选连接边 ---
123
+ connected_mask = all_probs > threshold
124
+ final_u = u_indices[connected_mask].cpu().numpy()
125
+ final_v = v_indices[connected_mask].cpu().numpy()
126
+
127
+ edges = np.stack([final_u, final_v], axis=1) # [E, 2]
128
+
129
+ return edges
130
+
131
+ def build_triangles_from_edges(edges, num_verts):
132
+ """
133
+ 从边列表构建三角形。
134
+ 寻找图中所有的 3-Cliques (三元环)。
135
+ 这在图论中是一个经典问题,可以使用 networkx 库。
136
+ """
137
+ if len(edges) == 0:
138
+ return np.empty((0, 3), dtype=int)
139
+
140
+ G = nx.Graph()
141
+ G.add_nodes_from(range(num_verts))
142
+ G.add_edges_from(edges)
143
+
144
+ # 寻找所有的 3-cliques (三角形)
145
+ # enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的
146
+ # 或者使用 nx.triangles ? 不,那个只返回数量
147
+ # 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以
148
+
149
+ # 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w
150
+ triangles = []
151
+ adj = [set(G.neighbors(n)) for n in range(num_verts)]
152
+
153
+ # 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w
154
+ # 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u]
155
+
156
+ # 优化算法:
157
+ for u, v in edges:
158
+ if u > v: u, v = v, u # 确保有序
159
+
160
+ # 找公共邻居
161
+ common = adj[u].intersection(adj[v])
162
+ for w in common:
163
+ if w > v: # 强制顺序 u < v < w 防止重复
164
+ triangles.append([u, v, w])
165
+
166
+ return np.array(triangles)
167
+
168
+ def downsample_voxels(
169
+ voxels: torch.Tensor,
170
+ input_resolution: int,
171
+ output_resolution: int
172
+ ) -> torch.Tensor:
173
+ if input_resolution % output_resolution != 0:
174
+ raise ValueError(f"input_resolution ({input_resolution}) must be divisible "
175
+ f"by output_resolution ({output_resolution}).")
176
+
177
+ factor = input_resolution // output_resolution
178
+
179
+ downsampled_voxels = voxels.clone().to(torch.long)
180
+
181
+ downsampled_voxels[:, 1:] //= factor
182
+
183
+ unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0)
184
+ return unique_downsampled_voxels
185
+
186
+ def visualize_colored_points_ply(coords, vectors, filename):
187
+ """
188
+ 可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。
189
+
190
+ Args:
191
+ coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。
192
+ vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。
193
+ filename (str): 保存输出文件的名称,必须是 .ply 格式。
194
+ """
195
+ # 确保输入是 numpy 数组
196
+ if isinstance(coords, torch.Tensor):
197
+ coords = coords.detach().cpu().numpy()
198
+ if isinstance(vectors, torch.Tensor):
199
+ vectors = vectors.detach().cpu().to(torch.float32).numpy()
200
+
201
+ # 检查输入数据是否为空,防止崩溃
202
+ if coords.size == 0 or vectors.size == 0:
203
+ print(f"警告:输入数据为空,未生成 {filename} 文件。")
204
+ return
205
+
206
+ # 将向量分量从 [-1, 1] 映射到 [0, 255]
207
+ # np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出
208
+ # (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2]
209
+ # * 127.5 将范围从 [0, 2] 缩放到 [0, 255]
210
+ colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8)
211
+
212
+ # 创建一个点云对象,并传入颜色信息
213
+ # trimesh.PointCloud 能够自动处理带颜色的点
214
+ points = trimesh.points.PointCloud(coords, colors=colors)
215
+ # 导出为 PLY 文件
216
+ points.export(filename, file_type='ply')
217
+ print(f"可视化文件已成功保存为: {filename}")
218
+
219
+
220
+ def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0):
221
+ # 转换为整数坐标并去重
222
+ print('len(pred_coords)', len(pred_coords))
223
+
224
+ pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0)
225
+ gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0)
226
+ print('len(pred_array)', len(pred_array))
227
+ pred_total = len(pred_array)
228
+ gt_total = len(gt_array)
229
+
230
+ # 如果没有点,直接返回
231
+ if pred_total == 0 or gt_total == 0:
232
+ return 0, 0.0, pred_total, gt_total
233
+
234
+ # 建立 KDTree(以 gt 为基准)
235
+ tree = KDTree(gt_array)
236
+
237
+ # 查找预测点到最近的 gt 点
238
+ dist, indices = tree.query(pred_array, k=1)
239
+ dist = dist.squeeze()
240
+ indices = indices.squeeze()
241
+
242
+ # 贪心去重:确保 1 对 1
243
+ matches = 0
244
+ used_gt = set()
245
+ for d, idx in zip(dist, indices):
246
+ if d <= threshold and idx not in used_gt:
247
+ matches += 1
248
+ used_gt.add(idx)
249
+
250
+ match_rate = matches / max(gt_total, 1)
251
+
252
+ return matches, match_rate, pred_total, gt_total
253
+
254
+ def flatten_coords_4d(coords_4d: torch.Tensor):
255
+ coords_4d_long = coords_4d.long()
256
+
257
+ base_x = 512
258
+ base_y = 512 * 512
259
+ base_z = 512 * 512 * 512
260
+
261
+ flat_coords = coords_4d_long[:, 0] * base_z + \
262
+ coords_4d_long[:, 1] * base_y + \
263
+ coords_4d_long[:, 2] * base_x + \
264
+ coords_4d_long[:, 3]
265
+ return flat_coords
266
+
267
+ class Tester:
268
+ def __init__(self, ckpt_path, config_path=None, dataset_path=None):
269
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
270
+ self.ckpt_path = ckpt_path
271
+
272
+ self.config = self._load_config(config_path)
273
+ self.dataset_path = dataset_path # or self.config['dataset']['path']
274
+ checkpoint = torch.load(self.ckpt_path, map_location='cpu')
275
+ self.epoch = checkpoint.get('epoch', 0)
276
+
277
+ self._init_models()
278
+ self._init_dataset()
279
+
280
+ self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results")
281
+ os.makedirs(self.result_dir, exist_ok=True)
282
+
283
+ dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '')
284
+ self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path),
285
+ f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs")
286
+ os.makedirs(self.output_voxel_dir, exist_ok=True)
287
+
288
+ self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path),
289
+ f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs")
290
+ os.makedirs(self.output_obj_dir, exist_ok=True)
291
+
292
+ def _save_logit_visualization(self, dense_vol, name, sample_name, ply_threshold=0.01):
293
+ """
294
+ 保存 Logit 的 3D .npy 文件、2D 最大投影热力图,以及带颜色和透明度的 3D .ply 点云
295
+
296
+ Args:
297
+ dense_vol: (H, W, D) numpy array, values in [0, 1]
298
+ name: str (e.g., "edge" or "vertex")
299
+ sample_name: str
300
+ ply_threshold: float, 只有概率大于此值的点才会被保存
301
+ """
302
+ # 1. 保存原始 Dense 数据 (可选)
303
+ npy_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_logits.npy")
304
+ # np.save(npy_path, dense_vol)
305
+
306
+ # 2. 生成 2D 投影热力图 (保持不变)
307
+ proj_x = np.max(dense_vol, axis=0)
308
+ proj_y = np.max(dense_vol, axis=1)
309
+ proj_z = np.max(dense_vol, axis=2)
310
+
311
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
312
+ im0 = axes[0].imshow(proj_x, cmap='turbo', vmin=0, vmax=1, origin='lower')
313
+ axes[0].set_title(f"{name} Max-Proj (YZ)")
314
+ im1 = axes[1].imshow(proj_y, cmap='turbo', vmin=0, vmax=1, origin='lower')
315
+ axes[1].set_title(f"{name} Max-Proj (XZ)")
316
+ im2 = axes[2].imshow(proj_z, cmap='turbo', vmin=0, vmax=1, origin='lower')
317
+ axes[2].set_title(f"{name} Max-Proj (XY)")
318
+
319
+ fig.colorbar(im2, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
320
+ plt.suptitle(f"{sample_name} - {name} Occupancy Probability")
321
+
322
+ png_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_heatmap.png")
323
+ plt.savefig(png_path, dpi=150)
324
+ plt.close(fig)
325
+
326
+ # ------------------------------------------------------------------
327
+ # 3. 保存为带颜色和透明度(RGBA)的 PLY 点云
328
+ # ------------------------------------------------------------------
329
+ # 筛选出概率大于阈值的点坐标
330
+ indices = np.argwhere(dense_vol > ply_threshold)
331
+
332
+ if len(indices) > 0:
333
+ # 获取这些点的概率值 [0, 1]
334
+ values = dense_vol[indices[:, 0], indices[:, 1], indices[:, 2]]
335
+
336
+ # 使用 matplotlib 的 colormap 进行颜色映射
337
+ import matplotlib.cm as cm
338
+ cmap = cm.get_cmap('turbo')
339
+
340
+ # map values [0, 1] to RGBA [0, 1] (N, 4)
341
+ colors_float = cmap(values)
342
+
343
+ # -------------------------------------------------------
344
+ # 【核心修改】:修改 Alpha 通道 (透明度)
345
+ # -------------------------------------------------------
346
+ # 让透明度直接等于概率值。
347
+ # 概率 1.0 -> Alpha 1.0 (完全不透明/颜色深)
348
+ # 概率 0.1 -> Alpha 0.1 (非常透明/颜色浅)
349
+ colors_float[:, 3] = values
350
+
351
+ # 转换为 uint8 [0, 255],保留 4 个通道 (R, G, B, A)
352
+ colors_uint8 = (colors_float * 255).astype(np.uint8)
353
+
354
+ # 坐标转换
355
+ vertices = indices
356
+
357
+ ply_filename = f"{sample_name}_{name}_logits_colored.ply"
358
+ ply_save_path = os.path.join(self.output_voxel_dir, ply_filename)
359
+
360
+ try:
361
+ # 使用 Trimesh 保存 (Trimesh 支持 (N, 4) 的 colors)
362
+ pcd = trimesh.points.PointCloud(vertices=vertices, colors=colors_uint8)
363
+ pcd.export(ply_save_path)
364
+ print(f"Saved colored RGBA logit PLY to {ply_save_path}")
365
+ except Exception as e:
366
+ print(f"Failed to save PLY with trimesh: {e}")
367
+ # Fallback: 手动写入 PLY (需要添加 alpha 属性)
368
+ with open(ply_save_path, 'w') as f:
369
+ f.write("ply\n")
370
+ f.write("format ascii 1.0\n")
371
+ f.write(f"element vertex {len(vertices)}\n")
372
+ f.write("property float x\n")
373
+ f.write("property float y\n")
374
+ f.write("property float z\n")
375
+ f.write("property uchar red\n")
376
+ f.write("property uchar green\n")
377
+ f.write("property uchar blue\n")
378
+ f.write("property uchar alpha\n") # 新增 Alpha 属性
379
+ f.write("end_header\n")
380
+ for i in range(len(vertices)):
381
+ v = vertices[i]
382
+ c = colors_uint8[i] # c is now (R, G, B, A)
383
+ f.write(f"{v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]} {c[3]}\n")
384
+
385
+ def _point_line_segment_distance(self, px, py, pz, x1, y1, z1, x2, y2, z2):
386
+ """
387
+ 计算点 (px,py,pz) 到线段 (x1,y1,z1)-(x2,y2,z2) 的最短距离的平方。
388
+ 全部输入为 Tensor,支持广播。
389
+ """
390
+ # 线段向量 AB
391
+ ABx = x2 - x1
392
+ ABy = y2 - y1
393
+ ABz = z2 - z1
394
+
395
+ # 向量 AP
396
+ APx = px - x1
397
+ APy = py - y1
398
+ APz = pz - z1
399
+
400
+ # AB 的长度平方
401
+ AB_sq = ABx**2 + ABy**2 + ABz**2
402
+
403
+ # 避免除以0 (如果两端点重合)
404
+ AB_sq = torch.clamp(AB_sq, min=1e-6)
405
+
406
+ # 投影系数 t = (AP · AB) / |AB|^2
407
+ t = (APx * ABx + APy * ABy + APz * ABz) / AB_sq
408
+
409
+ # 限制 t 在 [0, 1] 之间(线段约束)
410
+ t = torch.clamp(t, 0.0, 1.0)
411
+
412
+ # 最近点 (Projection)
413
+ closestX = x1 + t * ABx
414
+ closestY = y1 + t * ABy
415
+ closestZ = z1 + t * ABz
416
+
417
+ # 距离平方
418
+ dx = px - closestX
419
+ dy = py - closestY
420
+ dz = pz - closestZ
421
+
422
+ return dx**2 + dy**2 + dz**2
423
+
424
+ def _extract_mesh_projection_based(
425
+ self,
426
+ vtx_result: dict,
427
+ edge_result: dict,
428
+ resolution: int = 1024,
429
+ vtx_prob_threshold: float = 0.5,
430
+
431
+ # --- 你的新逻辑参数 ---
432
+ search_radius: float = 128.0, # 1. 候选边最大长度
433
+ project_dist_thresh: float = 1.5, # 2. 投影距离阈值 (管子半径,单位:voxel)
434
+ dir_align_threshold: float = 0.6, # 3. 方向相似度阈值 (cos theta)
435
+ connect_ratio_threshold: float = 0.4, # 4. 最终连接阈值 (匹配点数 / 理论长度)
436
+
437
+ edge_prob_threshold: float = 0.1, # 仅仅用于提取"存在的"体素
438
+ ):
439
+ t_start = time.perf_counter()
440
+
441
+ # ---------------------------------------------------------------------
442
+ # 1. 准备全局数据:提取所有"活着"的 Edge Voxels (作为点云处理)
443
+ # ---------------------------------------------------------------------
444
+ e_probs = torch.sigmoid(edge_result['occ_probs'][:, 0])
445
+ e_coords = edge_result['coords_4d'][:, 1:].float() # (N, 3)
446
+
447
+ # 获取方向向量
448
+ if 'predicted_direction_feats' in edge_result:
449
+ e_dirs = edge_result['predicted_direction_feats'] # (N, 3)
450
+ # 归一化方向
451
+ e_dirs = F.normalize(e_dirs, p=2, dim=1)
452
+ else:
453
+ print("Warning: No direction features, using dummy.")
454
+ e_dirs = torch.zeros_like(e_coords)
455
+
456
+ # 筛选有效的 Edge Voxels (Global Point Cloud)
457
+ valid_mask = e_probs > edge_prob_threshold
458
+
459
+ cloud_coords = e_coords[valid_mask] # (M, 3)
460
+ cloud_dirs = e_dirs[valid_mask] # (M, 3)
461
+
462
+ num_cloud = cloud_coords.shape[0]
463
+ print(f"[Projection] Global active edge voxels: {num_cloud}")
464
+
465
+ if num_cloud == 0:
466
+ return [], []
467
+
468
+ # ---------------------------------------------------------------------
469
+ # 2. 准备顶点和候选边
470
+ # ---------------------------------------------------------------------
471
+ v_probs = torch.sigmoid(vtx_result['occ_probs'][:, 0])
472
+ v_coords = vtx_result['coords_4d'][:, 1:].float()
473
+ v_mask = v_probs > vtx_prob_threshold
474
+ valid_v_coords = v_coords[v_mask] # (V, 3)
475
+
476
+ if valid_v_coords.shape[0] < 2:
477
+ return valid_v_coords.cpu().numpy() / resolution, []
478
+
479
+ # 生成所有可能的候选边 (基于距离粗筛)
480
+ dists = torch.cdist(valid_v_coords, valid_v_coords)
481
+ triu_mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
482
+ cand_mask = (dists < search_radius) & triu_mask
483
+ cand_indices = torch.nonzero(cand_mask, as_tuple=False) # (E_cand, 2)
484
+
485
+ p1s = valid_v_coords[cand_indices[:, 0]] # (E, 3)
486
+ p2s = valid_v_coords[cand_indices[:, 1]] # (E, 3)
487
+
488
+ num_candidates = p1s.shape[0]
489
+ print(f"[Projection] Checking {num_candidates} candidate pairs...")
490
+
491
+ # ---------------------------------------------------------------------
492
+ # 3. 循环处理候选边 (使用 Bounding Box 快速裁剪)
493
+ # ---------------------------------------------------------------------
494
+ final_edges = []
495
+
496
+ # 预计算所有候选边的方向和长度
497
+ edge_vecs = p2s - p1s
498
+ edge_lengths = torch.norm(edge_vecs, dim=1)
499
+ edge_dirs = F.normalize(edge_vecs, p=2, dim=1)
500
+
501
+ # 为了避免显存爆炸,也不要在 Python 里做太慢的循环
502
+ # 我们对点云进行操作太慢,对每一条边去遍历整个点云也太慢。
503
+ # 策略:
504
+ # 我们循环“边”,但在循环内部利用 mask 快速筛选点云。
505
+ # 由于 Python 循环 10000 次会很慢,我们只处理那些有希望的边。
506
+ # 这里为了演示逻辑的准确性,我们使用简单的循环,但在 GPU 上做计算。
507
+
508
+ # 将全局点云拆分到各个坐标轴,便于快速 BBox 筛选
509
+ cx, cy, cz = cloud_coords[:, 0], cloud_coords[:, 1], cloud_coords[:, 2]
510
+
511
+ # 优化:如果候选边太多,可以分块。这里假设边在 5万以内,点在 10万以内,可以处理。
512
+
513
+ # 这一步是瓶颈,我们尝试用 Python 循环,但只对局部点计算
514
+ # 为了加速,我们可以将点云放入 HashGrid 或者只是简单的 BBox Check。
515
+
516
+ # 让我们用简单的逻辑:对于每条边,找出 BBox 内的点,算距离。
517
+ # 这里的 batch_size 是指一次并行处理多少条边
518
+
519
+ batch_size = 128 # 每次处理 128 条边
520
+
521
+ for i in range(0, num_candidates, batch_size):
522
+ end = min(i + batch_size, num_candidates)
523
+
524
+ # 当前批次的边数据
525
+ b_p1 = p1s[i:end] # (B, 3)
526
+ b_p2 = p2s[i:end] # (B, 3)
527
+ b_dirs = edge_dirs[i:end] # (B, 3)
528
+ b_lens = edge_lengths[i:end] # (B,)
529
+
530
+ # --- 步骤 A: 投影 & 距离检查 ---
531
+ # 这是一个 (B, M) 的大矩阵计算,容易 OOM。
532
+ # M (点云数) 可能很大。
533
+ # 解决方法:我们反过来思考。
534
+ # 不计算矩阵,我们只对单个边进行循环?太慢。
535
+
536
+ # 实用优化:只对 bounding box 内的点进行距离计算。
537
+ # 由于 GPU 难以动态索引不规则数据,我们还是逐个边循环比较稳妥,
538
+ # 但为了 Python 速度,必须尽可能向量化。
539
+
540
+ # 这里我采用一种折中方案:逐个处理边,但是利用 torch.where 快速定位。
541
+ # 实际上,对于 Python 里的 for loop,几千次是可以接受的。
542
+
543
+ current_edges_indices = cand_indices[i:end]
544
+
545
+ for j in range(len(b_p1)):
546
+ # 单条边处理
547
+ p1 = b_p1[j]
548
+ p2 = b_p2[j]
549
+ e_dir = b_dirs[j]
550
+ e_len = b_lens[j].item()
551
+
552
+ # 1. Bounding Box Filter (快速大幅裁剪)
553
+ # 找出这条边 BBox 范围内的所有点 (+ padding)
554
+ padding = project_dist_thresh + 2.0
555
+ min_xyz = torch.min(p1, p2) - padding
556
+ max_xyz = torch.max(p1, p2) + padding
557
+
558
+ # 利用 boolean mask 筛选
559
+ mask_x = (cx >= min_xyz[0]) & (cx <= max_xyz[0])
560
+ mask_y = (cy >= min_xyz[1]) & (cy <= max_xyz[1])
561
+ mask_z = (cz >= min_xyz[2]) & (cz <= max_xyz[2])
562
+ bbox_mask = mask_x & mask_y & mask_z
563
+
564
+ subset_coords = cloud_coords[bbox_mask]
565
+ subset_dirs = cloud_dirs[bbox_mask]
566
+
567
+ if subset_coords.shape[0] == 0:
568
+ continue
569
+
570
+ # 2. 精确距离计算 (Projection Distance)
571
+ # 计算 subset 中每个点到线段 p1-p2 的距离平方
572
+ dist_sq = self._point_line_segment_distance(
573
+ subset_coords[:, 0], subset_coords[:, 1], subset_coords[:, 2],
574
+ p1[0], p1[1], p1[2],
575
+ p2[0], p2[1], p2[2]
576
+ )
577
+
578
+ # 3. 距离阈值过滤 (Keep voxels inside the tube)
579
+ dist_mask = dist_sq < (project_dist_thresh ** 2)
580
+
581
+ # 获取在管子内部的体素
582
+ tube_dirs = subset_dirs[dist_mask]
583
+
584
+ if tube_dirs.shape[0] == 0:
585
+ continue
586
+
587
+ # 4. 方向一致性检查 (Direction Check)
588
+ # 计算点积 (cos theta)
589
+ # e_dir 是 (3,), tube_dirs 是 (K, 3)
590
+ dot_prod = torch.matmul(tube_dirs, e_dir)
591
+
592
+ # 这里使用 abs,因为边可能是无向的,或者网络预测可能反向
593
+ # 如果你的网络严格预测流向,可以去掉 abs
594
+ dir_sim = torch.abs(dot_prod)
595
+
596
+ # 统计方向符合要求的体素数量
597
+ valid_voxel_count = (dir_sim > dir_align_threshold).sum().item()
598
+
599
+ # 5. 比值判决 (Ratio Check)
600
+ # 量化出的 Voxel 数目 ≈ 边的长度 (e_len)
601
+ # 如果 e_len 很小(比如<1),我们设为1防止除以0
602
+ theoretical_count = max(e_len, 1.0)
603
+
604
+ ratio = valid_voxel_count / theoretical_count
605
+
606
+ if ratio > connect_ratio_threshold:
607
+ # 找到了!
608
+ global_idx = i + j
609
+ edge_tuple = cand_indices[global_idx].cpu().numpy().tolist()
610
+ final_edges.append(edge_tuple)
611
+
612
+ t_end = time.perf_counter()
613
+ print(f"[Projection] Logic finished. Accepted {len(final_edges)} edges. Time={t_end - t_start:.4f}s")
614
+
615
+ out_vertices = valid_v_coords.cpu().numpy() / resolution
616
+ return out_vertices, final_edges
617
+
618
+ def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str):
619
+ if coords.numel() == 0:
620
+ return
621
+
622
+ coords_np = coords.cpu().to(torch.float32).numpy()
623
+ labels_np = labels.cpu().to(torch.float32).numpy()
624
+
625
+ colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8)
626
+ colors[labels_np == 0] = [255, 0, 0]
627
+ colors[labels_np == 1] = [0, 0, 255]
628
+
629
+ try:
630
+ import trimesh
631
+ point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors)
632
+ ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
633
+ point_cloud.export(ply_path)
634
+ except ImportError:
635
+ ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply")
636
+ with open(ply_path, 'w') as f:
637
+ f.write("ply\n")
638
+ f.write("format ascii 1.0\n")
639
+ f.write(f"element vertex {coords_np.shape[0]}\n")
640
+ f.write("property float x\n")
641
+ f.write("property float y\n")
642
+ f.write("property float z\n")
643
+ f.write("property uchar red\n")
644
+ f.write("property uchar green\n")
645
+ f.write("property uchar blue\n")
646
+ f.write("end_header\n")
647
+ for i in range(coords_np.shape[0]):
648
+ f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
649
+
650
+ def _load_config(self, config_path=None):
651
+ if config_path and os.path.exists(config_path):
652
+ with open(config_path) as f:
653
+ return yaml.safe_load(f)
654
+
655
+ ckpt_dir = os.path.dirname(self.ckpt_path)
656
+ possible_configs = [
657
+ os.path.join(ckpt_dir, "config.yaml"),
658
+ os.path.join(os.path.dirname(ckpt_dir), "config.yaml")
659
+ ]
660
+
661
+ for config_file in possible_configs:
662
+ if os.path.exists(config_file):
663
+ with open(config_file) as f:
664
+ print(f"Loaded config from: {config_file}")
665
+ return yaml.safe_load(f)
666
+
667
+ checkpoint = torch.load(self.ckpt_path, map_location='cpu')
668
+ if 'config' in checkpoint:
669
+ print("Loaded config from checkpoint")
670
+ return checkpoint['config']
671
+
672
+ raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.")
673
+
674
+ def _init_models(self):
675
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
676
+ in_channels=15,
677
+ hidden_dim=256,
678
+ out_channels=1024,
679
+ scatter_type='mean',
680
+ n_blocks=5,
681
+ resolution=128,
682
+
683
+ ).to(self.device)
684
+
685
+ self.connection_head = ConnectionHead(
686
+ channels=128 * 2,
687
+ out_channels=1,
688
+ mlp_ratio=4,
689
+ ).to(self.device)
690
+
691
+ self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation
692
+ in_channels=self.config['model']['in_channels'],
693
+ latent_dim=self.config['model']['latent_dim'],
694
+ encoder_blocks=self.config['model']['encoder_blocks'],
695
+ # decoder_blocks=self.config['model']['decoder_blocks'],
696
+ decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'],
697
+ decoder_blocks_edge=self.config['model']['decoder_blocks_edge'],
698
+ num_heads=8,
699
+ num_head_channels=64,
700
+ mlp_ratio=4.0,
701
+ attn_mode="swin",
702
+ window_size=8,
703
+ pe_mode="ape",
704
+ use_fp16=False,
705
+ use_checkpoint=False,
706
+ qk_rms_norm=False,
707
+ using_subdivide=True,
708
+ using_attn=self.config['model']['using_attn'],
709
+ attn_first=self.config['model'].get('attn_first', True),
710
+ pred_direction=self.config['model'].get('pred_direction', False),
711
+ ).to(self.device)
712
+
713
+ load_pretrained_woself(
714
+ checkpoint_path=self.ckpt_path,
715
+ voxel_encoder=self.voxel_encoder,
716
+ connection_head=self.connection_head,
717
+ vae=self.vae,
718
+ )
719
+ # --- 【新增】在这里添加权重检查逻辑 ---
720
+ print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---")
721
+ has_nan_inf = False
722
+ if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"):
723
+ has_nan_inf = True
724
+
725
+ if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"):
726
+ has_nan_inf = True
727
+
728
+ if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"):
729
+ has_nan_inf = True
730
+
731
+ if not has_nan_inf:
732
+ print("--- 权重检查通过。未发现 NaN/Inf 值。 ---")
733
+ else:
734
+ # 如果发现坏值,直接抛出异常,因为评估无法继续
735
+ raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。")
736
+ # --- 检查逻辑结束 ---
737
+
738
+ self.vae.eval()
739
+ self.voxel_encoder.eval()
740
+ self.connection_head.eval()
741
+
742
+ def _init_dataset(self):
743
+ self.dataset = VoxelVertexDataset_edge(
744
+ root_dir=self.dataset_path,
745
+ base_resolution=self.config['dataset']['base_resolution'],
746
+ min_resolution=self.config['dataset']['min_resolution'],
747
+ cache_dir='/home/tiger/yy/src/dataset_cache/test_15c_dora',
748
+ # cache_dir=self.config['dataset']['cache_dir'],
749
+ renders_dir=self.config['dataset']['renders_dir'],
750
+
751
+ # filter_active_voxels=self.config['dataset']['filter_active_voxels'],
752
+ filter_active_voxels=False,
753
+ cache_filter_path=self.config['dataset']['cache_filter_path'],
754
+
755
+ sample_type=self.config['dataset']['sample_type'],
756
+ active_voxel_res=128,
757
+ pc_sample_number=819200,
758
+
759
+ )
760
+
761
+ self.dataloader = DataLoader(
762
+ self.dataset,
763
+ batch_size=1,
764
+ shuffle=False,
765
+ collate_fn=partial(collate_fn_pointnet),
766
+ num_workers=0,
767
+ pin_memory=True,
768
+ # prefetch_factor=4,
769
+ )
770
+
771
+ def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool:
772
+ """
773
+ 检查模型的所有参数中是否存在 NaN 或 Inf 值。
774
+
775
+ Args:
776
+ model (torch.nn.Module): 要检查的模型。
777
+ model_name (str): 模型的名称,用于打印日志。
778
+
779
+ Returns:
780
+ bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。
781
+ """
782
+ found_issue = False
783
+ for name, param in model.named_parameters():
784
+ if torch.isnan(param.data).any():
785
+ print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!")
786
+ found_issue = True
787
+ if torch.isinf(param.data).any():
788
+ print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!")
789
+ found_issue = True
790
+ return found_issue
791
+
792
+
793
+ def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
794
+ """
795
+ 修改后的函数,确保一对一匹配,并优先匹配最近的点对。
796
+ """
797
+ pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
798
+ gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
799
+
800
+ pred_total = len(pred_array)
801
+ gt_total = len(gt_array)
802
+
803
+ if pred_total == 0 or gt_total == 0:
804
+ return {
805
+ 'recall': 0.0,
806
+ 'precision': 0.0,
807
+ 'f1': 0.0,
808
+ 'matches': 0,
809
+ 'pred_count': pred_total,
810
+ 'gt_count': gt_total
811
+ }
812
+
813
+ # 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点
814
+ tree = cKDTree(pred_array)
815
+ dists, pred_idxs = tree.query(gt_array, k=1)
816
+
817
+ # --- 核心修改部分 ---
818
+
819
+ # 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引)
820
+ # 这样我们就可以按距离对所有可能的匹配进行排序
821
+ possible_matches = []
822
+ for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)):
823
+ if dist <= threshold:
824
+ possible_matches.append((dist, gt_idx, pred_idx))
825
+
826
+ # 2. 按距离从小到大排序(贪心策略)
827
+ possible_matches.sort(key=lambda x: x[0])
828
+
829
+ matches = 0
830
+ # 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配
831
+ used_pred_indices = set()
832
+ used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨
833
+
834
+ # 3. 遍历排序后的可能匹配,进行一对一分配
835
+ for dist, gt_idx, pred_idx in possible_matches:
836
+ # 如果这个预测点和这个真实点都还没有被使用过
837
+ if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices:
838
+ matches += 1
839
+ used_pred_indices.add(pred_idx)
840
+ used_gt_indices.add(gt_idx)
841
+
842
+ # --- 修改结束 ---
843
+
844
+ # matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total
845
+ recall = matches / gt_total if gt_total > 0 else 0.0
846
+ precision = matches / pred_total if pred_total > 0 else 0.0
847
+
848
+ # 计算F1时,使用标准的 Precision 和 Recall 定义
849
+ if (precision + recall) == 0:
850
+ f1 = 0.0
851
+ else:
852
+ f1 = 2 * (precision * recall) / (precision + recall)
853
+
854
+ return {
855
+ 'recall': recall,
856
+ 'precision': precision,
857
+ 'f1': f1,
858
+ 'matches': matches,
859
+ 'pred_count': pred_total,
860
+ 'gt_count': gt_total
861
+ }
862
+
863
+ def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0):
864
+ """
865
+ 一个折衷的顶点指标计算方案。
866
+ 它沿用“为每个真实点寻找最近预测点”的逻辑,
867
+ 但通过修正计算方式,确保Precision和F1值不会超过1.0。
868
+ """
869
+ # 假设 pred_coords 和 gt_coords 是 PyTorch 张量
870
+ pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0)
871
+ gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0)
872
+
873
+ pred_total = len(pred_array)
874
+ gt_total = len(gt_array)
875
+
876
+ if pred_total == 0 or gt_total == 0:
877
+ return {
878
+ 'recall': 0.0,
879
+ 'precision': 0.0,
880
+ 'f1': 0.0,
881
+ 'matches': 0,
882
+ 'pred_count': pred_total,
883
+ 'gt_count': gt_total
884
+ }
885
+
886
+ # 在预测点上构建KD-Tree,为每个真实点查找最近的预测点
887
+ tree = cKDTree(pred_array)
888
+ dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引
889
+
890
+ # 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall)
891
+ # 这和您的第一个函数完全一样。
892
+ # 这个值代表了“有多少个真实点被成功找到了”。
893
+ matches_from_gt = np.sum(dists <= threshold)
894
+
895
+ # 2. 计算 Recall (召回率)
896
+ # 召回率的分母是真实点的总数,所以这里的计算是合理的。
897
+ recall = matches_from_gt / gt_total if gt_total > 0 else 0.0
898
+
899
+ # 3. 计算 Precision (精确率) - ✅ 这是核心修正点
900
+ # 精确率的分母是预测点的总数。
901
+ # 分子(True Positives)不能超过预测点的总数。
902
+ # 因此,我们取 matches_from_gt 和 pred_total 中的较小值。
903
+ # 这解决了 Precision > 1 的问题。
904
+ tp_for_precision = min(matches_from_gt, pred_total)
905
+ precision = tp_for_precision / pred_total if pred_total > 0 else 0.0
906
+
907
+ # 4. 使用标准的F1分数公式
908
+ # 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score,
909
+ # 更常用的是基于 Precision 和 Recall 的调和平均数。
910
+ if (precision + recall) == 0:
911
+ f1 = 0.0
912
+ else:
913
+ f1 = 2 * (precision * recall) / (precision + recall)
914
+
915
+ return {
916
+ 'recall': recall,
917
+ 'precision': precision,
918
+ 'f1': f1,
919
+ 'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察
920
+ 'pred_count': pred_total,
921
+ 'gt_count': gt_total
922
+ }
923
+
924
+ def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False):
925
+ if len(p1) == 0 or len(p2) == 0:
926
+ return float('nan')
927
+
928
+ dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean()
929
+
930
+ if one_sided:
931
+ return dist_p1_p2.item()
932
+ else:
933
+ dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean()
934
+ return (dist_p1_p2 + dist_p2_p1).item() / 2
935
+
936
+ def visualize_latent_space_pca(self, sample_idx: int):
937
+ """
938
+ Encodes a sample, performs PCA on its latent features, and saves a
939
+ colored PLY file for visualization.
940
+
941
+ The position of each point in the PLY file corresponds to the spatial
942
+ location in the latent grid.
943
+
944
+ The color of each point represents the first three principal components
945
+ of its feature vector.
946
+ """
947
+ print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---")
948
+ self.vae.eval()
949
+
950
+ try:
951
+ # 1. Get the latent representation for the sample
952
+ latent = self._get_latent_for_sample(sample_idx)
953
+ except ValueError as e:
954
+ print(f"Error: {e}")
955
+ return
956
+
957
+ latent_coords = latent.coords.detach().cpu().numpy()
958
+ latent_feats = latent.feats.detach().cpu().numpy()
959
+
960
+ if latent_feats.shape[0] < 3:
961
+ print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.")
962
+ return
963
+
964
+ print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...")
965
+
966
+ # 2. Perform PCA to reduce feature dimensions to 3
967
+ pca = PCA(n_components=3)
968
+ pca_features = pca.fit_transform(latent_feats)
969
+
970
+ print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}")
971
+ print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}")
972
+
973
+ # 3. Normalize the PCA components to be used as RGB colors [0, 255]
974
+ # We normalize each component independently to maximize color contrast
975
+ normalized_colors = np.zeros_like(pca_features)
976
+ for i in range(3):
977
+ min_val = pca_features[:, i].min()
978
+ max_val = pca_features[:, i].max()
979
+ if max_val - min_val > 1e-6:
980
+ normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val)
981
+ else:
982
+ normalized_colors[:, i] = 0.5 # Handle case of constant value
983
+
984
+ colors_uint8 = (normalized_colors * 255).astype(np.uint8)
985
+
986
+ # 4. Prepare spatial coordinates for the point cloud
987
+ # latent_coords is (batch_idx, x, y, z), we want the xyz part
988
+ spatial_coords = latent_coords[:, 1:]
989
+
990
+ # 5. Create and save the colored PLY file
991
+ try:
992
+ # Create a Trimesh PointCloud object
993
+ point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8)
994
+
995
+ # Define the output filename
996
+ filename = f"sample_{sample_idx}_latent_pca.ply"
997
+ ply_path = os.path.join(self.output_voxel_dir, filename)
998
+
999
+ # Export the file
1000
+ point_cloud.export(ply_path)
1001
+ print(f"--> Successfully saved PCA visualization to: {ply_path}")
1002
+
1003
+ except Exception as e:
1004
+ print(f"Error during Trimesh export: {e}")
1005
+ print("Please ensure 'trimesh' is installed correctly.")
1006
+
1007
+ def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor:
1008
+ """
1009
+ Encodes a single sample and returns its latent representation.
1010
+ """
1011
+ print(f"--> Encoding sample {sample_idx} to get its latent vector...")
1012
+ # Get data for the specified sample
1013
+ batch_data = self.dataset[sample_idx]
1014
+ if batch_data is None:
1015
+ raise ValueError(f"Sample at index {sample_idx} could not be loaded.")
1016
+
1017
+ # Use the collate function to form a batch
1018
+ batch_data = collate_fn_pointnet([batch_data])
1019
+
1020
+ with torch.no_grad():
1021
+ active_coords = batch_data['active_voxels_128'].to(self.device)
1022
+ point_cloud = batch_data['point_cloud_128'].to(self.device)
1023
+
1024
+ active_voxel_feats = self.voxel_encoder(
1025
+ p=point_cloud,
1026
+ sparse_coords=active_coords,
1027
+ res=128,
1028
+ bbox_size=(-0.5, 0.5),
1029
+ )
1030
+
1031
+ sparse_input = SparseTensor(
1032
+ feats=active_voxel_feats,
1033
+ coords=active_coords.int()
1034
+ )
1035
+
1036
+ # 2. Encode to get the latent representation
1037
+ latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,)
1038
+ print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}")
1039
+ return latent_128
1040
+
1041
+
1042
+
1043
+ def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.):
1044
+ total_samples = len(self.dataset)
1045
+ eval_samples = min(num_samples or total_samples, total_samples)
1046
+ sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples)
1047
+ # sample_indices = range(eval_samples)
1048
+
1049
+ eval_dataset = Subset(self.dataset, sample_indices)
1050
+ eval_loader = DataLoader(
1051
+ eval_dataset,
1052
+ batch_size=1,
1053
+ shuffle=False,
1054
+ collate_fn=partial(collate_fn_pointnet),
1055
+ num_workers=self.config['training']['num_workers'],
1056
+ pin_memory=True,
1057
+ )
1058
+
1059
+ per_sample_metrics = {
1060
+ 'vertex': {res: [] for res in [128, 256, 512]},
1061
+ 'edge': {res: [] for res in [128, 256, 512]},
1062
+ 'sample_names': []
1063
+ }
1064
+ avg_metrics = {
1065
+ 'vertex': {res: defaultdict(list) for res in [128, 256, 512]},
1066
+ 'edge': {res: defaultdict(list) for res in [128, 256, 512]},
1067
+ }
1068
+
1069
+ self.vae.eval()
1070
+
1071
+ for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")):
1072
+ if batch_data is None:
1073
+ continue
1074
+ sample_idx = sample_indices[batch_idx]
1075
+ sample_name = f'sample_{sample_idx}'
1076
+ per_sample_metrics['sample_names'].append(sample_name)
1077
+
1078
+ batch_save_path = f"/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/143500_sample_active_vis_42seed_trellis/gt_data_batch_{batch_idx}.pt"
1079
+ if not os.path.exists(batch_save_path):
1080
+ print(f"Warning: Saved batch file not found: {batch_save_path}")
1081
+ continue
1082
+ batch_data = torch.load(batch_save_path, map_location=self.device)
1083
+
1084
+ with torch.no_grad():
1085
+ # 1. Get input data
1086
+ combined_voxels_512 = batch_data['combined_voxels_512'].to(self.device)
1087
+ combined_voxel_labels_512 = batch_data['combined_voxel_labels_512'].to(self.device)
1088
+ gt_combined_endpoints_512 = batch_data['gt_combined_endpoints_512'].to(self.device)
1089
+ gt_combined_errors_512 = batch_data['gt_combined_errors_512'].to(self.device)
1090
+
1091
+ edge_mask = (combined_voxel_labels_512 == 1)
1092
+
1093
+ gt_edge_endpoints_512 = gt_combined_endpoints_512[edge_mask].to(self.device)
1094
+
1095
+ gt_edge_voxels_512 = combined_voxels_512[edge_mask].to(self.device)
1096
+
1097
+ p1 = gt_edge_endpoints_512[:, 1:4].float()
1098
+ p2 = gt_edge_endpoints_512[:, 4:7].float()
1099
+
1100
+ mask = ( (p1[:,0] < p2[:,0]) |
1101
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
1102
+ ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )
1103
+
1104
+ pA = torch.where(mask[:, None], p1, p2) # smaller one
1105
+ pB = torch.where(mask[:, None], p2, p1) # larger one
1106
+
1107
+ d = pB - pA
1108
+ dir_gt = F.normalize(d, dim=-1, eps=1e-6)
1109
+
1110
+ gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'].to(self.device).int()
1111
+
1112
+ vtx_128 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=128)
1113
+ vtx_256 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=256)
1114
+
1115
+ edge_128 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=128)
1116
+ edge_256 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=256)
1117
+ edge_512 = combined_voxels_512
1118
+
1119
+
1120
+ gt_edge_voxels_list = [
1121
+ edge_128,
1122
+ edge_256,
1123
+ edge_512,
1124
+ ]
1125
+
1126
+ active_coords = batch_data['active_voxels_128'].to(self.device)
1127
+ point_cloud = batch_data['point_cloud_128'].to(self.device)
1128
+
1129
+
1130
+ active_voxel_feats = self.voxel_encoder(
1131
+ p=point_cloud,
1132
+ sparse_coords=active_coords,
1133
+ res=128,
1134
+ bbox_size=(-0.5, 0.5),
1135
+ )
1136
+
1137
+ sparse_input = SparseTensor(
1138
+ feats=active_voxel_feats,
1139
+ coords=active_coords.int()
1140
+ )
1141
+
1142
+ latent_128, posterior = self.vae.encode(sparse_input)
1143
+
1144
+
1145
+ load_path = f'/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/143500_sample_active_vis_42seed_trellis/sample_latent_{batch_idx}.pt'
1146
+ latent_128 = torch.load(load_path, map_location=self.device)
1147
+
1148
+ print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std())
1149
+ print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean())
1150
+ print('latent_128.coords.shape', latent_128.coords.shape)
1151
+
1152
+
1153
+ latent_128 = SparseTensor(
1154
+ coords=latent_128.coords,
1155
+ feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats),
1156
+ )
1157
+
1158
+ self.output_voxel_dir = os.path.dirname(load_path)
1159
+ self.output_obj_dir = os.path.dirname(load_path)
1160
+
1161
+ # 7. Decoding with separate vertex and edge processing
1162
+ decoded_results = self.vae.decode(
1163
+ latent_128,
1164
+ gt_vertex_voxels_list=[],
1165
+ gt_edge_voxels_list=[],
1166
+ training=False,
1167
+
1168
+ inference_threshold=0.5,
1169
+ vis_last_layer=False,
1170
+ )
1171
+
1172
+ error = 0 # decoded_results[-1]['edge']['predicted_offset_feats']
1173
+
1174
+ if self.config['model'].get('pred_direction', False):
1175
+ pred_dir = decoded_results[-1]['edge']['predicted_direction_feats']
1176
+ zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0
1177
+ num_zeros = zero_mask.sum().item()
1178
+ print("Number of zero vectors:", num_zeros)
1179
+
1180
+ pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
1181
+ print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
1182
+ print('pred_dir.shape', pred_dir.shape)
1183
+ if pred_edge_coords_3d.shape[-1] == 4:
1184
+ pred_edge_coords_3d = pred_edge_coords_3d[:, 1:]
1185
+
1186
+ save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply")
1187
+ visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth)
1188
+
1189
+ save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply")
1190
+ visualize_colored_points_ply((gt_edge_voxels_512[:, 1:]), dir_gt, save_pth)
1191
+
1192
+
1193
+ pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords']
1194
+ pred_edge_coords_3d = decoded_results[-1]['edge']['coords']
1195
+
1196
+
1197
+ gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'][:, 1:].to(self.device)
1198
+ gt_edge_voxels_512 = batch_data['gt_edge_voxels_512'][:, 1:].to(self.device)
1199
+
1200
+
1201
+ # Calculate metrics and save results
1202
+ matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_512, threshold=threshold,)
1203
+ print(f"\n----- Resolution {512} vtx -----")
1204
+ print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
1205
+ print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
1206
+
1207
+ self._save_voxel_ply(pred_vtx_coords_3d / 512., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx")
1208
+ self._save_voxel_ply((pred_edge_coords_3d) / 512, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge")
1209
+
1210
+ self._save_voxel_ply(gt_vertex_voxels_512 / 512, torch.zeros(len(gt_vertex_voxels_512)), f"{sample_name}_gt_vertex")
1211
+ self._save_voxel_ply((combined_voxels_512[:, 1:]) / 512., torch.zeros(len(gt_combined_errors_512)), f"{sample_name}_gt_edge")
1212
+
1213
+
1214
+ # Calculate vertex-specific metrics
1215
+ matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_512[:, 1:], threshold=threshold,)
1216
+ print(f"\n----- Resolution {512} edge -----")
1217
+ print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape)
1218
+ print('gt_edge_voxels_512.shape', gt_edge_voxels_512.shape)
1219
+
1220
+ print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}")
1221
+ print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}")
1222
+
1223
+ pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int)
1224
+ pred_edges = []
1225
+ gt_vertex_coords_np = np.round(gt_vertex_voxels_512.cpu().numpy()).astype(int)
1226
+ if visualize:
1227
+ if pred_vtx_coords_3d.shape[-1] == 4:
1228
+ pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float()
1229
+ else:
1230
+ pred_vtx_coords_float = pred_vtx_coords_3d.float()
1231
+
1232
+ pred_vtx_feats = decoded_results[-1]['vertex']['feats']
1233
+
1234
+ # ==========================================
1235
+ # Link Prediction & Mesh Generation
1236
+ # ==========================================
1237
+ print("Predicting connectivity...")
1238
+
1239
+ # 1. 预测边
1240
+ # 注意:K_neighbors 的设置。如果是物体,64 足够了。
1241
+ # 如果点非常稀疏,可能需要更大。
1242
+ pred_edges = predict_mesh_connectivity(
1243
+ connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的
1244
+ vtx_feats=pred_vtx_feats,
1245
+ vtx_coords=pred_vtx_coords_float,
1246
+ batch_size=4096,
1247
+ threshold=0.5,
1248
+ k_neighbors=None,
1249
+ device=self.device
1250
+ )
1251
+ print(f"Predicted {len(pred_edges)} edges.")
1252
+
1253
+ # 2. 构建三角形
1254
+ num_verts = pred_vtx_coords_float.shape[0]
1255
+ pred_faces = build_triangles_from_edges(pred_edges, num_verts)
1256
+ print(f"Constructed {len(pred_faces)} triangles.")
1257
+
1258
+ # 3. 保存 OBJ
1259
+ import trimesh
1260
+
1261
+ # 坐标归一化/还原 (根据你的需求,这里假设你是 0-512 的体素坐标)
1262
+ # 如果想保存为归一化坐标:
1263
+ mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0
1264
+
1265
+ # 如果有 error offset,记得加上!
1266
+ # 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
1267
+ # 如果 vertex 也有 offset (如 dual contouring),在这里加上
1268
+
1269
+ # 移动到中心 (可选)
1270
+ mesh_verts = mesh_verts - 0.5
1271
+
1272
+ mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
1273
+
1274
+ trimesh.repair.fix_normals(mesh)
1275
+
1276
+ output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj")
1277
+ mesh.export(output_obj_path)
1278
+ print(f"Saved mesh to {output_obj_path}")
1279
+
1280
+ # 保存边线 (用于 Debug)
1281
+ # 有时候三角形很难形成,只看边也很有用
1282
+ edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply")
1283
+ # self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison")
1284
+
1285
+
1286
+ # Process results at different resolutions
1287
+ for i, res in enumerate([128, 256, 512]):
1288
+ if i >= len(decoded_results):
1289
+ continue
1290
+
1291
+ gt_key = f'gt_vertex_voxels_{res}'
1292
+ if gt_key not in batch_data:
1293
+ continue
1294
+ if i == 0:
1295
+ pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float()
1296
+ gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
1297
+ else:
1298
+ pred_coords_res = decoded_results[i]['vertex']['coords'].float()
1299
+ gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device)
1300
+
1301
+
1302
+ v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold)
1303
+
1304
+ per_sample_metrics['vertex'][res].append({
1305
+ 'recall': v_metrics['recall'],
1306
+ 'precision': v_metrics['precision'],
1307
+ 'f1': v_metrics['f1'],
1308
+ 'num_pred': len(pred_coords_res),
1309
+ 'num_gt': len(gt_coords_res)
1310
+ })
1311
+
1312
+ avg_metrics['vertex'][res]['recall'].append(v_metrics['recall'])
1313
+ avg_metrics['vertex'][res]['precision'].append(v_metrics['precision'])
1314
+ avg_metrics['vertex'][res]['f1'].append(v_metrics['f1'])
1315
+
1316
+ gt_edge_key = f'gt_edge_voxels_{res}'
1317
+ if gt_edge_key not in batch_data:
1318
+ continue
1319
+
1320
+ if i == 0:
1321
+ pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float()
1322
+ # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1323
+ idx = i
1324
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1325
+ elif i == 1:
1326
+ idx = i
1327
+ #################################
1328
+ # pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5
1329
+ # # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1330
+ # gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_512[:, 1:].to(self.device) + 0.5
1331
+
1332
+
1333
+ pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
1334
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1335
+
1336
+ # self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset")
1337
+ # self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset")
1338
+
1339
+ else:
1340
+ idx = i
1341
+ pred_edge_coords_res = decoded_results[i]['edge']['coords'].float()
1342
+ # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device)
1343
+ gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device)
1344
+
1345
+ # self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res")
1346
+ # self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res")
1347
+
1348
+ e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold)
1349
+
1350
+ per_sample_metrics['edge'][res].append({
1351
+ 'recall': e_metrics['recall'],
1352
+ 'precision': e_metrics['precision'],
1353
+ 'f1': e_metrics['f1'],
1354
+ 'num_pred': len(pred_edge_coords_res),
1355
+ 'num_gt': len(gt_edge_coords_res)
1356
+ })
1357
+
1358
+ avg_metrics['edge'][res]['recall'].append(e_metrics['recall'])
1359
+ avg_metrics['edge'][res]['precision'].append(e_metrics['precision'])
1360
+ avg_metrics['edge'][res]['f1'].append(e_metrics['f1'])
1361
+
1362
+ avg_metrics_processed = {}
1363
+ for category, res_dict in avg_metrics.items():
1364
+ avg_metrics_processed[category] = {}
1365
+ for resolution, metric_dict in res_dict.items():
1366
+ avg_metrics_processed[category][resolution] = {
1367
+ metric_name: np.mean(values) if values else float('nan')
1368
+ for metric_name, values in metric_dict.items()
1369
+ }
1370
+
1371
+ result_data = {
1372
+ 'config': self.config,
1373
+ 'checkpoint': self.ckpt_path,
1374
+ 'dataset': self.dataset_path,
1375
+ 'num_samples': eval_samples,
1376
+ 'per_sample_metrics': per_sample_metrics,
1377
+ 'avg_metrics': avg_metrics_processed
1378
+ }
1379
+
1380
+ results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml")
1381
+ with open(results_file_path, 'w') as f:
1382
+ yaml.dump(result_data, f, default_flow_style=False)
1383
+
1384
+ return result_data
1385
+
1386
+ def _generate_line_voxels(
1387
+ self,
1388
+ p1: torch.Tensor,
1389
+ p2: torch.Tensor
1390
+ ) -> Tuple[
1391
+ List[Tuple[int, int, int]],
1392
+ List[Tuple[torch.Tensor, torch.Tensor]],
1393
+ List[np.ndarray]
1394
+ ]:
1395
+ """
1396
+ Improved version using better sampling strategy
1397
+ """
1398
+ p1_np = p1 #.cpu().numpy()
1399
+ p2_np = p2 #.cpu().numpy()
1400
+ voxel_dict = OrderedDict()
1401
+
1402
+ # Use proper 3D line voxelization algorithm
1403
+ def bresenham_3d(p1, p2):
1404
+ """3D Bresenham's line algorithm"""
1405
+ x1, y1, z1 = np.round(p1).astype(int)
1406
+ x2, y2, z2 = np.round(p2).astype(int)
1407
+
1408
+ points = []
1409
+ dx = abs(x2 - x1)
1410
+ dy = abs(y2 - y1)
1411
+ dz = abs(z2 - z1)
1412
+
1413
+ xs = 1 if x2 > x1 else -1
1414
+ ys = 1 if y2 > y1 else -1
1415
+ zs = 1 if z2 > z1 else -1
1416
+
1417
+ # Driving axis is X
1418
+ if dx >= dy and dx >= dz:
1419
+ err_1 = 2 * dy - dx
1420
+ err_2 = 2 * dz - dx
1421
+ for i in range(dx + 1):
1422
+ points.append((x1, y1, z1))
1423
+ if err_1 > 0:
1424
+ y1 += ys
1425
+ err_1 -= 2 * dx
1426
+ if err_2 > 0:
1427
+ z1 += zs
1428
+ err_2 -= 2 * dx
1429
+ err_1 += 2 * dy
1430
+ err_2 += 2 * dz
1431
+ x1 += xs
1432
+
1433
+ # Driving axis is Y
1434
+ elif dy >= dx and dy >= dz:
1435
+ err_1 = 2 * dx - dy
1436
+ err_2 = 2 * dz - dy
1437
+ for i in range(dy + 1):
1438
+ points.append((x1, y1, z1))
1439
+ if err_1 > 0:
1440
+ x1 += xs
1441
+ err_1 -= 2 * dy
1442
+ if err_2 > 0:
1443
+ z1 += zs
1444
+ err_2 -= 2 * dy
1445
+ err_1 += 2 * dx
1446
+ err_2 += 2 * dz
1447
+ y1 += ys
1448
+
1449
+ # Driving axis is Z
1450
+ else:
1451
+ err_1 = 2 * dx - dz
1452
+ err_2 = 2 * dy - dz
1453
+ for i in range(dz + 1):
1454
+ points.append((x1, y1, z1))
1455
+ if err_1 > 0:
1456
+ x1 += xs
1457
+ err_1 -= 2 * dz
1458
+ if err_2 > 0:
1459
+ y1 += ys
1460
+ err_2 -= 2 * dz
1461
+ err_1 += 2 * dx
1462
+ err_2 += 2 * dy
1463
+ z1 += zs
1464
+
1465
+ return points
1466
+
1467
+ # Get all voxels using Bresenham algorithm
1468
+ voxel_coords = bresenham_3d(p1_np, p2_np)
1469
+
1470
+ # Add all voxels to dictionary
1471
+ for coord in voxel_coords:
1472
+ voxel_dict[tuple(coord)] = (p1, p2)
1473
+
1474
+ voxel_coords = list(voxel_dict.keys())
1475
+ endpoint_pairs = list(voxel_dict.values())
1476
+
1477
+ # --- compute error vectors ---
1478
+ error_vectors = []
1479
+ diff = p2_np - p1_np
1480
+ d_norm_sq = np.dot(diff, diff)
1481
+
1482
+ for v in voxel_coords:
1483
+ v_center = np.array(v, dtype=float) + 0.5
1484
+ if d_norm_sq == 0: # degenerate line
1485
+ closest = p1_np
1486
+ else:
1487
+ t = np.dot(v_center - p1_np, diff) / d_norm_sq
1488
+ t = np.clip(t, 0.0, 1.0)
1489
+ closest = p1_np + t * diff
1490
+ error_vectors.append(v_center - closest)
1491
+
1492
+ return voxel_coords, endpoint_pairs, error_vectors
1493
+
1494
+
1495
+ # 使用示例
1496
+ def set_seed(seed: int):
1497
+ random.seed(seed)
1498
+ np.random.seed(seed)
1499
+ torch.manual_seed(seed)
1500
+ if torch.cuda.is_available():
1501
+ torch.cuda.manual_seed(seed)
1502
+ torch.cuda.manual_seed_all(seed)
1503
+ torch.backends.cudnn.deterministic = True
1504
+ torch.backends.cudnn.benchmark = False
1505
+
1506
+ def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
1507
+ set_seed(42)
1508
+ tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
1509
+ result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD)
1510
+
1511
+ # 生成文件名
1512
+ epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0]
1513
+ dataset_name = os.path.basename(os.path.normpath(dataset_path))
1514
+
1515
+ # 保存简版报告(TXT)
1516
+ summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt")
1517
+ with open(summary_path, 'w') as f:
1518
+ # 头部信息
1519
+ f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n")
1520
+ f.write(f"Dataset: {dataset_name}\n")
1521
+ f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n")
1522
+
1523
+ # 平均指标
1524
+ f.write("=== Average Metrics ===\n")
1525
+ for category, data in result_data['avg_metrics'].items():
1526
+ if isinstance(data, dict): # 处理多分辨率情况
1527
+ f.write(f"\n{category.upper()}:\n")
1528
+ for res, metrics in data.items():
1529
+ f.write(f" Resolution {res}:\n")
1530
+ for k, v in metrics.items():
1531
+ # 确保值是数字类型后再格式化
1532
+ if isinstance(v, (int, float)):
1533
+ f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
1534
+ else:
1535
+ f.write(f" {str(k).ljust(15)}: {str(v)}\n")
1536
+ else: # 处理非多分辨率情况
1537
+ f.write(f"\n{category.upper()}:\n")
1538
+ for k, v in data.items():
1539
+ if isinstance(v, (int, float)):
1540
+ f.write(f" {str(k).ljust(15)}: {v:.4f}\n")
1541
+ else:
1542
+ f.write(f" {str(k).ljust(15)}: {str(v)}\n")
1543
+
1544
+ # 样本级详细统计
1545
+ f.write("\n\n=== Detailed Per-Sample Metrics ===\n")
1546
+ for name, vertex_metrics, edge_metrics in zip(
1547
+ result_data['per_sample_metrics']['sample_names'],
1548
+ zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256, 512]]),
1549
+ zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256, 512]])
1550
+ ):
1551
+ # 样本标题
1552
+ f.write(f"\n◆ Sample: {name}\n")
1553
+
1554
+ # 顶点指标
1555
+ f.write(f"[Vertex Prediction]\n")
1556
+ f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
1557
+ for res, metrics in zip([128, 256, 512], vertex_metrics):
1558
+ f.write(f" {str(res).ljust(10)} "
1559
+ f"{metrics['recall']:.4f} "
1560
+ f"{metrics['precision']:.4f} "
1561
+ f"{metrics['f1']:.4f} "
1562
+ f"{metrics['num_pred']}/{metrics['num_gt']}\n")
1563
+
1564
+ # Edge指标
1565
+ f.write(f"[Edge Prediction]\n")
1566
+ f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n")
1567
+ for res, metrics in zip([128, 256, 512], edge_metrics):
1568
+ f.write(f" {str(res).ljust(10)} "
1569
+ f"{metrics['recall']:.4f} "
1570
+ f"{metrics['precision']:.4f} "
1571
+ f"{metrics['f1']:.4f} "
1572
+ f"{metrics['num_pred']}/{metrics['num_gt']}\n")
1573
+
1574
+ f.write("-"*60 + "\n")
1575
+
1576
+ print(f"Saved summary to: {summary_path}")
1577
+ return result_data
1578
+
1579
+
1580
+ if __name__ == '__main__':
1581
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
1582
+ evaluate_all_checkpoints = False # 设置为 True 启用范围过滤
1583
+ EPOCH_START = 1
1584
+ EPOCH_END = 12
1585
+ CHAMFER_EDGE_THRESHOLD=0.5
1586
+ NUM_SAMPLES=50
1587
+ VISUALIZE=True
1588
+ THRESHOLD=1.5
1589
+ VISUAL_FIELD=False
1590
+
1591
+ ckpt_path = '/home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt'
1592
+ dataset_path = '/home/tiger/yy/src/trellis_clean_mesh/mesh_data'
1593
+
1594
+ if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
1595
+ RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
1596
+ else:
1597
+ RENDERS_DIR = ''
1598
+
1599
+
1600
+ ckpt_dir = os.path.dirname(ckpt_path)
1601
+ eval_dir = os.path.join(ckpt_dir, "evaluate")
1602
+ os.makedirs(eval_dir, exist_ok=True)
1603
+
1604
+ if False:
1605
+ for i in range(NUM_SAMPLES):
1606
+ print("--- Starting Latent Space PCA Visualization ---")
1607
+ tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path)
1608
+ tester.visualize_latent_space_pca(sample_idx=i)
1609
+ print("--- PCA Visualization Finished ---")
1610
+
1611
+ if not evaluate_all_checkpoints:
1612
+ evaluate_checkpoint(ckpt_path, dataset_path, eval_dir)
1613
+ else:
1614
+ pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
1615
+
1616
+ filtered_pt_files = []
1617
+ for f in pt_files:
1618
+ try:
1619
+ parts = f.split('_')
1620
+ epoch_str = parts[1].replace('epoch', '')
1621
+ epoch = int(epoch_str)
1622
+ if EPOCH_START <= epoch <= EPOCH_END:
1623
+ filtered_pt_files.append(f)
1624
+ except Exception as e:
1625
+ print(f"Warning: Could not parse epoch from {f}: {e}")
1626
+ continue
1627
+
1628
+ for pt_file in filtered_pt_files:
1629
+ full_ckpt_path = os.path.join(ckpt_dir, pt_file)
1630
+ evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir)
test_slat_vae_128to512_pointnet_vae_head.py CHANGED
@@ -744,7 +744,7 @@ class Tester:
744
  root_dir=self.dataset_path,
745
  base_resolution=self.config['dataset']['base_resolution'],
746
  min_resolution=self.config['dataset']['min_resolution'],
747
- cache_dir='/gemini/user/private/zhaotianhao/dataset_cache/test_15c_dora',
748
  # cache_dir=self.config['dataset']['cache_dir'],
749
  renders_dir=self.config['dataset']['renders_dir'],
750
 
@@ -1262,15 +1262,12 @@ class Tester:
1262
  # 如果想保存为归一化坐标:
1263
  mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0
1264
 
1265
- # 如果有 error offset,记得加上!
1266
- # 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了
1267
- # 如果 vertex 也有 offset (如 dual contouring),在这里加上
1268
-
1269
  # 移动到中心 (可选)
1270
  mesh_verts = mesh_verts - 0.5
1271
 
1272
  mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
1273
 
 
1274
  # 过滤孤立点 (可选)
1275
  # mesh.remove_unreferenced_vertices()
1276
 
@@ -1581,22 +1578,23 @@ def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir):
1581
  if __name__ == '__main__':
1582
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
1583
  evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
1584
- EPOCH_START = 0
1585
  EPOCH_END = 12
1586
  CHAMFER_EDGE_THRESHOLD=0.5
1587
- NUM_SAMPLES=20
1588
  VISUALIZE=True
1589
  THRESHOLD=1.5
1590
  VISUAL_FIELD=False
1591
 
1592
- ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch10433_loss1.2657.pt'
1593
- ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch2000_loss0.3315.pt'
1594
-
1595
- dataset_path = '/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/test'
1596
- dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized'
1597
- # dataset_path = '/gemini/user/private/zhaotianhao/data/trellis500k_compress_glb'
1598
- dataset_path = '/gemini/user/private/zhaotianhao/data/unique_files_glb_under6000face_2degree_30ratio_0.01'
1599
 
 
 
 
 
 
 
1600
  if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
1601
  RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
1602
  else:
 
744
  root_dir=self.dataset_path,
745
  base_resolution=self.config['dataset']['base_resolution'],
746
  min_resolution=self.config['dataset']['min_resolution'],
747
+ cache_dir='/home/tiger/yy/src/dataset_cache',
748
  # cache_dir=self.config['dataset']['cache_dir'],
749
  renders_dir=self.config['dataset']['renders_dir'],
750
 
 
1262
  # 如果想保存为归一化坐标:
1263
  mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0
1264
 
 
 
 
 
1265
  # 移动到中心 (可选)
1266
  mesh_verts = mesh_verts - 0.5
1267
 
1268
  mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces)
1269
 
1270
+ trimesh.repair.fix_normals(mesh)
1271
  # 过滤孤立点 (可选)
1272
  # mesh.remove_unreferenced_vertices()
1273
 
 
1578
  if __name__ == '__main__':
1579
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
1580
  evaluate_all_checkpoints = True # 设置为 True 启用范围过滤
1581
+ EPOCH_START = 1
1582
  EPOCH_END = 12
1583
  CHAMFER_EDGE_THRESHOLD=0.5
1584
+ NUM_SAMPLES=50
1585
  VISUALIZE=True
1586
  THRESHOLD=1.5
1587
  VISUAL_FIELD=False
1588
 
1589
+ # ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch10433_loss1.2657.pt'
1590
+ ckpt_path = '/home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch0_batch6000_loss0.1150.pt'
 
 
 
 
 
1591
 
1592
+ dataset_path = '/home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01'
1593
+ # dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized'
1594
+ # # dataset_path = '/gemini/user/private/zhaotianhao/data/trellis500k_compress_glb'
1595
+ # dataset_path = '/gemini/user/private/zhaotianhao/data/unique_files_glb_under6000face_2degree_30ratio_0.01'
1596
+ dataset_path = '/home/tiger/yy/src/trellis_clean_mesh/mesh_data'
1597
+
1598
  if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000':
1599
  RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond'
1600
  else:
train_slat_flow_128to512_pointnet_head.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ['ATTN_BACKEND'] = 'xformers' # xformers is generally compatible with DDP
3
+ # os.environ["OMP_NUM_THREADS"] = "1"
4
+ # os.environ["MKL_NUM_THREADS"] = "1"
5
+ import torch
6
+ import numpy as np
7
+ import yaml
8
+ from torch.utils.data import DataLoader, DistributedSampler
9
+ from functools import partial
10
+ import torch.nn.functional as F
11
+ from torch.optim import AdamW
12
+ from torch.amp import GradScaler, autocast
13
+ from typing import *
14
+ from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+
18
+ # --- Updated Imports based on VAE script ---
19
+ from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet
20
+ from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE
21
+ from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead
22
+ from triposf.modules.sparse.basic import SparseTensor
23
+
24
+ from trellis.models.structured_latent_flow import SLatFlowModel
25
+ from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer
26
+ from safetensors.torch import load_file
27
+ import torch.multiprocessing as mp
28
+ from PIL import Image
29
+ import torch.nn as nn
30
+
31
+ from triposf.modules.utils import DiagonalGaussianDistribution
32
+ import torchvision.transforms as transforms
33
+ import re
34
+ import contextlib
35
+
36
+ # --- Distributed Setup Functions ---
37
+ def setup_distributed(backend="nccl"):
38
+ """Initializes the distributed environment."""
39
+ if not dist.is_initialized():
40
+ rank = int(os.environ["RANK"])
41
+ world_size = int(os.environ["WORLD_SIZE"])
42
+ local_rank = int(os.environ["LOCAL_RANK"])
43
+
44
+ torch.cuda.set_device(local_rank)
45
+ dist.init_process_group(backend=backend)
46
+
47
+ return int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
48
+
49
+ def cleanup_distributed():
50
+ dist.destroy_process_group()
51
+
52
+ # --- Modified Trainer Class ---
53
+ class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer):
54
+ def __init__(self, *args, rank: int, local_rank: int, world_size: int, **kwargs):
55
+ super().__init__(*args, **kwargs)
56
+ self.cfg = kwargs.pop('cfg', None)
57
+ if self.cfg is None:
58
+ raise ValueError("Configuration dictionary 'cfg' must be provided.")
59
+
60
+ # --- Distributed-related attributes ---
61
+ self.rank = rank
62
+ self.local_rank = local_rank
63
+ self.world_size = world_size
64
+ self.device = torch.device(f"cuda:{self.local_rank}")
65
+ self.is_master = (self.rank == 0)
66
+ self.gradient_accumulation_steps = 8
67
+
68
+ self.i_save = self.cfg['training']['save_every']
69
+ self.save_dir = self.cfg['training']['output_dir']
70
+
71
+ self.resolution = 128
72
+ self.condition_type = 'image'
73
+ self.is_cond = False
74
+ self.img_res = 518
75
+
76
+ if self.is_master:
77
+ os.makedirs(self.save_dir, exist_ok=True)
78
+ print(f"Checkpoints and logs will be saved to: {self.save_dir}")
79
+
80
+ # Initialize components and set up for DDP
81
+ self._init_components(
82
+ clip_model_path=self.cfg['training'].get('clip_model_path', None),
83
+ dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None),
84
+ vae_path=self.cfg['training']['vae_path'],
85
+ )
86
+
87
+ self._setup_ddp()
88
+
89
+ self.denoiser_checkpoint_path = self.cfg['training'].get('denoiser_checkpoint_path', None)
90
+
91
+ trainable_params = list(self.denoiser.parameters())
92
+ self.optimizer = AdamW(trainable_params, lr=self.cfg['training'].get('lr', 0.0001), weight_decay=0.0)
93
+
94
+ self.scaler = GradScaler()
95
+
96
+ if self.is_master:
97
+ print("Using Automatic Mixed Precision (AMP) with GradScaler.")
98
+
99
+ def _init_components(self,
100
+ clip_model_path=None,
101
+ dinov2_model_path=None,
102
+ vae_path=None,
103
+ ):
104
+ """
105
+ Initializes VAE, VoxelEncoder (PointNet), and condition models.
106
+ """
107
+ # Use the Dataset from the VAE script
108
+ self.dataset = VoxelVertexDataset_edge(
109
+ root_dir=self.cfg['dataset']['path'],
110
+ base_resolution=self.cfg['dataset']['base_resolution'],
111
+ min_resolution=self.cfg['dataset']['min_resolution'],
112
+ cache_dir=self.cfg['dataset']['cache_dir'],
113
+ renders_dir=self.cfg['dataset']['renders_dir'],
114
+
115
+ filter_active_voxels=self.cfg['dataset']['filter_active_voxels'],
116
+ cache_filter_path=self.cfg['dataset']['cache_filter_path'],
117
+
118
+ active_voxel_res=128,
119
+ pc_sample_number=819200,
120
+
121
+ sample_type=self.cfg['dataset']['sample_type'],
122
+
123
+ )
124
+
125
+ self.sampler = DistributedSampler(
126
+ self.dataset,
127
+ num_replicas=self.world_size,
128
+ rank=self.rank,
129
+ shuffle=True
130
+ )
131
+
132
+ # Use collate_fn_pointnet
133
+ self.dataloader = DataLoader(
134
+ self.dataset,
135
+ batch_size=self.cfg['training']['batch_size'],
136
+ shuffle=False,
137
+ collate_fn=partial(collate_fn_pointnet,),
138
+ num_workers=self.cfg['training']['num_workers'],
139
+ pin_memory=True,
140
+ sampler=self.sampler,
141
+ prefetch_factor=4,
142
+ persistent_workers=True,
143
+ drop_last=True,
144
+ )
145
+
146
+ def load_file_func(path, device='cpu'):
147
+ return torch.load(path, map_location=device)
148
+
149
+ def _load_and_broadcast(model, load_fn=None, path=None, strict=True):
150
+ if self.is_master:
151
+ try:
152
+ state = load_fn(path) if load_fn else model.state_dict()
153
+ except Exception as e:
154
+ raise RuntimeError(f"Failed to load weights from {path}: {e}")
155
+ else:
156
+ state = None
157
+
158
+ dist.barrier()
159
+ state_b = [state] if self.is_master else [None]
160
+ dist.broadcast_object_list(state_b, src=0)
161
+
162
+ try:
163
+ # Handle potential key mismatches (e.g. 'module.' prefix)
164
+ model.load_state_dict(state_b[0], strict=strict)
165
+ except Exception as e:
166
+ if self.is_master: print(f"Strict loading failed for {model.__class__.__name__}, trying non-strict: {e}")
167
+ model.load_state_dict(state_b[0], strict=False)
168
+
169
+ # ------------------------- Voxel Encoder (PointNet) -------------------------
170
+ # Matching the VAE script configuration
171
+ self.voxel_encoder = VoxelFeatureEncoder_active_pointnet(
172
+ in_channels=15,
173
+ hidden_dim=256,
174
+ out_channels=1024,
175
+ scatter_type='mean',
176
+ n_blocks=5,
177
+ resolution=128,
178
+ add_label=False,
179
+ ).to(self.device)
180
+
181
+ # ------------------------- VAE -------------------------
182
+ self.vae = VoxelVAE(
183
+ in_channels=self.cfg['model']['in_channels'],
184
+ latent_dim=self.cfg['model']['latent_dim'],
185
+ encoder_blocks=self.cfg['model']['encoder_blocks'],
186
+ decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'],
187
+ decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'],
188
+ num_heads=8,
189
+ num_head_channels=64,
190
+ mlp_ratio=4.0,
191
+ attn_mode="swin",
192
+ window_size=8,
193
+ pe_mode="ape",
194
+ use_fp16=False,
195
+ use_checkpoint=True,
196
+ qk_rms_norm=False,
197
+ using_subdivide=True,
198
+ using_attn=self.cfg['model']['using_attn'],
199
+ attn_first=self.cfg['model'].get('attn_first', True),
200
+ pred_direction=self.cfg['model'].get('pred_direction', False),
201
+ ).to(self.device)
202
+
203
+
204
+ # ------------------------- Conditioning -------------------------
205
+ if self.condition_type == 'text':
206
+ self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path)
207
+ if self.is_master:
208
+ self.condition_model = CLIPTextModel.from_pretrained(clip_model_path)
209
+ else:
210
+ config = CLIPTextConfig.from_pretrained(clip_model_path)
211
+ self.condition_model = CLIPTextModel(config)
212
+ _load_and_broadcast(self.condition_model)
213
+
214
+ elif self.condition_type == 'image':
215
+ if self.is_master:
216
+ print("Initializing for IMAGE conditioning (DINOv2).")
217
+
218
+ # Update paths as per your environment
219
+ local_repo_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main"
220
+ weights_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth"
221
+
222
+ dinov2_model = torch.hub.load(
223
+ repo_or_dir=local_repo_path,
224
+ model='dinov2_vitl14_reg',
225
+ source='local',
226
+ pretrained=False
227
+ )
228
+ self.condition_model = dinov2_model
229
+
230
+ _load_and_broadcast(self.condition_model, load_fn=torch.load, path=weights_path)
231
+
232
+ self.image_cond_model_transform = transforms.Compose([
233
+ transforms.ToTensor(),
234
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
235
+ ])
236
+ else:
237
+ raise ValueError(f"Unsupported condition type: {self.condition_type}")
238
+
239
+ self.condition_model.to(self.device).eval()
240
+ for p in self.condition_model.parameters(): p.requires_grad = False
241
+
242
+ # ------------------------- Load VAE/Encoder Weights -------------------------
243
+ # Load weights corresponding to the logic in VAE script's `load_pretrained_woself`
244
+ # Assuming checkpoint contains 'vae' and 'voxel_encoder' keys
245
+ _load_and_broadcast(self.vae,
246
+ load_fn=lambda p: load_file_func(p)['vae'],
247
+ path=vae_path)
248
+
249
+ _load_and_broadcast(self.voxel_encoder,
250
+ load_fn=lambda p: load_file_func(p)['voxel_encoder'],
251
+ path=vae_path)
252
+
253
+ self.vae.eval()
254
+ self.voxel_encoder.eval()
255
+ for p in self.vae.parameters(): p.requires_grad = False
256
+ for p in self.voxel_encoder.parameters(): p.requires_grad = False
257
+
258
+ def _load_denoiser(self):
259
+ """Loads a checkpoint for the denoiser."""
260
+ path = self.denoiser_checkpoint_path
261
+ if not path or not os.path.isfile(path):
262
+ if self.is_master:
263
+ print("No valid checkpoint path provided for denoiser. Starting from scratch.")
264
+ return
265
+
266
+ if self.is_master:
267
+ print(f"Loading denoiser checkpoint from: {path}")
268
+ checkpoint = torch.load(path, map_location=self.device)
269
+ else:
270
+ checkpoint = None
271
+
272
+ dist.barrier()
273
+ dist_list = [checkpoint] if self.is_master else [None]
274
+ dist.broadcast_object_list(dist_list, src=0)
275
+ checkpoint = dist_list[0]
276
+
277
+ try:
278
+ self.denoiser.module.load_state_dict(checkpoint['denoiser'])
279
+ if self.is_master: print("Denoiser weights loaded successfully.")
280
+ except Exception as e:
281
+ if self.is_master: print(f"[ERROR] Failed to load denoiser state_dict: {e}")
282
+
283
+ if 'step' in checkpoint and self.is_master:
284
+ print(f"Checkpoint from step {checkpoint['step']}.")
285
+
286
+ dist.barrier()
287
+
288
+ def _setup_ddp(self):
289
+ """Sets up DDP and DataLoaders."""
290
+ self.denoiser = self.denoiser.to(self.device)
291
+ self.denoiser = DDP(self.denoiser, device_ids=[self.local_rank])
292
+
293
+ for param in self.denoiser.parameters():
294
+ param.requires_grad = True
295
+
296
+ @torch.no_grad()
297
+ def encode_image(self, images) -> torch.Tensor:
298
+ if isinstance(images, torch.Tensor):
299
+ batch_tensor = images.to(self.device)
300
+ elif isinstance(images, list):
301
+ assert all(isinstance(i, Image.Image) for i in images), "Image list should be list of PIL images"
302
+ image = [i.resize((518, 518), Image.LANCZOS) for i in images]
303
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
304
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
305
+ batch_tensor = torch.stack(image).to(self.device)
306
+ else:
307
+ raise ValueError(f"Unsupported type of image: {type(image)}")
308
+
309
+ if batch_tensor.shape[-2:] != (518, 518):
310
+ batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False)
311
+
312
+ features = self.condition_model(batch_tensor, is_training=True)['x_prenorm']
313
+ patchtokens = F.layer_norm(features, features.shape[-1:])
314
+ return patchtokens
315
+
316
+ def process_batch(self, batch):
317
+ preprocessed_images = batch['image']
318
+ cond_ = self.encode_image(preprocessed_images)
319
+ return cond_
320
+
321
+ def train(self, num_epochs=1000):
322
+ # 1. 无条件生成的准备工作 (和之前一样)
323
+ if self.is_cond == False:
324
+ if self.condition_type == 'text':
325
+ txt = ['']
326
+ encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
327
+ tokens = encoding['input_ids'].to(self.device)
328
+ with torch.no_grad():
329
+ cond_ = self.condition_model(input_ids=tokens).last_hidden_state
330
+ else:
331
+ blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8))
332
+ with torch.no_grad():
333
+ dummy_cond = self.encode_image([blank_img])
334
+ cond_ = torch.zeros_like(dummy_cond)
335
+ if self.is_master: print(f"Generated unconditional image prompt with shape: {cond_.shape}")
336
+
337
+ self._load_denoiser()
338
+ self.denoiser.train()
339
+
340
+ # 获取全局步数
341
+ global_step = 0
342
+ if self.denoiser_checkpoint_path:
343
+ match = re.search(r'step(\d+)', self.denoiser_checkpoint_path)
344
+ if match:
345
+ global_step = int(match.group(1))
346
+
347
+ accum_steps = self.gradient_accumulation_steps
348
+ if self.is_master:
349
+ print(f"Training with Gradient Accumulation Steps: {accum_steps}")
350
+
351
+ # 确保循环开始前梯度清零
352
+ self.optimizer.zero_grad(set_to_none=True)
353
+
354
+ for epoch in range(num_epochs):
355
+ self.dataloader.sampler.set_epoch(epoch)
356
+ epoch_losses = []
357
+
358
+ # 遍历数据
359
+ for i, batch in enumerate(self.dataloader):
360
+
361
+ # --- A. 数据准备 ---
362
+ if self.is_cond and self.condition_type == 'image':
363
+ cond_ = self.process_batch(batch)
364
+
365
+ point_cloud = batch['point_cloud_128'].to(self.device)
366
+ active_coords = batch['active_voxels_128'].to(self.device)
367
+
368
+ batch_size = int(active_coords[:, 0].max().item() + 1)
369
+ if cond_.shape[0] != batch_size:
370
+ cond_ = cond_.expand(batch_size, -1, -1).contiguous().to(self.device)
371
+ else:
372
+ cond_ = cond_.to(self.device)
373
+
374
+ # --- B. 前向传播 & Loss计算 ---
375
+ with autocast(device_type='cuda', dtype=torch.bfloat16):
376
+ with torch.no_grad():
377
+ active_voxel_feats = self.voxel_encoder(
378
+ p=point_cloud,
379
+ sparse_coords=active_coords,
380
+ res=128,
381
+ bbox_size=(-0.5, 0.5),
382
+ )
383
+ sparse_input = SparseTensor(
384
+ feats=active_voxel_feats,
385
+ coords=active_coords.int()
386
+ )
387
+ latent_128, posterior = self.vae.encode(sparse_input)
388
+
389
+ terms, _ = self.training_losses(x_0=latent_128, cond=cond_)
390
+ loss = terms['loss']
391
+
392
+ # [重点] Loss 除以累积步数
393
+ loss = loss / accum_steps
394
+
395
+ # --- C. 反向传播 ---
396
+ # 注意:这里没有 no_sync,每次都会同步梯度
397
+ self.scaler.scale(loss).backward()
398
+
399
+ # 记录还原后的 Loss 用于显示
400
+ current_real_loss = loss.item() * accum_steps
401
+ epoch_losses.append(current_real_loss)
402
+
403
+ # --- D. 梯度累积判断与更新 ---
404
+ if (i + 1) % accum_steps == 0:
405
+ # 如果需要 clip_grad_norm,可以在这里加:
406
+ # self.scaler.unscale_(self.optimizer)
407
+ # torch.nn.utils.clip_grad_norm_(self.denoiser.parameters(), max_norm=1.0)
408
+
409
+ self.scaler.step(self.optimizer)
410
+ self.scaler.update()
411
+ self.optimizer.zero_grad(set_to_none=True)
412
+
413
+ global_step += 1
414
+
415
+ # --- Logging (只在更新步进行) ---
416
+ if self.is_master:
417
+ if global_step % 10 == 0:
418
+ print(f"Epoch {epoch+1} Step {global_step}: "
419
+ f"Batch_Loss = {current_real_loss:.4f}, "
420
+ f"Epoch_Mean = {np.mean(epoch_losses):.4f}")
421
+
422
+ if global_step % self.i_save == 0:
423
+ checkpoint = {
424
+ 'denoiser': self.denoiser.module.state_dict(),
425
+ 'step': global_step
426
+ }
427
+ loss_str = f"{np.mean(epoch_losses):.6f}".replace('.', '_')
428
+ save_path = os.path.join(self.save_dir, f"checkpoint_step{global_step}_loss{loss_str}.pt")
429
+ torch.save(checkpoint, save_path)
430
+ print(f"Saved checkpoint to {save_path}")
431
+
432
+ # --- E. 处理 Epoch 结束时的残留 Batch (Leftovers) ---
433
+ # 如果 dataloader 长度不能被 accum_steps 整除,且 drop_last=False,
434
+ # 这里需要把最后累积的一点梯度更新掉。
435
+ # (如果你的 dataloader 设置了 drop_last=True 且总数够整除,这里不会触发,但写上比较保险)
436
+ if (i + 1) % accum_steps != 0:
437
+ self.scaler.step(self.optimizer)
438
+ self.scaler.update()
439
+ self.optimizer.zero_grad(set_to_none=True)
440
+ # 注意:这里通常不增加 global_step 或者看你习惯,
441
+ # 因为这是一个“不完整”的 step,通常梯度也是不对等的(因为除数还是accum_steps)
442
+ # 所以很多实现为了稳定,直接设置 drop_last=True 避开这种情况。
443
+
444
+ if self.is_master:
445
+ avg_loss = np.mean(epoch_losses) if epoch_losses else 0
446
+ log_path = os.path.join(self.save_dir, "loss_log.txt")
447
+ with open(log_path, "a") as f:
448
+ f.write(f"Epoch {epoch+1}, Step {global_step}, AvgLoss {avg_loss:.6f}\n")
449
+
450
+ dist.barrier()
451
+
452
+ if self.is_master:
453
+ print("Training complete.")
454
+
455
+ def main():
456
+ # if mp.get_start_method(allow_none=True) != 'spawn':
457
+ # mp.set_start_method('spawn', force=True)
458
+
459
+ # if mp.get_start_method(allow_none=True) != 'forkserver':
460
+ # mp.set_start_method('forkserver', force=True)
461
+
462
+ rank, local_rank, world_size = setup_distributed()
463
+ torch.manual_seed(42+rank)
464
+ np.random.seed(42+rank)
465
+
466
+ # Path to your config
467
+ config_path = "/home/tiger/yy/src/Michelangelo-master/config_slat_flow_128to512_pointnet_head.yaml"
468
+ with open(config_path) as f:
469
+ cfg = yaml.safe_load(f)
470
+
471
+ # Initialize Flow Model (on CPU first)
472
+ diffusion_model = SLatFlowModel(
473
+ resolution=cfg['flow']['resolution'],
474
+ in_channels=cfg['flow']['in_channels'],
475
+ out_channels=cfg['flow']['out_channels'],
476
+ model_channels=cfg['flow']['model_channels'],
477
+ cond_channels=cfg['flow']['cond_channels'],
478
+ num_blocks=cfg['flow']['num_blocks'],
479
+ num_heads=cfg['flow']['num_heads'],
480
+ mlp_ratio=cfg['flow']['mlp_ratio'],
481
+ patch_size=cfg['flow']['patch_size'],
482
+ num_io_res_blocks=cfg['flow']['num_io_res_blocks'],
483
+ io_block_channels=cfg['flow']['io_block_channels'],
484
+ pe_mode=cfg['flow']['pe_mode'],
485
+ qk_rms_norm=cfg['flow']['qk_rms_norm'],
486
+ qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'],
487
+ use_fp16=cfg['flow'].get('use_fp16', False),
488
+ )
489
+
490
+ torch.manual_seed(42 + rank)
491
+ np.random.seed(42 + rank)
492
+
493
+ trainer = SLatFlowMatchingTrainer(
494
+ denoiser=diffusion_model,
495
+ t_schedule=cfg['t_schedule'],
496
+ sigma_min=cfg['sigma_min'],
497
+ cfg=cfg,
498
+ rank=rank,
499
+ local_rank=local_rank,
500
+ world_size=world_size,
501
+ )
502
+
503
+ trainer.train()
504
+ cleanup_distributed()
505
+
506
+ if __name__ == '__main__':
507
+ main()
trellis/__init__.py CHANGED
@@ -2,5 +2,5 @@ from . import models
2
  from . import modules
3
  from . import pipelines
4
  from . import renderers
5
- from . import representations
6
  from . import utils
 
2
  from . import modules
3
  from . import pipelines
4
  from . import renderers
5
+ # from . import representations
6
  from . import utils
trellis/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/__pycache__/__init__.cpython-310.pyc and b/trellis/__pycache__/__init__.cpython-310.pyc differ
 
trellis/models/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/models/__pycache__/__init__.cpython-310.pyc and b/trellis/models/__pycache__/__init__.cpython-310.pyc differ
 
trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc CHANGED
Binary files a/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc and b/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc differ
 
trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc CHANGED
Binary files a/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc and b/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc differ
 
trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc CHANGED
Binary files a/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc and b/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc differ
 
trellis/modules/__pycache__/norm.cpython-310.pyc CHANGED
Binary files a/trellis/modules/__pycache__/norm.cpython-310.pyc and b/trellis/modules/__pycache__/norm.cpython-310.pyc differ
 
trellis/modules/__pycache__/spatial.cpython-310.pyc CHANGED
Binary files a/trellis/modules/__pycache__/spatial.cpython-310.pyc and b/trellis/modules/__pycache__/spatial.cpython-310.pyc differ
 
trellis/modules/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/trellis/modules/__pycache__/utils.cpython-310.pyc and b/trellis/modules/__pycache__/utils.cpython-310.pyc differ
 
trellis/modules/attention/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc differ
 
trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc CHANGED
Binary files a/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc and b/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc differ
 
trellis/modules/attention/__pycache__/modules.cpython-310.pyc CHANGED
Binary files a/trellis/modules/attention/__pycache__/modules.cpython-310.pyc and b/trellis/modules/attention/__pycache__/modules.cpython-310.pyc differ
 
trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc differ
 
trellis/modules/sparse/__pycache__/basic.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc differ
 
trellis/modules/sparse/__pycache__/linear.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc differ
 
trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc differ
 
trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc differ
 
trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc differ
 
trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc differ
 
trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc differ
 
trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc differ
 
trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc differ
 
trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc differ
 
trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc and b/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc differ
 
trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc differ
 
trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc differ
 
trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc CHANGED
Binary files a/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc differ
 
trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc differ
 
trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc CHANGED
Binary files a/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc differ
 
trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc CHANGED
Binary files a/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc differ
 
trellis/pipelines/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/__pycache__/__init__.cpython-310.pyc and b/trellis/pipelines/__pycache__/__init__.cpython-310.pyc differ
 
trellis/pipelines/__pycache__/base.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/__pycache__/base.cpython-310.pyc and b/trellis/pipelines/__pycache__/base.cpython-310.pyc differ
 
trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc and b/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc differ
 
trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc and b/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc differ
 
trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc differ
 
trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc differ
 
trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc differ
 
trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc differ
 
trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc CHANGED
Binary files a/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc differ
 
trellis/renderers/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/renderers/__pycache__/__init__.cpython-310.pyc and b/trellis/renderers/__pycache__/__init__.cpython-310.pyc differ
 
trellis/representations/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/representations/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/__pycache__/__init__.cpython-310.pyc differ
 
trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc differ
 
trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc CHANGED
Binary files a/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc differ
 
trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc CHANGED
Binary files a/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc differ