diff --git a/config_slat_flow_128to512_pointnet_head.yaml b/config_slat_flow_128to512_pointnet_head.yaml new file mode 100644 index 0000000000000000000000000000000000000000..968970cf10bcb8da84c5a9588105673861fc0760 --- /dev/null +++ b/config_slat_flow_128to512_pointnet_head.yaml @@ -0,0 +1,122 @@ +model: + pred_direction: false + relative_embed: true + using_attn: false + add_block_embed: true + multires: 12 + + embed_dim: 1024 + in_channels: 1024 + model_channels: 384 + latent_dim: 16 + + block_size: 16 + pos_encoding: 'nerf' + attn_first: false + + add_edge_glb_feats: true + add_direction: false + + encoder_blocks: + - in_channels: 1024 + model_channels: 512 + num_blocks: 8 + num_heads: 8 + out_channels: 512 + + decoder_blocks_edge: + - in_channels: 512 + model_channels: 512 + num_blocks: 0 + num_heads: 0 + out_channels: 256 + resolution: 128 + - in_channels: 256 + model_channels: 256 + num_blocks: 0 + num_heads: 0 + out_channels: 128 + resolution: 256 + # - in_channels: 64 + # model_channels: 64 + # num_blocks: 0 + # num_heads: 0 + # out_channels: 32 + # resolution: 512 + + decoder_blocks_vtx: + - in_channels: 512 + model_channels: 512 + num_blocks: 0 + num_heads: 0 + out_channels: 256 + resolution: 128 + - in_channels: 256 + model_channels: 256 + num_blocks: 0 + num_heads: 0 + out_channels: 128 + resolution: 256 + # - in_channels: 64 + # model_channels: 64 + # num_blocks: 0 + # num_heads: 0 + # out_channels: 32 + # resolution: 512 + +"t_schedule": + "name": "logitNormal" + "args": + "mean": 1.0 + "std": 1.0 + +"sigma_min": 1.e-5 + +training: + batch_size: 1 + lr: 1.e-4 + step_size: 20 + gamma: 0.95 + save_every: 500 + start_epoch: 0 + max_epochs: 300 + num_workers: 32 + + output_dir: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope + clip_model_path: None + dinov2_model_path: None + + vae_path: /home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt + denoiser_checkpoint_path: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512/checkpoint_step143000_loss0_736924.pt + + +dataset: + path: /home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01 + cache_dir: /home/tiger/yy/src/dataset_cache/unique_files_glb_under6000face_2degree_30ratio_0.01 + + renders_dir: None + filter_active_voxels: true + cache_filter_path: /home/tiger/yy/src/40w_2000-100000edge_2000-100000active.txt + + base_resolution: 1024 + min_resolution: 128 + + n_train_samples: 1024 + sample_type: dora + +flow: + "resolution": 128 + "in_channels": 16 + "out_channels": 16 + "model_channels": 768 + "cond_channels": 1024 + "num_blocks": 12 + "num_heads": 12 + "mlp_ratio": 4 + "patch_size": 2 + "num_io_res_blocks": 2 + "io_block_channels": [128] + "pe_mode": "rope" + "qk_rms_norm": true + "qk_rms_norm_cross": false + "use_fp16": false \ No newline at end of file diff --git a/config_slat_flow_128to512_pointnet_head_test.yaml b/config_slat_flow_128to512_pointnet_head_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a73bdc3cb13d41abc6fedb889239a9383ac009bb --- /dev/null +++ b/config_slat_flow_128to512_pointnet_head_test.yaml @@ -0,0 +1,123 @@ +model: + pred_direction: false + relative_embed: true + using_attn: false + add_block_embed: true + multires: 12 + + embed_dim: 1024 + in_channels: 1024 + model_channels: 384 + latent_dim: 16 + + block_size: 16 + pos_encoding: 'nerf' + attn_first: false + + add_edge_glb_feats: true + add_direction: false + + encoder_blocks: + - in_channels: 1024 + model_channels: 512 + num_blocks: 8 + num_heads: 8 + out_channels: 512 + + decoder_blocks_edge: + - in_channels: 512 + model_channels: 512 + num_blocks: 0 + num_heads: 0 + out_channels: 256 + resolution: 128 + - in_channels: 256 + model_channels: 256 + num_blocks: 0 + num_heads: 0 + out_channels: 128 + resolution: 256 + # - in_channels: 64 + # model_channels: 64 + # num_blocks: 0 + # num_heads: 0 + # out_channels: 32 + # resolution: 512 + + decoder_blocks_vtx: + - in_channels: 512 + model_channels: 512 + num_blocks: 0 + num_heads: 0 + out_channels: 256 + resolution: 128 + - in_channels: 256 + model_channels: 256 + num_blocks: 0 + num_heads: 0 + out_channels: 128 + resolution: 256 + # - in_channels: 64 + # model_channels: 64 + # num_blocks: 0 + # num_heads: 0 + # out_channels: 32 + # resolution: 512 + +"t_schedule": + "name": "logitNormal" + "args": + "mean": 1.0 + "std": 1.0 + +"sigma_min": 1.e-5 + +training: + batch_size: 1 + lr: 1.e-4 + step_size: 20 + gamma: 0.95 + save_every: 500 + start_epoch: 0 + max_epochs: 300 + num_workers: 32 + + output_dir: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope + clip_model_path: None + dinov2_model_path: None + + vae_path: /home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt + denoiser_checkpoint_path: /home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512/checkpoint_step143000_loss0_736924.pt + + +dataset: + path: /home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01 + path: /home/tiger/yy/src/trellis_clean_mesh/mesh_data + cache_dir: /home/tiger/yy/src/dataset_cache/unique_files_glb_under6000face_2degree_30ratio_0.01 + + renders_dir: None + filter_active_voxels: false + cache_filter_path: /home/tiger/yy/src/40w_2000-100000edge_2000-100000active.txt + + base_resolution: 1024 + min_resolution: 128 + + n_train_samples: 1024 + sample_type: dora + +flow: + "resolution": 128 + "in_channels": 16 + "out_channels": 16 + "model_channels": 768 + "cond_channels": 1024 + "num_blocks": 12 + "num_heads": 12 + "mlp_ratio": 4 + "patch_size": 2 + "num_io_res_blocks": 2 + "io_block_channels": [128] + "pe_mode": "rope" + "qk_rms_norm": true + "qk_rms_norm_cross": false + "use_fp16": false \ No newline at end of file diff --git a/test_slat_flow_128to512_pointnet_head.py b/test_slat_flow_128to512_pointnet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f07339164e1a6622b396623940a537ec936467 --- /dev/null +++ b/test_slat_flow_128to512_pointnet_head.py @@ -0,0 +1,404 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import yaml +import time +from datetime import datetime +from torch.utils.data import DataLoader +from functools import partial +import torch.nn.functional as F +from torch.amp import GradScaler, autocast +from typing import * +from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig, Dinov2Model, AutoImageProcessor, Dinov2Config +import torch +import re +from utils import load_pretrained_woself + +from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet +from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE +from vertex_encoder import VoxelFeatureEncoder_active_pointnet + +from trellis.models.structured_latent_flow import SLatFlowModel +from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer + +from trellis.pipelines.samplers import FlowEulerSampler +from safetensors.torch import load_file +import open3d as o3d +from PIL import Image + +from triposf.modules.sparse.basic import SparseTensor +from trellis.modules.sparse.basic import SparseTensor as SparseTensor_trellis + +from triposf.modules.utils import DiagonalGaussianDistribution + +from sklearn.decomposition import PCA +import trimesh +import torchvision.transforms as transforms + +# --- Helper Functions --- +def save_colored_ply(points, colors, filename): + if len(points) == 0: + print(f"[Warning] No points to save for {filename}") + return + # Ensure colors are uint8 + if colors.max() <= 1.0: + colors = (colors * 255).astype(np.uint8) + colors = colors.astype(np.uint8) + + # Add Alpha if missing + if colors.shape[1] == 3: + colors = np.hstack([colors, np.full((len(colors), 1), 255, dtype=np.uint8)]) + + cloud = trimesh.PointCloud(points, colors=colors) + cloud.export(filename) + print(f"Saved colored point cloud to {filename}") + +def normalize_to_rgb(features_3d): + min_vals = features_3d.min(axis=0) + max_vals = features_3d.max(axis=0) + range_vals = max_vals - min_vals + range_vals[range_vals == 0] = 1 + normalized = (features_3d - min_vals) / range_vals + return (normalized * 255).astype(np.uint8) + +class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cfg = kwargs.pop('cfg', None) + if self.cfg is None: + raise ValueError("Configuration dictionary 'cfg' must be provided.") + + self.sampler = FlowEulerSampler(sigma_min=1.e-5) + self.device = torch.device("cuda") + + # Based on PointNet Encoder setting + self.resolution = 128 + + self.condition_type = 'image' + self.is_cond = False + + self.img_res = 518 + self.feature_dim = self.cfg['model']['latent_dim'] + + self._init_components( + clip_model_path=self.cfg['training'].get('clip_model_path', None), + dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None), + vae_path=self.cfg['training']['vae_path'], + ) + + # Classifier head removed as it is not part of the Active Voxel pipeline + + def _load_denoiser(self, denoiser_checkpoint_path): + path = denoiser_checkpoint_path + if not path or not os.path.isfile(path): + print("No valid checkpoint path provided for fine-tuning. Starting from scratch.") + return + + print(f"Loading checkpoint from: {path}") + checkpoint = torch.load(path, map_location=self.device) + + try: + denoiser_state_dict = checkpoint['denoiser'] + # Handle DDP prefix + if next(iter(denoiser_state_dict)).startswith('module.'): + denoiser_state_dict = {k[7:]: v for k, v in denoiser_state_dict.items()} + + self.denoiser.load_state_dict(denoiser_state_dict) + print("Denoiser weights loaded successfully.") + except KeyError: + print("[WARN] 'denoiser' key not found in checkpoint. Skipping.") + except Exception as e: + print(f"[ERROR] Failed to load denoiser state_dict: {e}") + + def _init_components(self, + clip_model_path=None, + dinov2_model_path=None, + vae_path=None, + ): + + # 1. Initialize PointNet Voxel Encoder (Matches Training) + self.voxel_encoder = VoxelFeatureEncoder_active_pointnet( + in_channels=15, + hidden_dim=256, + out_channels=1024, + scatter_type='mean', + n_blocks=5, + resolution=128, + add_label=False, + ).to(self.device) + + # 2. Initialize VAE + self.vae = VoxelVAE( + in_channels=self.cfg['model']['in_channels'], + latent_dim=self.cfg['model']['latent_dim'], + encoder_blocks=self.cfg['model']['encoder_blocks'], + decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'], + decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'], + num_heads=8, + num_head_channels=64, + mlp_ratio=4.0, + attn_mode="swin", + window_size=8, + pe_mode="ape", + use_fp16=False, + use_checkpoint=False, + qk_rms_norm=False, + using_subdivide=True, + using_attn=self.cfg['model']['using_attn'], + attn_first=self.cfg['model'].get('attn_first', True), + pred_direction=self.cfg['model'].get('pred_direction', False), + ).to(self.device) + + # 3. Initialize Dataset with collate_fn_pointnet + self.dataset = VoxelVertexDataset_edge( + root_dir=self.cfg['dataset']['path'], + base_resolution=self.cfg['dataset']['base_resolution'], + min_resolution=self.cfg['dataset']['min_resolution'], + cache_dir=self.cfg['dataset']['cache_dir'], + renders_dir=self.cfg['dataset']['renders_dir'], + + process_img=False, + + active_voxel_res=128, + filter_active_voxels=self.cfg['dataset']['filter_active_voxels'], + cache_filter_path=self.cfg['dataset']['cache_filter_path'], + sample_type=self.cfg['dataset'].get('sample_type', 'dora'), + ) + + self.dataloader = DataLoader( + self.dataset, + batch_size=1, + shuffle=True, + collate_fn=partial(collate_fn_pointnet,), # Critical Change + num_workers=0, + pin_memory=True, + persistent_workers=False, + ) + + # 4. Load Pretrained Weights + # Assuming vae_path contains 'voxel_encoder' and 'vae' + print(f"Loading VAE/Encoder from {vae_path}") + ckpt = torch.load(vae_path, map_location='cpu') + + # Load VAE + if 'vae' in ckpt: + self.vae.load_state_dict(ckpt['vae'], strict=False) + else: + self.vae.load_state_dict(ckpt) # Fallback + + # Load Encoder + if 'voxel_encoder' in ckpt: + self.voxel_encoder.load_state_dict(ckpt['voxel_encoder']) + else: + print("[WARN] 'voxel_encoder' not found in checkpoint, random init (BAD for inference).") + + self.voxel_encoder.eval() + self.vae.eval() + + # 5. Initialize Conditioning Model + if self.condition_type == 'text': + self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path) + self.condition_model = CLIPTextModel.from_pretrained(clip_model_path) + elif self.condition_type == 'image': + model_name = 'dinov2_vitl14_reg' + local_repo_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main" + weights_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth" + + dinov2_model = torch.hub.load( + repo_or_dir=local_repo_path, + model=model_name, + source='local', + pretrained=False + ) + self.condition_model = dinov2_model + self.condition_model.load_state_dict(torch.load(weights_path)) + + self.image_cond_model_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + else: + raise ValueError(f"Unsupported condition type: {self.condition_type}") + + self.condition_model.to(self.device).eval() + + @torch.no_grad() + def encode_image(self, images) -> torch.Tensor: + if isinstance(images, torch.Tensor): + batch_tensor = images.to(self.device) + elif isinstance(images, list): + assert all(isinstance(i, Image.Image) for i in images) + image = [i.resize((518, 518), Image.LANCZOS) for i in images] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + batch_tensor = torch.stack(image).to(self.device) + else: + raise ValueError(f"Unsupported type of image: {type(images)}") + + if batch_tensor.shape[-2:] != (518, 518): + batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False) + + features = self.condition_model(batch_tensor, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + def process_batch(self, batch): + preprocessed_images = batch['image'] + cond_ = self.encode_image(preprocessed_images) + return cond_ + + def eval(self): + # Unconditional Setup + if self.is_cond == False: + if self.condition_type == 'text': + txt = [''] + encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt') + tokens = encoding['input_ids'].to(self.device) + with torch.no_grad(): + cond_ = self.condition_model(input_ids=tokens).last_hidden_state + else: + blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)) + with torch.no_grad(): + dummy_cond = self.encode_image([blank_img]) + cond_ = torch.zeros_like(dummy_cond) + print(f"Generated unconditional image prompt (zero tensor) with shape: {cond_.shape}") + + self.denoiser.eval() + + # Load Denoiser Checkpoint + # Update this path to your ACTIVE VOXEL trained checkpoint + checkpoint_path = '/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/checkpoint_step143500_loss0_766792.pt' + + self._load_denoiser(checkpoint_path) + + filename = os.path.basename(checkpoint_path) + match = re.search(r'step(\d+)', filename) + step_str = match.group(1) if match else "eval" + save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_trellis") + # save_dir = os.path.join(os.path.dirname(checkpoint_path), f"{step_str}_sample_active_vis_42seed_40w_train") + os.makedirs(save_dir, exist_ok=True) + print(f"Results will be saved to: {save_dir}") + + for i, batch in enumerate(self.dataloader): + if i > 50: exit() # Visualize first 10 + + if self.is_cond and self.condition_type == 'image': + cond_ = self.process_batch(batch) + + if cond_.shape[0] != 1: + cond_ = cond_.expand(batch['active_voxels_128'].shape[0], -1, -1).contiguous().to(self.device) + else: + cond_ = cond_.to(self.device) + + # --- Data Retrieval (Matches collate_fn_pointnet) --- + point_cloud = batch['point_cloud_128'].to(self.device) + active_coords = batch['active_voxels_128'].to(self.device) # [N, 4] + + with autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.no_grad(): + # 1. Encode Ground Truth Latents + active_voxel_feats = self.voxel_encoder( + p=point_cloud, + sparse_coords=active_coords, + res=128, + bbox_size=(-0.5, 0.5), + ) + + sparse_input = SparseTensor( + feats=active_voxel_feats, + coords=active_coords.int() + ) + + # Encode to get GT distribution + gt_latents, posterior = self.vae.encode(sparse_input) + + print(f"Batch {i}: Active voxels: {active_coords.shape[0]}") + + # 2. Generation / Sampling + # Generate noise on the SAME active coordinates + noise = SparseTensor_trellis( + coords=active_coords.int(), + feats=torch.randn( + active_coords.shape[0], + self.feature_dim, + device=self.device, + ), + ) + + sample_results = self.sampler.sample( + model=self.denoiser.float(), + noise=noise.to(self.device).float(), + cond=cond_.to(self.device).float(), + steps=50, + rescale_t=1.0, + verbose=True, + ) + + generated_sparse_tensor = sample_results.samples + generated_coords = generated_sparse_tensor.coords + generated_features = generated_sparse_tensor.feats + + print('Gen features mean:', generated_features.mean().item(), 'std:', generated_features.std().item()) + print('GT features mean:', gt_latents.feats.mean().item(), 'std:', gt_latents.feats.std().item()) + print('MSE:', F.mse_loss(generated_features, gt_latents.feats).item()) + + # --- Visualization (PCA) --- + gt_feats_np = gt_latents.feats.detach().cpu().numpy() + gen_feats_np = generated_features.detach().cpu().numpy() + coords_np = active_coords[:, 1:4].detach().cpu().numpy() # x, y, z + + print("Visualizing features using PCA...") + pca = PCA(n_components=3) + + # Fit PCA on GT, transform both + pca.fit(gt_feats_np) + gt_feats_3d = pca.transform(gt_feats_np) + gen_feats_3d = pca.transform(gen_feats_np) + + gt_colors = normalize_to_rgb(gt_feats_3d) + gen_colors = normalize_to_rgb(gen_feats_3d) + + # Save PLYs + save_colored_ply(coords_np, gt_colors, os.path.join(save_dir, f"batch_{i}_gt_pca.ply")) + save_colored_ply(coords_np, gen_colors, os.path.join(save_dir, f"batch_{i}_gen_pca.ply")) + + # Save Tensors for further analysis + torch.save(gt_latents, os.path.join(save_dir, f"gt_latent_{i}.pt")) + + torch.save(batch, os.path.join(save_dir, f"gt_data_batch_{i}.pt")) + torch.save(sample_results.samples, os.path.join(save_dir, f"sample_latent_{i}.pt")) + +if __name__ == '__main__': + torch.manual_seed(42) + config_path = "/home/tiger/yy/src/Michelangelo-master/config_slat_flow_128to512_pointnet_head_test.yaml" + with open(config_path) as f: + cfg = yaml.safe_load(f) + + # Initialize Model on CPU first + diffusion_model = SLatFlowModel( + resolution=cfg['flow']['resolution'], + in_channels=cfg['flow']['in_channels'], + out_channels=cfg['flow']['out_channels'], + model_channels=cfg['flow']['model_channels'], + cond_channels=cfg['flow']['cond_channels'], + num_blocks=cfg['flow']['num_blocks'], + num_heads=cfg['flow']['num_heads'], + mlp_ratio=cfg['flow']['mlp_ratio'], + patch_size=cfg['flow']['patch_size'], + num_io_res_blocks=cfg['flow']['num_io_res_blocks'], + io_block_channels=cfg['flow']['io_block_channels'], + pe_mode=cfg['flow']['pe_mode'], + qk_rms_norm=cfg['flow']['qk_rms_norm'], + qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'], + use_fp16=cfg['flow'].get('use_fp16', False), + ).to("cuda" if torch.cuda.is_available() else "cpu") + + trainer = SLatFlowMatchingTrainer( + denoiser=diffusion_model, + t_schedule=cfg['t_schedule'], + sigma_min=cfg['sigma_min'], + cfg=cfg, + ) + + trainer.eval() \ No newline at end of file diff --git a/test_slat_flow_128to512_pointnet_head_tomesh.py b/test_slat_flow_128to512_pointnet_head_tomesh.py new file mode 100644 index 0000000000000000000000000000000000000000..2193643e8c237c35447a36074ea58fe19a9c1093 --- /dev/null +++ b/test_slat_flow_128to512_pointnet_head_tomesh.py @@ -0,0 +1,1630 @@ +import os +import yaml +import torch +import numpy as np +import random +from tqdm import tqdm +from collections import defaultdict +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from torch.utils.data import DataLoader, Subset +from triposf.modules.sparse.basic import SparseTensor + +from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE + +from vertex_encoder import VoxelFeatureEncoder_edge, VoxelFeatureEncoder_vtx, VoxelFeatureEncoder_active, VoxelFeatureEncoder_active_pointnet, ConnectionHead +from utils import load_pretrained_woself + +from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet + + +from functools import partial +import itertools +from typing import List, Tuple, Set +from collections import OrderedDict +from scipy.spatial import cKDTree +from sklearn.neighbors import KDTree + +import trimesh + +import torch +import torch.nn.functional as F +import time + +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt + +import networkx as nx + +def predict_mesh_connectivity( + connection_head, + vtx_feats, + vtx_coords, + batch_size=10000, + threshold=0.5, + k_neighbors=64, # 限制每个点只检测最近的 K 个邻居,设为 -1 则全连接检测 + device='cuda' +): + """ + Args: + connection_head: 训练好的 MLP 模型 + vtx_feats: [N, C] 顶点特征 + vtx_coords: [N, 3] 顶点坐标 (用于 KNN 筛选候选边) + batch_size: MLP 推理的 batch size + threshold: 判定连接的概率阈值 + k_neighbors: K-NN 数量。如果是 None 或 -1,则检测所有 N*(N-1)/2 对。 + """ + num_verts = vtx_feats.shape[0] + if num_verts < 3: + return [], [] # 无法构成三角形 + + connection_head.eval() + + # --- 1. 生成候选边 (Candidate Edges) --- + if k_neighbors is not None and k_neighbors > 0 and k_neighbors < num_verts: + # 策略 A: 局部 KNN (推荐) + # 计算距离矩阵可能会 OOM,使用分块或 KDTree/Faiss,这里用 PyTorch 的 cdist 分块简化版 + # 或者直接暴力 cdist 如果 N < 10000 + + # 为了简单且高效,这里演示简单的 cdist (注意显存) + # 如果 N 很大 (>5000),建议使用 faiss 或 scipy.spatial.cKDTree + dist_mat = torch.cdist(vtx_coords.float(), vtx_coords.float()) # [N, N] + + # 取 topk (smallest distance),排除自己 + # values: [N, K], indices: [N, K] + _, indices = torch.topk(dist_mat, k=k_neighbors + 1, dim=1, largest=False) + neighbor_indices = indices[:, 1:] # 去掉第一列(自己) + + # 构建 source, target 索引 + src = torch.arange(num_verts, device=device).unsqueeze(1).repeat(1, k_neighbors).flatten() + dst = neighbor_indices.flatten() + + # 此时得到的边是双向的 (u->v 和 v->u 可能都存在),为了效率可以去重 + # 但为了利用你的 symmetric MLP,保留双向或者只保留 u < v 均可 + # 这里为了简单,我们生成 u < v 的 mask + mask = src < dst + u_indices = src[mask] + v_indices = dst[mask] + + else: + # 策略 B: 全连接 (O(N^2)) - 仅当 N 较小时使用 + u_indices, v_indices = torch.triu_indices(num_verts, num_verts, offset=1, device=device) + + # --- 2. 批量推理 --- + all_probs = [] + num_candidates = u_indices.shape[0] + + with torch.no_grad(): + for i in range(0, num_candidates, batch_size): + end = min(i + batch_size, num_candidates) + batch_u = u_indices[i:end] + batch_v = v_indices[i:end] + + feat_u = vtx_feats[batch_u] + feat_v = vtx_feats[batch_v] + + # Symmetric Forward (和你训练时保持一致) + # A -> B + input_uv = torch.cat([feat_u, feat_v], dim=-1) + logits_uv = connection_head(input_uv) + + # B -> A + input_vu = torch.cat([feat_v, feat_u], dim=-1) + logits_vu = connection_head(input_vu) + + # Sum logits + logits = (logits_uv + logits_vu) + probs = torch.sigmoid(logits) + all_probs.append(probs) + + all_probs = torch.cat(all_probs).squeeze() # [M] + + # --- 3. 筛选连接边 --- + connected_mask = all_probs > threshold + final_u = u_indices[connected_mask].cpu().numpy() + final_v = v_indices[connected_mask].cpu().numpy() + + edges = np.stack([final_u, final_v], axis=1) # [E, 2] + + return edges + +def build_triangles_from_edges(edges, num_verts): + """ + 从边列表构建三角形。 + 寻找图中所有的 3-Cliques (三元环)。 + 这在图论中是一个经典问题,可以使用 networkx 库。 + """ + if len(edges) == 0: + return np.empty((0, 3), dtype=int) + + G = nx.Graph() + G.add_nodes_from(range(num_verts)) + G.add_edges_from(edges) + + # 寻找所有的 3-cliques (三角形) + # enumerate_all_cliques 返回所有大小的 clique,我们需要过滤大小为 3 的 + # 或者使用 nx.triangles ? 不,那个只返回数量 + # 使用 nx.enumerate_all_cliques 效率可能较低,对于稀疏图还可以 + + # 更快的方法:迭代每条边 (u, v),查找 u 和 v 的公共邻居 w + triangles = [] + adj = [set(G.neighbors(n)) for n in range(num_verts)] + + # 为了避免重复 (u, v, w), (v, w, u)... 我们可以强制 u < v < w + # 既然 edges 已经是 u < v (如果我们之前做了 triu),则只需要找 w > v 且 w in adj[u] + + # 优化算法: + for u, v in edges: + if u > v: u, v = v, u # 确保有序 + + # 找公共邻居 + common = adj[u].intersection(adj[v]) + for w in common: + if w > v: # 强制顺序 u < v < w 防止重复 + triangles.append([u, v, w]) + + return np.array(triangles) + +def downsample_voxels( + voxels: torch.Tensor, + input_resolution: int, + output_resolution: int +) -> torch.Tensor: + if input_resolution % output_resolution != 0: + raise ValueError(f"input_resolution ({input_resolution}) must be divisible " + f"by output_resolution ({output_resolution}).") + + factor = input_resolution // output_resolution + + downsampled_voxels = voxels.clone().to(torch.long) + + downsampled_voxels[:, 1:] //= factor + + unique_downsampled_voxels = torch.unique(downsampled_voxels, dim=0) + return unique_downsampled_voxels + +def visualize_colored_points_ply(coords, vectors, filename): + """ + 可视化点云,并用向量方向的颜色来表示,保存为 PLY 文件。 + + Args: + coords (torch.Tensor or np.ndarray): 3D坐标,形状为 (N, 3)。 + vectors (torch.Tensor or np.ndarray): 方向向量,形状为 (N, 3)。 + filename (str): 保存输出文件的名称,必须是 .ply 格式。 + """ + # 确保输入是 numpy 数组 + if isinstance(coords, torch.Tensor): + coords = coords.detach().cpu().numpy() + if isinstance(vectors, torch.Tensor): + vectors = vectors.detach().cpu().to(torch.float32).numpy() + + # 检查输入数据是否为空,防止崩溃 + if coords.size == 0 or vectors.size == 0: + print(f"警告:输入数据为空,未生成 {filename} 文件。") + return + + # 将向量分量从 [-1, 1] 映射到 [0, 255] + # np.clip 用于将数值限制在 -1 和 1 之间,防止颜色溢出 + # (vectors + 1) 将范围从 [-1, 1] 移动到 [0, 2] + # * 127.5 将范围从 [0, 2] 缩放到 [0, 255] + colors = np.clip((vectors + 1) * 127.5, 0, 255).astype(np.uint8) + + # 创建一个点云对象,并传入颜色信息 + # trimesh.PointCloud 能够自动处理带颜色的点 + points = trimesh.points.PointCloud(coords, colors=colors) + # 导出为 PLY 文件 + points.export(filename, file_type='ply') + print(f"可视化文件已成功保存为: {filename}") + + +def compute_vertex_matching(pred_coords, gt_coords, threshold=1.0): + # 转换为整数坐标并去重 + print('len(pred_coords)', len(pred_coords)) + + pred_array = np.unique(pred_coords.detach().to(torch.float32).cpu().numpy(), axis=0) + gt_array = np.unique(gt_coords.detach().cpu().to(torch.float32).numpy(), axis=0) + print('len(pred_array)', len(pred_array)) + pred_total = len(pred_array) + gt_total = len(gt_array) + + # 如果没有点,直接返回 + if pred_total == 0 or gt_total == 0: + return 0, 0.0, pred_total, gt_total + + # 建立 KDTree(以 gt 为基准) + tree = KDTree(gt_array) + + # 查找预测点到最近的 gt 点 + dist, indices = tree.query(pred_array, k=1) + dist = dist.squeeze() + indices = indices.squeeze() + + # 贪心去重:确保 1 对 1 + matches = 0 + used_gt = set() + for d, idx in zip(dist, indices): + if d <= threshold and idx not in used_gt: + matches += 1 + used_gt.add(idx) + + match_rate = matches / max(gt_total, 1) + + return matches, match_rate, pred_total, gt_total + +def flatten_coords_4d(coords_4d: torch.Tensor): + coords_4d_long = coords_4d.long() + + base_x = 512 + base_y = 512 * 512 + base_z = 512 * 512 * 512 + + flat_coords = coords_4d_long[:, 0] * base_z + \ + coords_4d_long[:, 1] * base_y + \ + coords_4d_long[:, 2] * base_x + \ + coords_4d_long[:, 3] + return flat_coords + +class Tester: + def __init__(self, ckpt_path, config_path=None, dataset_path=None): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.ckpt_path = ckpt_path + + self.config = self._load_config(config_path) + self.dataset_path = dataset_path # or self.config['dataset']['path'] + checkpoint = torch.load(self.ckpt_path, map_location='cpu') + self.epoch = checkpoint.get('epoch', 0) + + self._init_models() + self._init_dataset() + + self.result_dir = os.path.join(os.path.dirname(ckpt_path), "evaluation_results") + os.makedirs(self.result_dir, exist_ok=True) + + dataset_name_clean = os.path.basename(self.dataset_path).replace('.npz', '').replace('.npy', '') + self.output_voxel_dir = os.path.join(os.path.dirname(ckpt_path), + f"epoch_{self.epoch}_{dataset_name_clean}_voxels_0_gs") + os.makedirs(self.output_voxel_dir, exist_ok=True) + + self.output_obj_dir = os.path.join(os.path.dirname(ckpt_path), + f"epoch_{self.epoch}_{dataset_name_clean}_obj_0_gs") + os.makedirs(self.output_obj_dir, exist_ok=True) + + def _save_logit_visualization(self, dense_vol, name, sample_name, ply_threshold=0.01): + """ + 保存 Logit 的 3D .npy 文件、2D 最大投影热力图,以及带颜色和透明度的 3D .ply 点云 + + Args: + dense_vol: (H, W, D) numpy array, values in [0, 1] + name: str (e.g., "edge" or "vertex") + sample_name: str + ply_threshold: float, 只有概率大于此值的点才会被保存 + """ + # 1. 保存原始 Dense 数据 (可选) + npy_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_logits.npy") + # np.save(npy_path, dense_vol) + + # 2. 生成 2D 投影热力图 (保持不变) + proj_x = np.max(dense_vol, axis=0) + proj_y = np.max(dense_vol, axis=1) + proj_z = np.max(dense_vol, axis=2) + + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + im0 = axes[0].imshow(proj_x, cmap='turbo', vmin=0, vmax=1, origin='lower') + axes[0].set_title(f"{name} Max-Proj (YZ)") + im1 = axes[1].imshow(proj_y, cmap='turbo', vmin=0, vmax=1, origin='lower') + axes[1].set_title(f"{name} Max-Proj (XZ)") + im2 = axes[2].imshow(proj_z, cmap='turbo', vmin=0, vmax=1, origin='lower') + axes[2].set_title(f"{name} Max-Proj (XY)") + + fig.colorbar(im2, ax=axes, orientation='vertical', fraction=0.02, pad=0.04) + plt.suptitle(f"{sample_name} - {name} Occupancy Probability") + + png_path = os.path.join(self.output_voxel_dir, f"{sample_name}_{name}_heatmap.png") + plt.savefig(png_path, dpi=150) + plt.close(fig) + + # ------------------------------------------------------------------ + # 3. 保存为带颜色和透明度(RGBA)的 PLY 点云 + # ------------------------------------------------------------------ + # 筛选出概率大于阈值的点坐标 + indices = np.argwhere(dense_vol > ply_threshold) + + if len(indices) > 0: + # 获取这些点的概率值 [0, 1] + values = dense_vol[indices[:, 0], indices[:, 1], indices[:, 2]] + + # 使用 matplotlib 的 colormap 进行颜色映射 + import matplotlib.cm as cm + cmap = cm.get_cmap('turbo') + + # map values [0, 1] to RGBA [0, 1] (N, 4) + colors_float = cmap(values) + + # ------------------------------------------------------- + # 【核心修改】:修改 Alpha 通道 (透明度) + # ------------------------------------------------------- + # 让透明度直接等于概率值。 + # 概率 1.0 -> Alpha 1.0 (完全不透明/颜色深) + # 概率 0.1 -> Alpha 0.1 (非常透明/颜色浅) + colors_float[:, 3] = values + + # 转换为 uint8 [0, 255],保留 4 个通道 (R, G, B, A) + colors_uint8 = (colors_float * 255).astype(np.uint8) + + # 坐标转换 + vertices = indices + + ply_filename = f"{sample_name}_{name}_logits_colored.ply" + ply_save_path = os.path.join(self.output_voxel_dir, ply_filename) + + try: + # 使用 Trimesh 保存 (Trimesh 支持 (N, 4) 的 colors) + pcd = trimesh.points.PointCloud(vertices=vertices, colors=colors_uint8) + pcd.export(ply_save_path) + print(f"Saved colored RGBA logit PLY to {ply_save_path}") + except Exception as e: + print(f"Failed to save PLY with trimesh: {e}") + # Fallback: 手动写入 PLY (需要添加 alpha 属性) + with open(ply_save_path, 'w') as f: + f.write("ply\n") + f.write("format ascii 1.0\n") + f.write(f"element vertex {len(vertices)}\n") + f.write("property float x\n") + f.write("property float y\n") + f.write("property float z\n") + f.write("property uchar red\n") + f.write("property uchar green\n") + f.write("property uchar blue\n") + f.write("property uchar alpha\n") # 新增 Alpha 属性 + f.write("end_header\n") + for i in range(len(vertices)): + v = vertices[i] + c = colors_uint8[i] # c is now (R, G, B, A) + f.write(f"{v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]} {c[3]}\n") + + def _point_line_segment_distance(self, px, py, pz, x1, y1, z1, x2, y2, z2): + """ + 计算点 (px,py,pz) 到线段 (x1,y1,z1)-(x2,y2,z2) 的最短距离的平方。 + 全部输入为 Tensor,支持广播。 + """ + # 线段向量 AB + ABx = x2 - x1 + ABy = y2 - y1 + ABz = z2 - z1 + + # 向量 AP + APx = px - x1 + APy = py - y1 + APz = pz - z1 + + # AB 的长度平方 + AB_sq = ABx**2 + ABy**2 + ABz**2 + + # 避免除以0 (如果两端点重合) + AB_sq = torch.clamp(AB_sq, min=1e-6) + + # 投影系数 t = (AP · AB) / |AB|^2 + t = (APx * ABx + APy * ABy + APz * ABz) / AB_sq + + # 限制 t 在 [0, 1] 之间(线段约束) + t = torch.clamp(t, 0.0, 1.0) + + # 最近点 (Projection) + closestX = x1 + t * ABx + closestY = y1 + t * ABy + closestZ = z1 + t * ABz + + # 距离平方 + dx = px - closestX + dy = py - closestY + dz = pz - closestZ + + return dx**2 + dy**2 + dz**2 + + def _extract_mesh_projection_based( + self, + vtx_result: dict, + edge_result: dict, + resolution: int = 1024, + vtx_prob_threshold: float = 0.5, + + # --- 你的新逻辑参数 --- + search_radius: float = 128.0, # 1. 候选边最大长度 + project_dist_thresh: float = 1.5, # 2. 投影距离阈值 (管子半径,单位:voxel) + dir_align_threshold: float = 0.6, # 3. 方向相似度阈值 (cos theta) + connect_ratio_threshold: float = 0.4, # 4. 最终连接阈值 (匹配点数 / 理论长度) + + edge_prob_threshold: float = 0.1, # 仅仅用于提取"存在的"体素 + ): + t_start = time.perf_counter() + + # --------------------------------------------------------------------- + # 1. 准备全局数据:提取所有"活着"的 Edge Voxels (作为点云处理) + # --------------------------------------------------------------------- + e_probs = torch.sigmoid(edge_result['occ_probs'][:, 0]) + e_coords = edge_result['coords_4d'][:, 1:].float() # (N, 3) + + # 获取方向向量 + if 'predicted_direction_feats' in edge_result: + e_dirs = edge_result['predicted_direction_feats'] # (N, 3) + # 归一化方向 + e_dirs = F.normalize(e_dirs, p=2, dim=1) + else: + print("Warning: No direction features, using dummy.") + e_dirs = torch.zeros_like(e_coords) + + # 筛选有效的 Edge Voxels (Global Point Cloud) + valid_mask = e_probs > edge_prob_threshold + + cloud_coords = e_coords[valid_mask] # (M, 3) + cloud_dirs = e_dirs[valid_mask] # (M, 3) + + num_cloud = cloud_coords.shape[0] + print(f"[Projection] Global active edge voxels: {num_cloud}") + + if num_cloud == 0: + return [], [] + + # --------------------------------------------------------------------- + # 2. 准备顶点和候选边 + # --------------------------------------------------------------------- + v_probs = torch.sigmoid(vtx_result['occ_probs'][:, 0]) + v_coords = vtx_result['coords_4d'][:, 1:].float() + v_mask = v_probs > vtx_prob_threshold + valid_v_coords = v_coords[v_mask] # (V, 3) + + if valid_v_coords.shape[0] < 2: + return valid_v_coords.cpu().numpy() / resolution, [] + + # 生成所有可能的候选边 (基于距离粗筛) + dists = torch.cdist(valid_v_coords, valid_v_coords) + triu_mask = torch.triu(torch.ones_like(dists), diagonal=1).bool() + cand_mask = (dists < search_radius) & triu_mask + cand_indices = torch.nonzero(cand_mask, as_tuple=False) # (E_cand, 2) + + p1s = valid_v_coords[cand_indices[:, 0]] # (E, 3) + p2s = valid_v_coords[cand_indices[:, 1]] # (E, 3) + + num_candidates = p1s.shape[0] + print(f"[Projection] Checking {num_candidates} candidate pairs...") + + # --------------------------------------------------------------------- + # 3. 循环处理候选边 (使用 Bounding Box 快速裁剪) + # --------------------------------------------------------------------- + final_edges = [] + + # 预计算所有候选边的方向和长度 + edge_vecs = p2s - p1s + edge_lengths = torch.norm(edge_vecs, dim=1) + edge_dirs = F.normalize(edge_vecs, p=2, dim=1) + + # 为了避免显存爆炸,也不要在 Python 里做太慢的循环 + # 我们对点云进行操作太慢,对每一条边去遍历整个点云也太慢。 + # 策略: + # 我们循环“边”,但在循环内部利用 mask 快速筛选点云。 + # 由于 Python 循环 10000 次会很慢,我们只处理那些有希望的边。 + # 这里为了演示逻辑的准确性,我们使用简单的循环,但在 GPU 上做计算。 + + # 将全局点云拆分到各个坐标轴,便于快速 BBox 筛选 + cx, cy, cz = cloud_coords[:, 0], cloud_coords[:, 1], cloud_coords[:, 2] + + # 优化:如果候选边太多,可以分块。这里假设边在 5万以内,点在 10万以内,可以处理。 + + # 这一步是瓶颈,我们尝试用 Python 循环,但只对局部点计算 + # 为了加速,我们可以将点云放入 HashGrid 或者只是简单的 BBox Check。 + + # 让我们用简单的逻辑:对于每条边,找出 BBox 内的点,算距离。 + # 这里的 batch_size 是指一次并行处理多少条边 + + batch_size = 128 # 每次处理 128 条边 + + for i in range(0, num_candidates, batch_size): + end = min(i + batch_size, num_candidates) + + # 当前批次的边数据 + b_p1 = p1s[i:end] # (B, 3) + b_p2 = p2s[i:end] # (B, 3) + b_dirs = edge_dirs[i:end] # (B, 3) + b_lens = edge_lengths[i:end] # (B,) + + # --- 步骤 A: 投影 & 距离检查 --- + # 这是一个 (B, M) 的大矩阵计算,容易 OOM。 + # M (点云数) 可能很大。 + # 解决方法:我们反过来思考。 + # 不计算矩阵,我们只对单个边进行循环?太慢。 + + # 实用优化:只对 bounding box 内的点进行距离计算。 + # 由于 GPU 难以动态索引不规则数据,我们还是逐个边循环比较稳妥, + # 但为了 Python 速度,必须尽可能向量化。 + + # 这里我采用一种折中方案:逐个处理边,但是利用 torch.where 快速定位。 + # 实际上,对于 Python 里的 for loop,几千次是可以接受的。 + + current_edges_indices = cand_indices[i:end] + + for j in range(len(b_p1)): + # 单条边处理 + p1 = b_p1[j] + p2 = b_p2[j] + e_dir = b_dirs[j] + e_len = b_lens[j].item() + + # 1. Bounding Box Filter (快速大幅裁剪) + # 找出这条边 BBox 范围内的所有点 (+ padding) + padding = project_dist_thresh + 2.0 + min_xyz = torch.min(p1, p2) - padding + max_xyz = torch.max(p1, p2) + padding + + # 利用 boolean mask 筛选 + mask_x = (cx >= min_xyz[0]) & (cx <= max_xyz[0]) + mask_y = (cy >= min_xyz[1]) & (cy <= max_xyz[1]) + mask_z = (cz >= min_xyz[2]) & (cz <= max_xyz[2]) + bbox_mask = mask_x & mask_y & mask_z + + subset_coords = cloud_coords[bbox_mask] + subset_dirs = cloud_dirs[bbox_mask] + + if subset_coords.shape[0] == 0: + continue + + # 2. 精确距离计算 (Projection Distance) + # 计算 subset 中每个点到线段 p1-p2 的距离平方 + dist_sq = self._point_line_segment_distance( + subset_coords[:, 0], subset_coords[:, 1], subset_coords[:, 2], + p1[0], p1[1], p1[2], + p2[0], p2[1], p2[2] + ) + + # 3. 距离阈值过滤 (Keep voxels inside the tube) + dist_mask = dist_sq < (project_dist_thresh ** 2) + + # 获取在管子内部的体素 + tube_dirs = subset_dirs[dist_mask] + + if tube_dirs.shape[0] == 0: + continue + + # 4. 方向一致性检查 (Direction Check) + # 计算点积 (cos theta) + # e_dir 是 (3,), tube_dirs 是 (K, 3) + dot_prod = torch.matmul(tube_dirs, e_dir) + + # 这里使用 abs,因为边可能是无向的,或者网络预测可能反向 + # 如果你的网络严格预测流向,可以去掉 abs + dir_sim = torch.abs(dot_prod) + + # 统计方向符合要求的体素数量 + valid_voxel_count = (dir_sim > dir_align_threshold).sum().item() + + # 5. 比值判决 (Ratio Check) + # 量化出的 Voxel 数目 ≈ 边的长度 (e_len) + # 如果 e_len 很小(比如<1),我们设为1防止除以0 + theoretical_count = max(e_len, 1.0) + + ratio = valid_voxel_count / theoretical_count + + if ratio > connect_ratio_threshold: + # 找到了! + global_idx = i + j + edge_tuple = cand_indices[global_idx].cpu().numpy().tolist() + final_edges.append(edge_tuple) + + t_end = time.perf_counter() + print(f"[Projection] Logic finished. Accepted {len(final_edges)} edges. Time={t_end - t_start:.4f}s") + + out_vertices = valid_v_coords.cpu().numpy() / resolution + return out_vertices, final_edges + + def _save_voxel_ply(self, coords: torch.Tensor, labels: torch.Tensor, filename: str): + if coords.numel() == 0: + return + + coords_np = coords.cpu().to(torch.float32).numpy() + labels_np = labels.cpu().to(torch.float32).numpy() + + colors = np.zeros((coords_np.shape[0], 3), dtype=np.uint8) + colors[labels_np == 0] = [255, 0, 0] + colors[labels_np == 1] = [0, 0, 255] + + try: + import trimesh + point_cloud = trimesh.PointCloud(vertices=coords_np, colors=colors) + ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply") + point_cloud.export(ply_path) + except ImportError: + ply_path = os.path.join(self.output_voxel_dir, f"{filename}.ply") + with open(ply_path, 'w') as f: + f.write("ply\n") + f.write("format ascii 1.0\n") + f.write(f"element vertex {coords_np.shape[0]}\n") + f.write("property float x\n") + f.write("property float y\n") + f.write("property float z\n") + f.write("property uchar red\n") + f.write("property uchar green\n") + f.write("property uchar blue\n") + f.write("end_header\n") + for i in range(coords_np.shape[0]): + f.write(f"{coords_np[i,0]} {coords_np[i,1]} {coords_np[i,2]} {colors[i,0]} {colors[i,1]} {colors[i,2]}\n") + + def _load_config(self, config_path=None): + if config_path and os.path.exists(config_path): + with open(config_path) as f: + return yaml.safe_load(f) + + ckpt_dir = os.path.dirname(self.ckpt_path) + possible_configs = [ + os.path.join(ckpt_dir, "config.yaml"), + os.path.join(os.path.dirname(ckpt_dir), "config.yaml") + ] + + for config_file in possible_configs: + if os.path.exists(config_file): + with open(config_file) as f: + print(f"Loaded config from: {config_file}") + return yaml.safe_load(f) + + checkpoint = torch.load(self.ckpt_path, map_location='cpu') + if 'config' in checkpoint: + print("Loaded config from checkpoint") + return checkpoint['config'] + + raise FileNotFoundError("Could not find config_edge.yaml in checkpoint directory or parent, and config not saved in checkpoint.") + + def _init_models(self): + self.voxel_encoder = VoxelFeatureEncoder_active_pointnet( + in_channels=15, + hidden_dim=256, + out_channels=1024, + scatter_type='mean', + n_blocks=5, + resolution=128, + + ).to(self.device) + + self.connection_head = ConnectionHead( + channels=128 * 2, + out_channels=1, + mlp_ratio=4, + ).to(self.device) + + self.vae = VoxelVAE( # abalation: VoxelVAE_1volume_dilation + in_channels=self.config['model']['in_channels'], + latent_dim=self.config['model']['latent_dim'], + encoder_blocks=self.config['model']['encoder_blocks'], + # decoder_blocks=self.config['model']['decoder_blocks'], + decoder_blocks_vtx=self.config['model']['decoder_blocks_vtx'], + decoder_blocks_edge=self.config['model']['decoder_blocks_edge'], + num_heads=8, + num_head_channels=64, + mlp_ratio=4.0, + attn_mode="swin", + window_size=8, + pe_mode="ape", + use_fp16=False, + use_checkpoint=False, + qk_rms_norm=False, + using_subdivide=True, + using_attn=self.config['model']['using_attn'], + attn_first=self.config['model'].get('attn_first', True), + pred_direction=self.config['model'].get('pred_direction', False), + ).to(self.device) + + load_pretrained_woself( + checkpoint_path=self.ckpt_path, + voxel_encoder=self.voxel_encoder, + connection_head=self.connection_head, + vae=self.vae, + ) + # --- 【新增】在这里添加权重检查逻辑 --- + print(f"--- 正在检查权重文件中的 NaN/Inf 值... ---") + has_nan_inf = False + if self._check_weights_for_nan_inf(self.vae, "VoxelVAE"): + has_nan_inf = True + + if self._check_weights_for_nan_inf(self.voxel_encoder, "Vertex Encoder"): + has_nan_inf = True + + if self._check_weights_for_nan_inf(self.connection_head, "Connection Head"): + has_nan_inf = True + + if not has_nan_inf: + print("--- 权重检查通过。未发现 NaN/Inf 值。 ---") + else: + # 如果发现坏值,直接抛出异常,因为评估无法继续 + raise ValueError(f"在检查点 '{self.ckpt_path}' 中发现了 NaN 或 Inf 值。请检查导致训练不稳定的权重文件。") + # --- 检查逻辑结束 --- + + self.vae.eval() + self.voxel_encoder.eval() + self.connection_head.eval() + + def _init_dataset(self): + self.dataset = VoxelVertexDataset_edge( + root_dir=self.dataset_path, + base_resolution=self.config['dataset']['base_resolution'], + min_resolution=self.config['dataset']['min_resolution'], + cache_dir='/home/tiger/yy/src/dataset_cache/test_15c_dora', + # cache_dir=self.config['dataset']['cache_dir'], + renders_dir=self.config['dataset']['renders_dir'], + + # filter_active_voxels=self.config['dataset']['filter_active_voxels'], + filter_active_voxels=False, + cache_filter_path=self.config['dataset']['cache_filter_path'], + + sample_type=self.config['dataset']['sample_type'], + active_voxel_res=128, + pc_sample_number=819200, + + ) + + self.dataloader = DataLoader( + self.dataset, + batch_size=1, + shuffle=False, + collate_fn=partial(collate_fn_pointnet), + num_workers=0, + pin_memory=True, + # prefetch_factor=4, + ) + + def _check_weights_for_nan_inf(self, model: torch.nn.Module, model_name: str) -> bool: + """ + 检查模型的所有参数中是否存在 NaN 或 Inf 值。 + + Args: + model (torch.nn.Module): 要检查的模型。 + model_name (str): 模型的名称,用于打印日志。 + + Returns: + bool: 如果找到 NaN 或 Inf,则返回 True,否则返回 False。 + """ + found_issue = False + for name, param in model.named_parameters(): + if torch.isnan(param.data).any(): + print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 NaN 值!") + found_issue = True + if torch.isinf(param.data).any(): + print(f"[!!!] 严重错误: 在模型 '{model_name}' 的参数 '{name}' 中发现 Inf 值!") + found_issue = True + return found_issue + + + def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0): + """ + 修改后的函数,确保一对一匹配,并优先匹配最近的点对。 + """ + pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0) + gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0) + + pred_total = len(pred_array) + gt_total = len(gt_array) + + if pred_total == 0 or gt_total == 0: + return { + 'recall': 0.0, + 'precision': 0.0, + 'f1': 0.0, + 'matches': 0, + 'pred_count': pred_total, + 'gt_count': gt_total + } + + # 依然在预测点上构建KD-Tree,为每个真实点查找最近的预测点 + tree = cKDTree(pred_array) + dists, pred_idxs = tree.query(gt_array, k=1) + + # --- 核心修改部分 --- + + # 1. 创建一个列表,包含 (距离, 真实点索引, 预测点索引) + # 这样我们就可以按距离对所有可能的匹配进行排序 + possible_matches = [] + for gt_idx, (dist, pred_idx) in enumerate(zip(dists, pred_idxs)): + if dist <= threshold: + possible_matches.append((dist, gt_idx, pred_idx)) + + # 2. 按距离从小到大排序(贪心策略) + possible_matches.sort(key=lambda x: x[0]) + + matches = 0 + # 使用集合来跟踪已经使用过的预测点和真实点,确保一对一匹配 + used_pred_indices = set() + used_gt_indices = set() # 虽然当前逻辑下gt不会重复,但加上更严谨 + + # 3. 遍历排序后的可能匹配,进行一对一分配 + for dist, gt_idx, pred_idx in possible_matches: + # 如果这个预测点和这个真实点都还没有被使用过 + if pred_idx not in used_pred_indices and gt_idx not in used_gt_indices: + matches += 1 + used_pred_indices.add(pred_idx) + used_gt_indices.add(gt_idx) + + # --- 修改结束 --- + + # matches 现在是真正的 True Positives 数量,它绝不会超过 pred_total 或 gt_total + recall = matches / gt_total if gt_total > 0 else 0.0 + precision = matches / pred_total if pred_total > 0 else 0.0 + + # 计算F1时,使用标准的 Precision 和 Recall 定义 + if (precision + recall) == 0: + f1 = 0.0 + else: + f1 = 2 * (precision * recall) / (precision + recall) + + return { + 'recall': recall, + 'precision': precision, + 'f1': f1, + 'matches': matches, + 'pred_count': pred_total, + 'gt_count': gt_total + } + + def _compute_vertex_metrics(self, pred_coords, gt_coords, threshold=1.0): + """ + 一个折衷的顶点指标计算方案。 + 它沿用“为每个真实点寻找最近预测点”的逻辑, + 但通过修正计算方式,确保Precision和F1值不会超过1.0。 + """ + # 假设 pred_coords 和 gt_coords 是 PyTorch 张量 + pred_array = np.unique(pred_coords.round().int().cpu().numpy(), axis=0) + gt_array = np.unique(gt_coords.round().int().cpu().numpy(), axis=0) + + pred_total = len(pred_array) + gt_total = len(gt_array) + + if pred_total == 0 or gt_total == 0: + return { + 'recall': 0.0, + 'precision': 0.0, + 'f1': 0.0, + 'matches': 0, + 'pred_count': pred_total, + 'gt_count': gt_total + } + + # 在预测点上构建KD-Tree,为每个真实点查找最近的预测点 + tree = cKDTree(pred_array) + dists, _ = tree.query(gt_array, k=1) # 我们在这里其实不需要 pred 的索引 + + # 1. 计算从 gt 角度出发的匹配数 (True Positives for Recall) + # 这和您的第一个函数完全一样。 + # 这个值代表了“有多少个真实点被成功找到了”。 + matches_from_gt = np.sum(dists <= threshold) + + # 2. 计算 Recall (召回率) + # 召回率的分母是真实点的总数,所以这里的计算是合理的。 + recall = matches_from_gt / gt_total if gt_total > 0 else 0.0 + + # 3. 计算 Precision (精确率) - ✅ 这是核心修正点 + # 精确率的分母是预测点的总数。 + # 分子(True Positives)不能超过预测点的总数。 + # 因此,我们取 matches_from_gt 和 pred_total 中的较小值。 + # 这解决了 Precision > 1 的问题。 + tp_for_precision = min(matches_from_gt, pred_total) + precision = tp_for_precision / pred_total if pred_total > 0 else 0.0 + + # 4. 使用标准的F1分数公式 + # 您原来的 F1 公式 `2 * matches / (pred + gt)` 是 L1-Score, + # 更常用的是基于 Precision 和 Recall 的调和平均数。 + if (precision + recall) == 0: + f1 = 0.0 + else: + f1 = 2 * (precision * recall) / (precision + recall) + + return { + 'recall': recall, + 'precision': precision, + 'f1': f1, + 'matches': matches_from_gt, # 仍然报告原始的匹配数,便于观察 + 'pred_count': pred_total, + 'gt_count': gt_total + } + + def _compute_chamfer_distance(self, p1: torch.Tensor, p2: torch.Tensor, one_sided: bool = False): + if len(p1) == 0 or len(p2) == 0: + return float('nan') + + dist_p1_p2 = torch.min(torch.cdist(p1, p2), dim=1)[0].mean() + + if one_sided: + return dist_p1_p2.item() + else: + dist_p2_p1 = torch.min(torch.cdist(p2, p1), dim=1)[0].mean() + return (dist_p1_p2 + dist_p2_p1).item() / 2 + + def visualize_latent_space_pca(self, sample_idx: int): + """ + Encodes a sample, performs PCA on its latent features, and saves a + colored PLY file for visualization. + + The position of each point in the PLY file corresponds to the spatial + location in the latent grid. + + The color of each point represents the first three principal components + of its feature vector. + """ + print(f"--- Starting Latent Space PCA Visualization for Sample {sample_idx} ---") + self.vae.eval() + + try: + # 1. Get the latent representation for the sample + latent = self._get_latent_for_sample(sample_idx) + except ValueError as e: + print(f"Error: {e}") + return + + latent_coords = latent.coords.detach().cpu().numpy() + latent_feats = latent.feats.detach().cpu().numpy() + + if latent_feats.shape[0] < 3: + print(f"Warning: Not enough latent points ({latent_feats.shape[0]}) to perform PCA. Skipping.") + return + + print(f"--> Performing PCA on {latent_feats.shape[0]} latent vectors of dimension {latent_feats.shape[1]}...") + + # 2. Perform PCA to reduce feature dimensions to 3 + pca = PCA(n_components=3) + pca_features = pca.fit_transform(latent_feats) + + print(f" Explained variance ratio by 3 components: {pca.explained_variance_ratio_}") + print(f" Total explained variance: {np.sum(pca.explained_variance_ratio_):.4f}") + + # 3. Normalize the PCA components to be used as RGB colors [0, 255] + # We normalize each component independently to maximize color contrast + normalized_colors = np.zeros_like(pca_features) + for i in range(3): + min_val = pca_features[:, i].min() + max_val = pca_features[:, i].max() + if max_val - min_val > 1e-6: + normalized_colors[:, i] = (pca_features[:, i] - min_val) / (max_val - min_val) + else: + normalized_colors[:, i] = 0.5 # Handle case of constant value + + colors_uint8 = (normalized_colors * 255).astype(np.uint8) + + # 4. Prepare spatial coordinates for the point cloud + # latent_coords is (batch_idx, x, y, z), we want the xyz part + spatial_coords = latent_coords[:, 1:] + + # 5. Create and save the colored PLY file + try: + # Create a Trimesh PointCloud object + point_cloud = trimesh.points.PointCloud(vertices=spatial_coords, colors=colors_uint8) + + # Define the output filename + filename = f"sample_{sample_idx}_latent_pca.ply" + ply_path = os.path.join(self.output_voxel_dir, filename) + + # Export the file + point_cloud.export(ply_path) + print(f"--> Successfully saved PCA visualization to: {ply_path}") + + except Exception as e: + print(f"Error during Trimesh export: {e}") + print("Please ensure 'trimesh' is installed correctly.") + + def _get_latent_for_sample(self, sample_idx: int) -> SparseTensor: + """ + Encodes a single sample and returns its latent representation. + """ + print(f"--> Encoding sample {sample_idx} to get its latent vector...") + # Get data for the specified sample + batch_data = self.dataset[sample_idx] + if batch_data is None: + raise ValueError(f"Sample at index {sample_idx} could not be loaded.") + + # Use the collate function to form a batch + batch_data = collate_fn_pointnet([batch_data]) + + with torch.no_grad(): + active_coords = batch_data['active_voxels_128'].to(self.device) + point_cloud = batch_data['point_cloud_128'].to(self.device) + + active_voxel_feats = self.voxel_encoder( + p=point_cloud, + sparse_coords=active_coords, + res=128, + bbox_size=(-0.5, 0.5), + ) + + sparse_input = SparseTensor( + feats=active_voxel_feats, + coords=active_coords.int() + ) + + # 2. Encode to get the latent representation + latent_128, posterior = self.vae.encode(sparse_input, sample_posterior=True,) + print(f" Latent for sample {sample_idx} obtained. Shape: {latent_128.feats.shape}") + return latent_128 + + + + def evaluate(self, num_samples=None, visualize=False, chamfer_threshold=0.9, threshold=1.): + total_samples = len(self.dataset) + eval_samples = min(num_samples or total_samples, total_samples) + sample_indices = random.sample(range(total_samples), eval_samples) if num_samples else range(total_samples) + # sample_indices = range(eval_samples) + + eval_dataset = Subset(self.dataset, sample_indices) + eval_loader = DataLoader( + eval_dataset, + batch_size=1, + shuffle=False, + collate_fn=partial(collate_fn_pointnet), + num_workers=self.config['training']['num_workers'], + pin_memory=True, + ) + + per_sample_metrics = { + 'vertex': {res: [] for res in [128, 256, 512]}, + 'edge': {res: [] for res in [128, 256, 512]}, + 'sample_names': [] + } + avg_metrics = { + 'vertex': {res: defaultdict(list) for res in [128, 256, 512]}, + 'edge': {res: defaultdict(list) for res in [128, 256, 512]}, + } + + self.vae.eval() + + for batch_idx, batch_data in enumerate(tqdm(eval_loader, desc="Evaluating")): + if batch_data is None: + continue + sample_idx = sample_indices[batch_idx] + sample_name = f'sample_{sample_idx}' + per_sample_metrics['sample_names'].append(sample_name) + + batch_save_path = f"/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/143500_sample_active_vis_42seed_trellis/gt_data_batch_{batch_idx}.pt" + if not os.path.exists(batch_save_path): + print(f"Warning: Saved batch file not found: {batch_save_path}") + continue + batch_data = torch.load(batch_save_path, map_location=self.device) + + with torch.no_grad(): + # 1. Get input data + combined_voxels_512 = batch_data['combined_voxels_512'].to(self.device) + combined_voxel_labels_512 = batch_data['combined_voxel_labels_512'].to(self.device) + gt_combined_endpoints_512 = batch_data['gt_combined_endpoints_512'].to(self.device) + gt_combined_errors_512 = batch_data['gt_combined_errors_512'].to(self.device) + + edge_mask = (combined_voxel_labels_512 == 1) + + gt_edge_endpoints_512 = gt_combined_endpoints_512[edge_mask].to(self.device) + + gt_edge_voxels_512 = combined_voxels_512[edge_mask].to(self.device) + + p1 = gt_edge_endpoints_512[:, 1:4].float() + p2 = gt_edge_endpoints_512[:, 4:7].float() + + mask = ( (p1[:,0] < p2[:,0]) | + ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) | + ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) ) + + pA = torch.where(mask[:, None], p1, p2) # smaller one + pB = torch.where(mask[:, None], p2, p1) # larger one + + d = pB - pA + dir_gt = F.normalize(d, dim=-1, eps=1e-6) + + gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'].to(self.device).int() + + vtx_128 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=128) + vtx_256 = downsample_voxels(gt_vertex_voxels_512, input_resolution=512, output_resolution=256) + + edge_128 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=128) + edge_256 = downsample_voxels(combined_voxels_512, input_resolution=512, output_resolution=256) + edge_512 = combined_voxels_512 + + + gt_edge_voxels_list = [ + edge_128, + edge_256, + edge_512, + ] + + active_coords = batch_data['active_voxels_128'].to(self.device) + point_cloud = batch_data['point_cloud_128'].to(self.device) + + + active_voxel_feats = self.voxel_encoder( + p=point_cloud, + sparse_coords=active_coords, + res=128, + bbox_size=(-0.5, 0.5), + ) + + sparse_input = SparseTensor( + feats=active_voxel_feats, + coords=active_coords.int() + ) + + latent_128, posterior = self.vae.encode(sparse_input) + + + load_path = f'/home/tiger/yy/src/checkpoints/output_slat_flow_matching_active/40w_128to512_head_rope/143500_sample_active_vis_42seed_trellis/sample_latent_{batch_idx}.pt' + latent_128 = torch.load(load_path, map_location=self.device) + + print('latent_128.feats.mean()', latent_128.feats.mean(), 'latent_128.feats.std()', latent_128.feats.std()) + print('posterior.mean', posterior.mean.mean(), 'posterior.std', posterior.std.mean(), 'posterior.var', posterior.var.mean()) + print('latent_128.coords.shape', latent_128.coords.shape) + + + latent_128 = SparseTensor( + coords=latent_128.coords, + feats=latent_128.feats + 0. * torch.randn_like(latent_128.feats), + ) + + self.output_voxel_dir = os.path.dirname(load_path) + self.output_obj_dir = os.path.dirname(load_path) + + # 7. Decoding with separate vertex and edge processing + decoded_results = self.vae.decode( + latent_128, + gt_vertex_voxels_list=[], + gt_edge_voxels_list=[], + training=False, + + inference_threshold=0.5, + vis_last_layer=False, + ) + + error = 0 # decoded_results[-1]['edge']['predicted_offset_feats'] + + if self.config['model'].get('pred_direction', False): + pred_dir = decoded_results[-1]['edge']['predicted_direction_feats'] + zero_mask = (pred_dir == 0).all(dim=1) # [N],True 表示这一行全为0 + num_zeros = zero_mask.sum().item() + print("Number of zero vectors:", num_zeros) + + pred_edge_coords_3d = decoded_results[-1]['edge']['coords'] + print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape) + print('pred_dir.shape', pred_dir.shape) + if pred_edge_coords_3d.shape[-1] == 4: + pred_edge_coords_3d = pred_edge_coords_3d[:, 1:] + + save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction.ply") + visualize_colored_points_ply(pred_edge_coords_3d, pred_dir, save_pth) + + save_pth = os.path.join(self.output_voxel_dir, f"{sample_name}_direction_gt.ply") + visualize_colored_points_ply((gt_edge_voxels_512[:, 1:]), dir_gt, save_pth) + + + pred_vtx_coords_3d = decoded_results[-1]['vertex']['coords'] + pred_edge_coords_3d = decoded_results[-1]['edge']['coords'] + + + gt_vertex_voxels_512 = batch_data['gt_vertex_voxels_512'][:, 1:].to(self.device) + gt_edge_voxels_512 = batch_data['gt_edge_voxels_512'][:, 1:].to(self.device) + + + # Calculate metrics and save results + matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_vtx_coords_3d, gt_vertex_voxels_512, threshold=threshold,) + print(f"\n----- Resolution {512} vtx -----") + print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}") + print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}") + + self._save_voxel_ply(pred_vtx_coords_3d / 512., torch.zeros(len(pred_vtx_coords_3d)), f"{sample_name}_pred_vtx") + self._save_voxel_ply((pred_edge_coords_3d) / 512, torch.zeros(len(pred_edge_coords_3d)), f"{sample_name}_pred_edge") + + self._save_voxel_ply(gt_vertex_voxels_512 / 512, torch.zeros(len(gt_vertex_voxels_512)), f"{sample_name}_gt_vertex") + self._save_voxel_ply((combined_voxels_512[:, 1:]) / 512., torch.zeros(len(gt_combined_errors_512)), f"{sample_name}_gt_edge") + + + # Calculate vertex-specific metrics + matches, match_rate, pred_total, gt_total = compute_vertex_matching(pred_edge_coords_3d, combined_voxels_512[:, 1:], threshold=threshold,) + print(f"\n----- Resolution {512} edge -----") + print('pred_edge_coords_3d.shape', pred_edge_coords_3d.shape) + print('gt_edge_voxels_512.shape', gt_edge_voxels_512.shape) + + print(f"Pred Vertices: {pred_total} | GT Vertices: {gt_total}") + print(f"Matched Vertices: {matches} | Match Rate: {match_rate:.2%}") + + pred_vertex_coords_np = np.round(pred_vtx_coords_3d.cpu().numpy()).astype(int) + pred_edges = [] + gt_vertex_coords_np = np.round(gt_vertex_voxels_512.cpu().numpy()).astype(int) + if visualize: + if pred_vtx_coords_3d.shape[-1] == 4: + pred_vtx_coords_float = pred_vtx_coords_3d[:, 1:].float() + else: + pred_vtx_coords_float = pred_vtx_coords_3d.float() + + pred_vtx_feats = decoded_results[-1]['vertex']['feats'] + + # ========================================== + # Link Prediction & Mesh Generation + # ========================================== + print("Predicting connectivity...") + + # 1. 预测边 + # 注意:K_neighbors 的设置。如果是物体,64 足够了。 + # 如果点非常稀疏,可能需要更大。 + pred_edges = predict_mesh_connectivity( + connection_head=self.connection_head, # 或者是 self.connection_head,取决于你在哪里定义的 + vtx_feats=pred_vtx_feats, + vtx_coords=pred_vtx_coords_float, + batch_size=4096, + threshold=0.5, + k_neighbors=None, + device=self.device + ) + print(f"Predicted {len(pred_edges)} edges.") + + # 2. 构建三角形 + num_verts = pred_vtx_coords_float.shape[0] + pred_faces = build_triangles_from_edges(pred_edges, num_verts) + print(f"Constructed {len(pred_faces)} triangles.") + + # 3. 保存 OBJ + import trimesh + + # 坐标归一化/还原 (根据你的需求,这里假设你是 0-512 的体素坐标) + # 如果想保存为归一化坐标: + mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0 + + # 如果有 error offset,记得加上! + # 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了 + # 如果 vertex 也有 offset (如 dual contouring),在这里加上 + + # 移动到中心 (可选) + mesh_verts = mesh_verts - 0.5 + + mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces) + + trimesh.repair.fix_normals(mesh) + + output_obj_path = os.path.join(self.output_voxel_dir, f"{sample_name}_recon.obj") + mesh.export(output_obj_path) + print(f"Saved mesh to {output_obj_path}") + + # 保存边线 (用于 Debug) + # 有时候三角形很难形成,只看边也很有用 + edges_path = os.path.join(self.output_voxel_dir, f"{sample_name}_edges.ply") + # self._visualize_vertices(pred_edge_coords_np, gt_edge_coords_np, f"{sample_name}_edge_comparison") + + + # Process results at different resolutions + for i, res in enumerate([128, 256, 512]): + if i >= len(decoded_results): + continue + + gt_key = f'gt_vertex_voxels_{res}' + if gt_key not in batch_data: + continue + if i == 0: + pred_coords_res = decoded_results[i]['vtx_sp'].coords[:, 1:].float() + gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device) + else: + pred_coords_res = decoded_results[i]['vertex']['coords'].float() + gt_coords_res = batch_data[gt_key][:, 1:].float().to(self.device) + + + v_metrics = self._compute_vertex_metrics(pred_coords_res, gt_coords_res, threshold=threshold) + + per_sample_metrics['vertex'][res].append({ + 'recall': v_metrics['recall'], + 'precision': v_metrics['precision'], + 'f1': v_metrics['f1'], + 'num_pred': len(pred_coords_res), + 'num_gt': len(gt_coords_res) + }) + + avg_metrics['vertex'][res]['recall'].append(v_metrics['recall']) + avg_metrics['vertex'][res]['precision'].append(v_metrics['precision']) + avg_metrics['vertex'][res]['f1'].append(v_metrics['f1']) + + gt_edge_key = f'gt_edge_voxels_{res}' + if gt_edge_key not in batch_data: + continue + + if i == 0: + pred_edge_coords_res = decoded_results[i]['edge_sp'].coords[:, 1:].float() + # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device) + idx = i + gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) + elif i == 1: + idx = i + ################################# + # pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() - error / 2. + 0.5 + # # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device) + # gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) - gt_combined_errors_512[:, 1:].to(self.device) + 0.5 + + + pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() + gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) + + # self._save_voxel_ply(gt_edge_voxels_list[idx][:, 1:].float().to(self.device) / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res_wooffset") + # self._save_voxel_ply(decoded_results[i]['edge']['coords'].float() / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res_wooffset") + + else: + idx = i + pred_edge_coords_res = decoded_results[i]['edge']['coords'].float() + # gt_edge_coords_res = batch_data[gt_edge_key][:, 1:].float().to(self.device) + gt_edge_coords_res = gt_edge_voxels_list[idx][:, 1:].float().to(self.device) + + # self._save_voxel_ply(gt_edge_coords_res / (128*2**i), torch.zeros(len(gt_edge_coords_res)), f"{sample_name}_gt_edge_{128*2**i}res") + # self._save_voxel_ply(pred_edge_coords_res / (128*2**i), torch.zeros(len(pred_edge_coords_res)), f"{sample_name}_pred_edge_{128*2**i}res") + + e_metrics = self._compute_vertex_metrics(pred_edge_coords_res, gt_edge_coords_res, threshold=threshold) + + per_sample_metrics['edge'][res].append({ + 'recall': e_metrics['recall'], + 'precision': e_metrics['precision'], + 'f1': e_metrics['f1'], + 'num_pred': len(pred_edge_coords_res), + 'num_gt': len(gt_edge_coords_res) + }) + + avg_metrics['edge'][res]['recall'].append(e_metrics['recall']) + avg_metrics['edge'][res]['precision'].append(e_metrics['precision']) + avg_metrics['edge'][res]['f1'].append(e_metrics['f1']) + + avg_metrics_processed = {} + for category, res_dict in avg_metrics.items(): + avg_metrics_processed[category] = {} + for resolution, metric_dict in res_dict.items(): + avg_metrics_processed[category][resolution] = { + metric_name: np.mean(values) if values else float('nan') + for metric_name, values in metric_dict.items() + } + + result_data = { + 'config': self.config, + 'checkpoint': self.ckpt_path, + 'dataset': self.dataset_path, + 'num_samples': eval_samples, + 'per_sample_metrics': per_sample_metrics, + 'avg_metrics': avg_metrics_processed + } + + results_file_path = os.path.join(self.result_dir, f"evaluation_results_epoch{self.epoch}.yaml") + with open(results_file_path, 'w') as f: + yaml.dump(result_data, f, default_flow_style=False) + + return result_data + + def _generate_line_voxels( + self, + p1: torch.Tensor, + p2: torch.Tensor + ) -> Tuple[ + List[Tuple[int, int, int]], + List[Tuple[torch.Tensor, torch.Tensor]], + List[np.ndarray] + ]: + """ + Improved version using better sampling strategy + """ + p1_np = p1 #.cpu().numpy() + p2_np = p2 #.cpu().numpy() + voxel_dict = OrderedDict() + + # Use proper 3D line voxelization algorithm + def bresenham_3d(p1, p2): + """3D Bresenham's line algorithm""" + x1, y1, z1 = np.round(p1).astype(int) + x2, y2, z2 = np.round(p2).astype(int) + + points = [] + dx = abs(x2 - x1) + dy = abs(y2 - y1) + dz = abs(z2 - z1) + + xs = 1 if x2 > x1 else -1 + ys = 1 if y2 > y1 else -1 + zs = 1 if z2 > z1 else -1 + + # Driving axis is X + if dx >= dy and dx >= dz: + err_1 = 2 * dy - dx + err_2 = 2 * dz - dx + for i in range(dx + 1): + points.append((x1, y1, z1)) + if err_1 > 0: + y1 += ys + err_1 -= 2 * dx + if err_2 > 0: + z1 += zs + err_2 -= 2 * dx + err_1 += 2 * dy + err_2 += 2 * dz + x1 += xs + + # Driving axis is Y + elif dy >= dx and dy >= dz: + err_1 = 2 * dx - dy + err_2 = 2 * dz - dy + for i in range(dy + 1): + points.append((x1, y1, z1)) + if err_1 > 0: + x1 += xs + err_1 -= 2 * dy + if err_2 > 0: + z1 += zs + err_2 -= 2 * dy + err_1 += 2 * dx + err_2 += 2 * dz + y1 += ys + + # Driving axis is Z + else: + err_1 = 2 * dx - dz + err_2 = 2 * dy - dz + for i in range(dz + 1): + points.append((x1, y1, z1)) + if err_1 > 0: + x1 += xs + err_1 -= 2 * dz + if err_2 > 0: + y1 += ys + err_2 -= 2 * dz + err_1 += 2 * dx + err_2 += 2 * dy + z1 += zs + + return points + + # Get all voxels using Bresenham algorithm + voxel_coords = bresenham_3d(p1_np, p2_np) + + # Add all voxels to dictionary + for coord in voxel_coords: + voxel_dict[tuple(coord)] = (p1, p2) + + voxel_coords = list(voxel_dict.keys()) + endpoint_pairs = list(voxel_dict.values()) + + # --- compute error vectors --- + error_vectors = [] + diff = p2_np - p1_np + d_norm_sq = np.dot(diff, diff) + + for v in voxel_coords: + v_center = np.array(v, dtype=float) + 0.5 + if d_norm_sq == 0: # degenerate line + closest = p1_np + else: + t = np.dot(v_center - p1_np, diff) / d_norm_sq + t = np.clip(t, 0.0, 1.0) + closest = p1_np + t * diff + error_vectors.append(v_center - closest) + + return voxel_coords, endpoint_pairs, error_vectors + + +# 使用示例 +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir): + set_seed(42) + tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path) + result_data = tester.evaluate(num_samples=NUM_SAMPLES, visualize=VISUALIZE, chamfer_threshold=CHAMFER_EDGE_THRESHOLD, threshold=THRESHOLD) + + # 生成文件名 + epoch_str = os.path.basename(ckpt_path).split('_')[1].split('.')[0] + dataset_name = os.path.basename(os.path.normpath(dataset_path)) + + # 保存简版报告(TXT) + summary_path = os.path.join(eval_dir, f"epoch{epoch_str}_{dataset_name}_summary_threshold{THRESHOLD}_one2one.txt") + with open(summary_path, 'w') as f: + # 头部信息 + f.write(f"Checkpoint: {os.path.basename(ckpt_path)}\n") + f.write(f"Dataset: {dataset_name}\n") + f.write(f"Evaluation Samples: {result_data['num_samples']}\n\n") + + # 平均指标 + f.write("=== Average Metrics ===\n") + for category, data in result_data['avg_metrics'].items(): + if isinstance(data, dict): # 处理多分辨率情况 + f.write(f"\n{category.upper()}:\n") + for res, metrics in data.items(): + f.write(f" Resolution {res}:\n") + for k, v in metrics.items(): + # 确保值是数字类型后再格式化 + if isinstance(v, (int, float)): + f.write(f" {str(k).ljust(15)}: {v:.4f}\n") + else: + f.write(f" {str(k).ljust(15)}: {str(v)}\n") + else: # 处理非多分辨率情况 + f.write(f"\n{category.upper()}:\n") + for k, v in data.items(): + if isinstance(v, (int, float)): + f.write(f" {str(k).ljust(15)}: {v:.4f}\n") + else: + f.write(f" {str(k).ljust(15)}: {str(v)}\n") + + # 样本级详细统计 + f.write("\n\n=== Detailed Per-Sample Metrics ===\n") + for name, vertex_metrics, edge_metrics in zip( + result_data['per_sample_metrics']['sample_names'], + zip(*[result_data['per_sample_metrics']['vertex'][res] for res in [128, 256, 512]]), + zip(*[result_data['per_sample_metrics']['edge'][res] for res in [128, 256, 512]]) + ): + # 样本标题 + f.write(f"\n◆ Sample: {name}\n") + + # 顶点指标 + f.write(f"[Vertex Prediction]\n") + f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n") + for res, metrics in zip([128, 256, 512], vertex_metrics): + f.write(f" {str(res).ljust(10)} " + f"{metrics['recall']:.4f} " + f"{metrics['precision']:.4f} " + f"{metrics['f1']:.4f} " + f"{metrics['num_pred']}/{metrics['num_gt']}\n") + + # Edge指标 + f.write(f"[Edge Prediction]\n") + f.write(f" {'Resolution'.ljust(10)} {'Recall'.ljust(8)} {'Precision'.ljust(8)} {'F1'.ljust(8)} {'Pred/Gt'.ljust(10)}\n") + for res, metrics in zip([128, 256, 512], edge_metrics): + f.write(f" {str(res).ljust(10)} " + f"{metrics['recall']:.4f} " + f"{metrics['precision']:.4f} " + f"{metrics['f1']:.4f} " + f"{metrics['num_pred']}/{metrics['num_gt']}\n") + + f.write("-"*60 + "\n") + + print(f"Saved summary to: {summary_path}") + return result_data + + +if __name__ == '__main__': + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + evaluate_all_checkpoints = False # 设置为 True 启用范围过滤 + EPOCH_START = 1 + EPOCH_END = 12 + CHAMFER_EDGE_THRESHOLD=0.5 + NUM_SAMPLES=50 + VISUALIZE=True + THRESHOLD=1.5 + VISUAL_FIELD=False + + ckpt_path = '/home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch1_batch12000_loss0.1140.pt' + dataset_path = '/home/tiger/yy/src/trellis_clean_mesh/mesh_data' + + if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000': + RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond' + else: + RENDERS_DIR = '' + + + ckpt_dir = os.path.dirname(ckpt_path) + eval_dir = os.path.join(ckpt_dir, "evaluate") + os.makedirs(eval_dir, exist_ok=True) + + if False: + for i in range(NUM_SAMPLES): + print("--- Starting Latent Space PCA Visualization ---") + tester = Tester(ckpt_path=ckpt_path, dataset_path=dataset_path) + tester.visualize_latent_space_pca(sample_idx=i) + print("--- PCA Visualization Finished ---") + + if not evaluate_all_checkpoints: + evaluate_checkpoint(ckpt_path, dataset_path, eval_dir) + else: + pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')]) + + filtered_pt_files = [] + for f in pt_files: + try: + parts = f.split('_') + epoch_str = parts[1].replace('epoch', '') + epoch = int(epoch_str) + if EPOCH_START <= epoch <= EPOCH_END: + filtered_pt_files.append(f) + except Exception as e: + print(f"Warning: Could not parse epoch from {f}: {e}") + continue + + for pt_file in filtered_pt_files: + full_ckpt_path = os.path.join(ckpt_dir, pt_file) + evaluate_checkpoint(full_ckpt_path, dataset_path, eval_dir) \ No newline at end of file diff --git a/test_slat_vae_128to512_pointnet_vae_head.py b/test_slat_vae_128to512_pointnet_vae_head.py index 81fcb9ae87aa7f7f609e28ddbde79b2189ac47a5..49d064638660059b1f108c59aa06a9b1d074d739 100644 --- a/test_slat_vae_128to512_pointnet_vae_head.py +++ b/test_slat_vae_128to512_pointnet_vae_head.py @@ -744,7 +744,7 @@ class Tester: root_dir=self.dataset_path, base_resolution=self.config['dataset']['base_resolution'], min_resolution=self.config['dataset']['min_resolution'], - cache_dir='/gemini/user/private/zhaotianhao/dataset_cache/test_15c_dora', + cache_dir='/home/tiger/yy/src/dataset_cache', # cache_dir=self.config['dataset']['cache_dir'], renders_dir=self.config['dataset']['renders_dir'], @@ -1262,15 +1262,12 @@ class Tester: # 如果想保存为归一化坐标: mesh_verts = pred_vtx_coords_float.cpu().numpy() / 512.0 - # 如果有 error offset,记得加上! - # 你之前的代码好像没有对 vertex 加 offset,只对 edge 加了 - # 如果 vertex 也有 offset (如 dual contouring),在这里加上 - # 移动到中心 (可选) mesh_verts = mesh_verts - 0.5 mesh = trimesh.Trimesh(vertices=mesh_verts, faces=pred_faces) + trimesh.repair.fix_normals(mesh) # 过滤孤立点 (可选) # mesh.remove_unreferenced_vertices() @@ -1581,22 +1578,23 @@ def evaluate_checkpoint(ckpt_path, dataset_path, eval_dir): if __name__ == '__main__': with torch.cuda.amp.autocast(dtype=torch.bfloat16): evaluate_all_checkpoints = True # 设置为 True 启用范围过滤 - EPOCH_START = 0 + EPOCH_START = 1 EPOCH_END = 12 CHAMFER_EDGE_THRESHOLD=0.5 - NUM_SAMPLES=20 + NUM_SAMPLES=50 VISUALIZE=True THRESHOLD=1.5 VISUAL_FIELD=False - ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch10433_loss1.2657.pt' - ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch2000_loss0.3315.pt' - - dataset_path = '/gemini/user/private/zhaotianhao/data/MERGED_DATASET_count_200_2000_100000/test' - dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized' - # dataset_path = '/gemini/user/private/zhaotianhao/data/trellis500k_compress_glb' - dataset_path = '/gemini/user/private/zhaotianhao/data/unique_files_glb_under6000face_2degree_30ratio_0.01' + # ckpt_path = '/gemini/user/private/zhaotianhao/checkpoints/vae/train_9w_200_2000face/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small/checkpoint_epoch0_batch10433_loss1.2657.pt' + ckpt_path = '/home/tiger/yy/src/checkpoints/vae/unique_files_glb_under6000face_2degree_30ratio_0.01/shapenet_bs2_128to512_wolabel_dir_sorted_dora_small_lowlr/checkpoint_epoch0_batch6000_loss0.1150.pt' + dataset_path = '/home/tiger/yy/src/unique_files_glb_under6000face_2degree_30ratio_0.01' + # dataset_path = '/gemini/user/private/zhaotianhao/data/why_filter_unquantized' + # # dataset_path = '/gemini/user/private/zhaotianhao/data/trellis500k_compress_glb' + # dataset_path = '/gemini/user/private/zhaotianhao/data/unique_files_glb_under6000face_2degree_30ratio_0.01' + dataset_path = '/home/tiger/yy/src/trellis_clean_mesh/mesh_data' + if dataset_path == '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh/objaverse_200_2000': RENDERS_DIR = '/HOME/paratera_xy/pxy1054/HDD_POOL/Trisf/data/mesh_render_img/objaverse_200_2000/renders_cond' else: diff --git a/train_slat_flow_128to512_pointnet_head.py b/train_slat_flow_128to512_pointnet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1d47668fc519019b2121cd091a36cce70d5e107b --- /dev/null +++ b/train_slat_flow_128to512_pointnet_head.py @@ -0,0 +1,507 @@ +import os +# os.environ['ATTN_BACKEND'] = 'xformers' # xformers is generally compatible with DDP +# os.environ["OMP_NUM_THREADS"] = "1" +# os.environ["MKL_NUM_THREADS"] = "1" +import torch +import numpy as np +import yaml +from torch.utils.data import DataLoader, DistributedSampler +from functools import partial +import torch.nn.functional as F +from torch.optim import AdamW +from torch.amp import GradScaler, autocast +from typing import * +from transformers import CLIPTextModel, AutoTokenizer, CLIPTextConfig +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +# --- Updated Imports based on VAE script --- +from dataset_triposf_head import VoxelVertexDataset_edge, collate_fn_pointnet +from triposf.models.triposf_vae.VoxelFeatureVAE_edge_woself_128to1024_decoder_head import VoxelVAE +from vertex_encoder import VoxelFeatureEncoder_active_pointnet, ConnectionHead +from triposf.modules.sparse.basic import SparseTensor + +from trellis.models.structured_latent_flow import SLatFlowModel +from trellis.trainers.flow_matching.sparse_flow_matching_alone import SparseFlowMatchingTrainer +from safetensors.torch import load_file +import torch.multiprocessing as mp +from PIL import Image +import torch.nn as nn + +from triposf.modules.utils import DiagonalGaussianDistribution +import torchvision.transforms as transforms +import re +import contextlib + +# --- Distributed Setup Functions --- +def setup_distributed(backend="nccl"): + """Initializes the distributed environment.""" + if not dist.is_initialized(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + + torch.cuda.set_device(local_rank) + dist.init_process_group(backend=backend) + + return int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) + +def cleanup_distributed(): + dist.destroy_process_group() + +# --- Modified Trainer Class --- +class SLatFlowMatchingTrainer(SparseFlowMatchingTrainer): + def __init__(self, *args, rank: int, local_rank: int, world_size: int, **kwargs): + super().__init__(*args, **kwargs) + self.cfg = kwargs.pop('cfg', None) + if self.cfg is None: + raise ValueError("Configuration dictionary 'cfg' must be provided.") + + # --- Distributed-related attributes --- + self.rank = rank + self.local_rank = local_rank + self.world_size = world_size + self.device = torch.device(f"cuda:{self.local_rank}") + self.is_master = (self.rank == 0) + self.gradient_accumulation_steps = 8 + + self.i_save = self.cfg['training']['save_every'] + self.save_dir = self.cfg['training']['output_dir'] + + self.resolution = 128 + self.condition_type = 'image' + self.is_cond = False + self.img_res = 518 + + if self.is_master: + os.makedirs(self.save_dir, exist_ok=True) + print(f"Checkpoints and logs will be saved to: {self.save_dir}") + + # Initialize components and set up for DDP + self._init_components( + clip_model_path=self.cfg['training'].get('clip_model_path', None), + dinov2_model_path=self.cfg['training'].get('dinov2_model_path', None), + vae_path=self.cfg['training']['vae_path'], + ) + + self._setup_ddp() + + self.denoiser_checkpoint_path = self.cfg['training'].get('denoiser_checkpoint_path', None) + + trainable_params = list(self.denoiser.parameters()) + self.optimizer = AdamW(trainable_params, lr=self.cfg['training'].get('lr', 0.0001), weight_decay=0.0) + + self.scaler = GradScaler() + + if self.is_master: + print("Using Automatic Mixed Precision (AMP) with GradScaler.") + + def _init_components(self, + clip_model_path=None, + dinov2_model_path=None, + vae_path=None, + ): + """ + Initializes VAE, VoxelEncoder (PointNet), and condition models. + """ + # Use the Dataset from the VAE script + self.dataset = VoxelVertexDataset_edge( + root_dir=self.cfg['dataset']['path'], + base_resolution=self.cfg['dataset']['base_resolution'], + min_resolution=self.cfg['dataset']['min_resolution'], + cache_dir=self.cfg['dataset']['cache_dir'], + renders_dir=self.cfg['dataset']['renders_dir'], + + filter_active_voxels=self.cfg['dataset']['filter_active_voxels'], + cache_filter_path=self.cfg['dataset']['cache_filter_path'], + + active_voxel_res=128, + pc_sample_number=819200, + + sample_type=self.cfg['dataset']['sample_type'], + + ) + + self.sampler = DistributedSampler( + self.dataset, + num_replicas=self.world_size, + rank=self.rank, + shuffle=True + ) + + # Use collate_fn_pointnet + self.dataloader = DataLoader( + self.dataset, + batch_size=self.cfg['training']['batch_size'], + shuffle=False, + collate_fn=partial(collate_fn_pointnet,), + num_workers=self.cfg['training']['num_workers'], + pin_memory=True, + sampler=self.sampler, + prefetch_factor=4, + persistent_workers=True, + drop_last=True, + ) + + def load_file_func(path, device='cpu'): + return torch.load(path, map_location=device) + + def _load_and_broadcast(model, load_fn=None, path=None, strict=True): + if self.is_master: + try: + state = load_fn(path) if load_fn else model.state_dict() + except Exception as e: + raise RuntimeError(f"Failed to load weights from {path}: {e}") + else: + state = None + + dist.barrier() + state_b = [state] if self.is_master else [None] + dist.broadcast_object_list(state_b, src=0) + + try: + # Handle potential key mismatches (e.g. 'module.' prefix) + model.load_state_dict(state_b[0], strict=strict) + except Exception as e: + if self.is_master: print(f"Strict loading failed for {model.__class__.__name__}, trying non-strict: {e}") + model.load_state_dict(state_b[0], strict=False) + + # ------------------------- Voxel Encoder (PointNet) ------------------------- + # Matching the VAE script configuration + self.voxel_encoder = VoxelFeatureEncoder_active_pointnet( + in_channels=15, + hidden_dim=256, + out_channels=1024, + scatter_type='mean', + n_blocks=5, + resolution=128, + add_label=False, + ).to(self.device) + + # ------------------------- VAE ------------------------- + self.vae = VoxelVAE( + in_channels=self.cfg['model']['in_channels'], + latent_dim=self.cfg['model']['latent_dim'], + encoder_blocks=self.cfg['model']['encoder_blocks'], + decoder_blocks_vtx=self.cfg['model']['decoder_blocks_vtx'], + decoder_blocks_edge=self.cfg['model']['decoder_blocks_edge'], + num_heads=8, + num_head_channels=64, + mlp_ratio=4.0, + attn_mode="swin", + window_size=8, + pe_mode="ape", + use_fp16=False, + use_checkpoint=True, + qk_rms_norm=False, + using_subdivide=True, + using_attn=self.cfg['model']['using_attn'], + attn_first=self.cfg['model'].get('attn_first', True), + pred_direction=self.cfg['model'].get('pred_direction', False), + ).to(self.device) + + + # ------------------------- Conditioning ------------------------- + if self.condition_type == 'text': + self.tokenizer = AutoTokenizer.from_pretrained(clip_model_path) + if self.is_master: + self.condition_model = CLIPTextModel.from_pretrained(clip_model_path) + else: + config = CLIPTextConfig.from_pretrained(clip_model_path) + self.condition_model = CLIPTextModel(config) + _load_and_broadcast(self.condition_model) + + elif self.condition_type == 'image': + if self.is_master: + print("Initializing for IMAGE conditioning (DINOv2).") + + # Update paths as per your environment + local_repo_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2-main" + weights_path = "/home/tiger/yy/src/gemini/user/private/zhaotianhao/dinov2_resources/dinov2_vitl14_reg4_pretrain.pth" + + dinov2_model = torch.hub.load( + repo_or_dir=local_repo_path, + model='dinov2_vitl14_reg', + source='local', + pretrained=False + ) + self.condition_model = dinov2_model + + _load_and_broadcast(self.condition_model, load_fn=torch.load, path=weights_path) + + self.image_cond_model_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + else: + raise ValueError(f"Unsupported condition type: {self.condition_type}") + + self.condition_model.to(self.device).eval() + for p in self.condition_model.parameters(): p.requires_grad = False + + # ------------------------- Load VAE/Encoder Weights ------------------------- + # Load weights corresponding to the logic in VAE script's `load_pretrained_woself` + # Assuming checkpoint contains 'vae' and 'voxel_encoder' keys + _load_and_broadcast(self.vae, + load_fn=lambda p: load_file_func(p)['vae'], + path=vae_path) + + _load_and_broadcast(self.voxel_encoder, + load_fn=lambda p: load_file_func(p)['voxel_encoder'], + path=vae_path) + + self.vae.eval() + self.voxel_encoder.eval() + for p in self.vae.parameters(): p.requires_grad = False + for p in self.voxel_encoder.parameters(): p.requires_grad = False + + def _load_denoiser(self): + """Loads a checkpoint for the denoiser.""" + path = self.denoiser_checkpoint_path + if not path or not os.path.isfile(path): + if self.is_master: + print("No valid checkpoint path provided for denoiser. Starting from scratch.") + return + + if self.is_master: + print(f"Loading denoiser checkpoint from: {path}") + checkpoint = torch.load(path, map_location=self.device) + else: + checkpoint = None + + dist.barrier() + dist_list = [checkpoint] if self.is_master else [None] + dist.broadcast_object_list(dist_list, src=0) + checkpoint = dist_list[0] + + try: + self.denoiser.module.load_state_dict(checkpoint['denoiser']) + if self.is_master: print("Denoiser weights loaded successfully.") + except Exception as e: + if self.is_master: print(f"[ERROR] Failed to load denoiser state_dict: {e}") + + if 'step' in checkpoint and self.is_master: + print(f"Checkpoint from step {checkpoint['step']}.") + + dist.barrier() + + def _setup_ddp(self): + """Sets up DDP and DataLoaders.""" + self.denoiser = self.denoiser.to(self.device) + self.denoiser = DDP(self.denoiser, device_ids=[self.local_rank]) + + for param in self.denoiser.parameters(): + param.requires_grad = True + + @torch.no_grad() + def encode_image(self, images) -> torch.Tensor: + if isinstance(images, torch.Tensor): + batch_tensor = images.to(self.device) + elif isinstance(images, list): + assert all(isinstance(i, Image.Image) for i in images), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in images] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + batch_tensor = torch.stack(image).to(self.device) + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + if batch_tensor.shape[-2:] != (518, 518): + batch_tensor = F.interpolate(batch_tensor, (518, 518), mode='bicubic', align_corners=False) + + features = self.condition_model(batch_tensor, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + def process_batch(self, batch): + preprocessed_images = batch['image'] + cond_ = self.encode_image(preprocessed_images) + return cond_ + + def train(self, num_epochs=1000): + # 1. 无条件生成的准备工作 (和之前一样) + if self.is_cond == False: + if self.condition_type == 'text': + txt = [''] + encoding = self.tokenizer(txt, max_length=77, padding='max_length', truncation=True, return_tensors='pt') + tokens = encoding['input_ids'].to(self.device) + with torch.no_grad(): + cond_ = self.condition_model(input_ids=tokens).last_hidden_state + else: + blank_img = Image.fromarray(np.zeros((self.img_res, self.img_res, 3), dtype=np.uint8)) + with torch.no_grad(): + dummy_cond = self.encode_image([blank_img]) + cond_ = torch.zeros_like(dummy_cond) + if self.is_master: print(f"Generated unconditional image prompt with shape: {cond_.shape}") + + self._load_denoiser() + self.denoiser.train() + + # 获取全局步数 + global_step = 0 + if self.denoiser_checkpoint_path: + match = re.search(r'step(\d+)', self.denoiser_checkpoint_path) + if match: + global_step = int(match.group(1)) + + accum_steps = self.gradient_accumulation_steps + if self.is_master: + print(f"Training with Gradient Accumulation Steps: {accum_steps}") + + # 确保循环开始前梯度清零 + self.optimizer.zero_grad(set_to_none=True) + + for epoch in range(num_epochs): + self.dataloader.sampler.set_epoch(epoch) + epoch_losses = [] + + # 遍历数据 + for i, batch in enumerate(self.dataloader): + + # --- A. 数据准备 --- + if self.is_cond and self.condition_type == 'image': + cond_ = self.process_batch(batch) + + point_cloud = batch['point_cloud_128'].to(self.device) + active_coords = batch['active_voxels_128'].to(self.device) + + batch_size = int(active_coords[:, 0].max().item() + 1) + if cond_.shape[0] != batch_size: + cond_ = cond_.expand(batch_size, -1, -1).contiguous().to(self.device) + else: + cond_ = cond_.to(self.device) + + # --- B. 前向传播 & Loss计算 --- + with autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.no_grad(): + active_voxel_feats = self.voxel_encoder( + p=point_cloud, + sparse_coords=active_coords, + res=128, + bbox_size=(-0.5, 0.5), + ) + sparse_input = SparseTensor( + feats=active_voxel_feats, + coords=active_coords.int() + ) + latent_128, posterior = self.vae.encode(sparse_input) + + terms, _ = self.training_losses(x_0=latent_128, cond=cond_) + loss = terms['loss'] + + # [重点] Loss 除以累积步数 + loss = loss / accum_steps + + # --- C. 反向传播 --- + # 注意:这里没有 no_sync,每次都会同步梯度 + self.scaler.scale(loss).backward() + + # 记录还原后的 Loss 用于显示 + current_real_loss = loss.item() * accum_steps + epoch_losses.append(current_real_loss) + + # --- D. 梯度累积判断与更新 --- + if (i + 1) % accum_steps == 0: + # 如果需要 clip_grad_norm,可以在这里加: + # self.scaler.unscale_(self.optimizer) + # torch.nn.utils.clip_grad_norm_(self.denoiser.parameters(), max_norm=1.0) + + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + + global_step += 1 + + # --- Logging (只在更新步进行) --- + if self.is_master: + if global_step % 10 == 0: + print(f"Epoch {epoch+1} Step {global_step}: " + f"Batch_Loss = {current_real_loss:.4f}, " + f"Epoch_Mean = {np.mean(epoch_losses):.4f}") + + if global_step % self.i_save == 0: + checkpoint = { + 'denoiser': self.denoiser.module.state_dict(), + 'step': global_step + } + loss_str = f"{np.mean(epoch_losses):.6f}".replace('.', '_') + save_path = os.path.join(self.save_dir, f"checkpoint_step{global_step}_loss{loss_str}.pt") + torch.save(checkpoint, save_path) + print(f"Saved checkpoint to {save_path}") + + # --- E. 处理 Epoch 结束时的残留 Batch (Leftovers) --- + # 如果 dataloader 长度不能被 accum_steps 整除,且 drop_last=False, + # 这里需要把最后累积的一点梯度更新掉。 + # (如果你的 dataloader 设置了 drop_last=True 且总数够整除,这里不会触发,但写上比较保险) + if (i + 1) % accum_steps != 0: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + # 注意:这里通常不增加 global_step 或者看你习惯, + # 因为这是一个“不完整”的 step,通常梯度也是不对等的(因为除数还是accum_steps) + # 所以很多实现为了稳定,直接设置 drop_last=True 避开这种情况。 + + if self.is_master: + avg_loss = np.mean(epoch_losses) if epoch_losses else 0 + log_path = os.path.join(self.save_dir, "loss_log.txt") + with open(log_path, "a") as f: + f.write(f"Epoch {epoch+1}, Step {global_step}, AvgLoss {avg_loss:.6f}\n") + + dist.barrier() + + if self.is_master: + print("Training complete.") + +def main(): + # if mp.get_start_method(allow_none=True) != 'spawn': + # mp.set_start_method('spawn', force=True) + + # if mp.get_start_method(allow_none=True) != 'forkserver': + # mp.set_start_method('forkserver', force=True) + + rank, local_rank, world_size = setup_distributed() + torch.manual_seed(42+rank) + np.random.seed(42+rank) + + # Path to your config + config_path = "/home/tiger/yy/src/Michelangelo-master/config_slat_flow_128to512_pointnet_head.yaml" + with open(config_path) as f: + cfg = yaml.safe_load(f) + + # Initialize Flow Model (on CPU first) + diffusion_model = SLatFlowModel( + resolution=cfg['flow']['resolution'], + in_channels=cfg['flow']['in_channels'], + out_channels=cfg['flow']['out_channels'], + model_channels=cfg['flow']['model_channels'], + cond_channels=cfg['flow']['cond_channels'], + num_blocks=cfg['flow']['num_blocks'], + num_heads=cfg['flow']['num_heads'], + mlp_ratio=cfg['flow']['mlp_ratio'], + patch_size=cfg['flow']['patch_size'], + num_io_res_blocks=cfg['flow']['num_io_res_blocks'], + io_block_channels=cfg['flow']['io_block_channels'], + pe_mode=cfg['flow']['pe_mode'], + qk_rms_norm=cfg['flow']['qk_rms_norm'], + qk_rms_norm_cross=cfg['flow']['qk_rms_norm_cross'], + use_fp16=cfg['flow'].get('use_fp16', False), + ) + + torch.manual_seed(42 + rank) + np.random.seed(42 + rank) + + trainer = SLatFlowMatchingTrainer( + denoiser=diffusion_model, + t_schedule=cfg['t_schedule'], + sigma_min=cfg['sigma_min'], + cfg=cfg, + rank=rank, + local_rank=local_rank, + world_size=world_size, + ) + + trainer.train() + cleanup_distributed() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/trellis/__init__.py b/trellis/__init__.py index 20d240afc9c26a21aee76954628b3d4ef9a1ccbd..7c8c28a48dc1a11c026cabe6dc0d824578c3fc02 100644 --- a/trellis/__init__.py +++ b/trellis/__init__.py @@ -2,5 +2,5 @@ from . import models from . import modules from . import pipelines from . import renderers -from . import representations +# from . import representations from . import utils diff --git a/trellis/__pycache__/__init__.cpython-310.pyc b/trellis/__pycache__/__init__.cpython-310.pyc index 2671fb157268089c31bda0acf9e95c7d54bc122d..9c9e0798374fb69c601c43d9f898dfac3751203e 100644 Binary files a/trellis/__pycache__/__init__.cpython-310.pyc and b/trellis/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/__init__.cpython-310.pyc b/trellis/models/__pycache__/__init__.cpython-310.pyc index e8fd5b707000f2a793d67faa54531d2747ffd954..eeda81a75d9cbf96ab0aebfc189d77a61a5bdf85 100644 Binary files a/trellis/models/__pycache__/__init__.cpython-310.pyc and b/trellis/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc b/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc index 63086d24b2e2755e9f1ea1d5a214fedddfae02df..f98e26f2feea4a754c75e5130f4e7e301768f535 100644 Binary files a/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc and b/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc b/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc index f3b382ae5af1fb031931d93eb8c88f68acd75be1..b7a957e0958ac90bece7248afa45a05961e30781 100644 Binary files a/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc and b/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc b/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc index b1e860963ebcb85584600cbf19d61040fd5ea089..8d346cb3b1020be6ab8d6df16250bb7394356b32 100644 Binary files a/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc and b/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc differ diff --git a/trellis/modules/__pycache__/norm.cpython-310.pyc b/trellis/modules/__pycache__/norm.cpython-310.pyc index 5b47d7fff68ace70b9618304a5d99a0088761dd1..77de6822cc643c4e5ac785972746239534b88149 100644 Binary files a/trellis/modules/__pycache__/norm.cpython-310.pyc and b/trellis/modules/__pycache__/norm.cpython-310.pyc differ diff --git a/trellis/modules/__pycache__/spatial.cpython-310.pyc b/trellis/modules/__pycache__/spatial.cpython-310.pyc index 64061a60666b71583f50625f0a0223d35c18c56b..4f769f87d5d97d908a64211003d253d4c205c74e 100644 Binary files a/trellis/modules/__pycache__/spatial.cpython-310.pyc and b/trellis/modules/__pycache__/spatial.cpython-310.pyc differ diff --git a/trellis/modules/__pycache__/utils.cpython-310.pyc b/trellis/modules/__pycache__/utils.cpython-310.pyc index c9f930f40c16550686e91e1ca1e7194075728dbe..f1cd0d37da7850e0e03c842bb7c6eaf09896eca6 100644 Binary files a/trellis/modules/__pycache__/utils.cpython-310.pyc and b/trellis/modules/__pycache__/utils.cpython-310.pyc differ diff --git a/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc b/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc index c906a2c524225f5fbd00d90b92ab10d9906a6ff5..1f1ad4dca3fe6a3d9f2370f17f5645c1473a1fb8 100644 Binary files a/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc b/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc index 39d504ca477c066dec0bf90ef0575e23d5523c9e..9537737ebe7a7f7fe594651491a82226193393fc 100644 Binary files a/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc and b/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/trellis/modules/attention/__pycache__/modules.cpython-310.pyc b/trellis/modules/attention/__pycache__/modules.cpython-310.pyc index c0f3419d8aad79e85b25bd0c56f040ccdf9d0162..06820124ccb44a763b18843328e323e447e21f21 100644 Binary files a/trellis/modules/attention/__pycache__/modules.cpython-310.pyc and b/trellis/modules/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc index 46b7f00e69263afefc5face633a3302cce6050cb..0bf6e85ddb3d8a634f2829b87c9b8be99991251e 100644 Binary files a/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc b/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc index 1b996db0b31acb3fe18716ea40f924efa19852e2..92b8bc8e200d273c28b20b277360f35241a72a58 100644 Binary files a/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc b/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc index 9bbf2021a312e3f8056916be079f8672bb7bdb3c..d82513d5962db223b0e2cbb620bc1de735cbe08b 100644 Binary files a/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc b/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc index 536b6a2136fdebd531ff5dcfdcbcbcea13c28c56..6920228ff25022044e0801af80c99bd38748652b 100644 Binary files a/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc b/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc index 0e00ea6c66abc443d03ac1560dac02900f2043a2..02d999b6a598ba54c953c701f2c4e77d59c5f544 100644 Binary files a/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc index d394f9673da968979cf4fde18f6e0c16298af64c..70dbd96819c3b4c66d5764489cc8569467833d6f 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc index 560ca19039f056d44a666fa86f4391e345942c33..3a3c01cd10058a99be282ad57a5874fc0656f3a9 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc index ebe406dd3ce08adb8545c73f7e96e90f20848d3d..8f5dfcc203847719114f1d0a04cf17aef22aacfc 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc index b0724dcb15eed440cc826801d0fe7f07d8627062..af33065118d72caa9fb8d3c0b134d8be66bd5b9f 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc index deae8fc1c4984c90b3afa365261b3ecba5a6cd25..f5ee03b10b9c0a6b03c7670ccfec6378bdf6b8bc 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc index 9d45541e9c4bf00936b3caffab75ca2c1fe71230..ba7ec79b9fcf1275e875dc1fbd507c1ff4166a08 100644 Binary files a/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc b/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc index 4b5f6211b1fc26ffda91f3892e7c9f5da4543def..0004fa121143301d936ace6cb6782fec5c80dac1 100644 Binary files a/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc and b/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc index e5b1ecb38a0f1116b510f17b914e96f700f726d0..547993470d9262d10e3ca644d53ee71722f16bd3 100644 Binary files a/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc index c4bb4743c889348fc5749c637bb614809ae1d568..6e8aa403ef0546ba529c8a39d249f8ae7decd27c 100644 Binary files a/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc index a5617ea473b6e6c67ea9f9d5bbc46318e5e0c8f6..2ae33271927ba63382d47105d84357dd1925e71f 100644 Binary files a/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc b/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc index 8cce1133a3e1de4f056605983db02ff80da54039..05f42d78118131febf77a3e2455f6a78f877ef79 100644 Binary files a/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc b/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc index fa9729c2ab37f47234aeb334cad337f60980be70..d19c48d9314f638f5fa73dbec4f4fdd834c9f374 100644 Binary files a/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc b/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc index af241087a9c1e85f8ca812e94191137708880c52..2e33d9d57c5e2fc1373eec6543580b38d9f0065f 100644 Binary files a/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/__init__.cpython-310.pyc b/trellis/pipelines/__pycache__/__init__.cpython-310.pyc index ad4fa90621c6c716aff8b0e5c68b0f39151575c1..aae3130d0608d059f461d358e3f12bca4392393f 100644 Binary files a/trellis/pipelines/__pycache__/__init__.cpython-310.pyc and b/trellis/pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/base.cpython-310.pyc b/trellis/pipelines/__pycache__/base.cpython-310.pyc index 062ad7672db6948a42a89600035aa02e778d70d0..e0e21881b0552a90eb9c8bd4a2021e6226453309 100644 Binary files a/trellis/pipelines/__pycache__/base.cpython-310.pyc and b/trellis/pipelines/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc b/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc index 93ed0c8daf40ebee551a01b19436e87d89d37976..812a7c5107d4dfab0236dcd554671e55e1817a4e 100644 Binary files a/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc and b/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc b/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc index 8b52fa32e26b1140c83e4500b70da25f544c4758..acc1e6c0f1b596bab786c06d1aedb61e308d79c7 100644 Binary files a/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc and b/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc index ef7c213b0e296e5f3457802f7dab7ec9b9f4ba73..46c0679d38d4c2ffe79b621b097016d02afda2fe 100644 Binary files a/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc index 2cd98da6b762e18e4730e32c051d8b886ce6c065..c6649b85f086328fe9baa0037face532215b4a89 100644 Binary files a/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc index 8f6e13ee1c1cd552f276156ad572689f8d9c0b8c..21d17ee6e6ef2aca0033c939e1119437815c85d8 100644 Binary files a/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc index cab31d5bf449fc4132002705f9e31abcdf91c241..971efbcb19ec9c0b8721ef1fe8cb6f8d5f09bf54 100644 Binary files a/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc index 5947e6360f785f5532413b355a6fb742d2747952..9f0f4dd6ebca41ac568c340032a273ffe2e765fc 100644 Binary files a/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc differ diff --git a/trellis/renderers/__pycache__/__init__.cpython-310.pyc b/trellis/renderers/__pycache__/__init__.cpython-310.pyc index 7a04cb915f46cac14136acec869d207cacb82e57..37761b027557291c8e8f05284c50fa0e4a36777d 100644 Binary files a/trellis/renderers/__pycache__/__init__.cpython-310.pyc and b/trellis/renderers/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/__pycache__/__init__.cpython-310.pyc b/trellis/representations/__pycache__/__init__.cpython-310.pyc index e9e4d808e3439a58c148136a4552c7a53636293c..99a19f8ac53a5470c5065d6c1b25f45dfbe5baec 100644 Binary files a/trellis/representations/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc index 013556b055b13f4bc56da627fb5f132e16714ca4..3218bdc93b4f16df3738b2afa9756cd371bae059 100644 Binary files a/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc index 6bf40021068cdaa8ea189419ec8d6cd1c177aa87..84c0ebce10134c3b8e88984875b80ef5857c7ebb 100644 Binary files a/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc index 38045088c1f860e8461a17659385ae44306e7c1d..e1bc8ba97aaf11abd5206e8661bcd0544ea6f618 100644 Binary files a/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc b/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc index 0b5176792ead613bf1ba9b62bc5aeaaf8fc27162..22b2bc7191b8e0ddb2bf83325dafa89ebb13b552 100644 Binary files a/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc b/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc index 20e9ca5d471326939524cc16acbad7181f4e95c5..0b883c4fefed819d165a8a847cd581fc5aad991d 100644 Binary files a/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc b/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc index ea6744b48c343907db223a18a5e5dce5ff081e1d..8d25bcaa38671d6c7e0160544649318e9a9fa487 100644 Binary files a/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc differ diff --git a/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc b/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc index 431707e551705ad6345d8b83059f38c8d1f4d80d..978622888a1576c4beb906eb259e126b1f620f0f 100644 Binary files a/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc b/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc index 1eb7e8ca93a86043f63fdb8d1aa639d73fa81886..ad2b5337581acaefa6db811a9354148a92798aaa 100644 Binary files a/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc and b/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc differ diff --git a/trellis/representations/radiance_field/__pycache__/__init__.cpython-310.pyc b/trellis/representations/radiance_field/__pycache__/__init__.cpython-310.pyc index 906a7796483c7c93e54e569dd406dfe6b505a257..78592df48ddea1efdc8644e63cbd84e86774140b 100644 Binary files a/trellis/representations/radiance_field/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/radiance_field/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/radiance_field/__pycache__/strivec.cpython-310.pyc b/trellis/representations/radiance_field/__pycache__/strivec.cpython-310.pyc index f652d69a75d3147c082ca670583115bae80b36dc..f5b2f138ee0a186ab8625df6f04ba6860e6af706 100644 Binary files a/trellis/representations/radiance_field/__pycache__/strivec.cpython-310.pyc and b/trellis/representations/radiance_field/__pycache__/strivec.cpython-310.pyc differ diff --git a/trellis/trainers/__pycache__/__init__.cpython-310.pyc b/trellis/trainers/__pycache__/__init__.cpython-310.pyc index 4ec9b1e628c502e339e441a31e0abd2dcc728eff..3d66e8738c790734fc267669acd0fc7e0400afbf 100644 Binary files a/trellis/trainers/__pycache__/__init__.cpython-310.pyc and b/trellis/trainers/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/trainers/__pycache__/base.cpython-310.pyc b/trellis/trainers/__pycache__/base.cpython-310.pyc index 51961b766bdc8e4cbb64e98c492afca522c81180..a13e271047338dd2c0c2ba96812666b84a11867b 100644 Binary files a/trellis/trainers/__pycache__/base.cpython-310.pyc and b/trellis/trainers/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/trainers/__pycache__/basic.cpython-310.pyc b/trellis/trainers/__pycache__/basic.cpython-310.pyc index 929c822fafddb73d30b2309b542f083fd8bc6ce9..f861ce1d5fbde582fa2cdec7666da796fc01335b 100644 Binary files a/trellis/trainers/__pycache__/basic.cpython-310.pyc and b/trellis/trainers/__pycache__/basic.cpython-310.pyc differ diff --git a/trellis/trainers/__pycache__/utils.cpython-310.pyc b/trellis/trainers/__pycache__/utils.cpython-310.pyc index 7b6d776d5c822cf3590a1a4682a788cdae73a06f..4cad99003efaf97a01f1c62c16cefdbe3371774f 100644 Binary files a/trellis/trainers/__pycache__/utils.cpython-310.pyc and b/trellis/trainers/__pycache__/utils.cpython-310.pyc differ diff --git a/trellis/trainers/flow_matching/__pycache__/flow_matching_alone.cpython-310.pyc b/trellis/trainers/flow_matching/__pycache__/flow_matching_alone.cpython-310.pyc index d5a90712ab875a20b8eb5897735ec120bfc87c86..a0d602004e7bde0c3ee9608248c4bec61fd49288 100644 Binary files a/trellis/trainers/flow_matching/__pycache__/flow_matching_alone.cpython-310.pyc and b/trellis/trainers/flow_matching/__pycache__/flow_matching_alone.cpython-310.pyc differ diff --git a/trellis/trainers/flow_matching/__pycache__/sparse_flow_matching_alone.cpython-310.pyc b/trellis/trainers/flow_matching/__pycache__/sparse_flow_matching_alone.cpython-310.pyc index 1e0bc457178685d79132b84bf685f0af083fd623..39082f7274d529fde666cc25a96deb5efd870d2c 100644 Binary files a/trellis/trainers/flow_matching/__pycache__/sparse_flow_matching_alone.cpython-310.pyc and b/trellis/trainers/flow_matching/__pycache__/sparse_flow_matching_alone.cpython-310.pyc differ diff --git a/trellis/trainers/flow_matching/mixins/__pycache__/classifier_free_guidance.cpython-310.pyc b/trellis/trainers/flow_matching/mixins/__pycache__/classifier_free_guidance.cpython-310.pyc index 78fa07dd78bb664d62c0aba4f7cc73b45641ec8d..415b3cf93cf8438b9ad12f0173d0ee30a1a3ddb6 100644 Binary files a/trellis/trainers/flow_matching/mixins/__pycache__/classifier_free_guidance.cpython-310.pyc and b/trellis/trainers/flow_matching/mixins/__pycache__/classifier_free_guidance.cpython-310.pyc differ diff --git a/trellis/trainers/flow_matching/mixins/__pycache__/image_conditioned.cpython-310.pyc b/trellis/trainers/flow_matching/mixins/__pycache__/image_conditioned.cpython-310.pyc index 2c2064da34ca52bbe9e21e91f105f5bd90837b10..22596d11585af9bdb2365ab6c11a0a497aa03637 100644 Binary files a/trellis/trainers/flow_matching/mixins/__pycache__/image_conditioned.cpython-310.pyc and b/trellis/trainers/flow_matching/mixins/__pycache__/image_conditioned.cpython-310.pyc differ diff --git a/trellis/trainers/flow_matching/mixins/__pycache__/text_conditioned.cpython-310.pyc b/trellis/trainers/flow_matching/mixins/__pycache__/text_conditioned.cpython-310.pyc index 1d728229c60771e8a80208b33cd72ed782ae9a73..a74a0c1284e0ef64d3a04112b482b9fd23f1984c 100644 Binary files a/trellis/trainers/flow_matching/mixins/__pycache__/text_conditioned.cpython-310.pyc and b/trellis/trainers/flow_matching/mixins/__pycache__/text_conditioned.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/__init__.cpython-310.pyc b/trellis/utils/__pycache__/__init__.cpython-310.pyc index 2f1c490ebaf18514464a248bbfacb8006c8b4a16..28151a274bc83bd93236d75b7f6779c53645672d 100644 Binary files a/trellis/utils/__pycache__/__init__.cpython-310.pyc and b/trellis/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/data_utils.cpython-310.pyc b/trellis/utils/__pycache__/data_utils.cpython-310.pyc index 13578645297db97feec71e5cd64a14d92c851ea4..e731acf3e76291c1d1516ee66d291de0e61b06cd 100644 Binary files a/trellis/utils/__pycache__/data_utils.cpython-310.pyc and b/trellis/utils/__pycache__/data_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/dist_utils.cpython-310.pyc b/trellis/utils/__pycache__/dist_utils.cpython-310.pyc index db677a285deea2ca9c6e8491c2b2dbd4180ed466..defa0adb4a517e6581a9c07eb533ea35adddb5fd 100644 Binary files a/trellis/utils/__pycache__/dist_utils.cpython-310.pyc and b/trellis/utils/__pycache__/dist_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/elastic_utils.cpython-310.pyc b/trellis/utils/__pycache__/elastic_utils.cpython-310.pyc index 9906a0c5eb1ad7fb178dfc307b736a49e134bc77..3a918c8283061c6b30969b0f711954086944aa63 100644 Binary files a/trellis/utils/__pycache__/elastic_utils.cpython-310.pyc and b/trellis/utils/__pycache__/elastic_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/general_utils.cpython-310.pyc b/trellis/utils/__pycache__/general_utils.cpython-310.pyc index b1b9d1d951c0c35b0c9c27075c2d47cebaeb7bbe..b64d229f32bc49eec2d23c4b837d1b49a436a474 100644 Binary files a/trellis/utils/__pycache__/general_utils.cpython-310.pyc and b/trellis/utils/__pycache__/general_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/grad_clip_utils.cpython-310.pyc b/trellis/utils/__pycache__/grad_clip_utils.cpython-310.pyc index 0516913aef4506b05b5d9193f7d54c774d034db1..78256ad029315c2aeb81275e14830a8a3b6f1b78 100644 Binary files a/trellis/utils/__pycache__/grad_clip_utils.cpython-310.pyc and b/trellis/utils/__pycache__/grad_clip_utils.cpython-310.pyc differ