ndjadjafbagk / generate_npz.py
udbbdh's picture
Upload folder using huggingface_hub
7340df2 verified
import sys
import os
import yaml
import torch
import os
from torch.utils.data import DataLoader
from functools import partial
# Assuming your custom modules are in the same directory or in the Python path
# from dataset import VoxelVertexDataset_edge, collate_fn_edge
from dataset_triposf import VoxelVertexDataset_edge, collate_fn_pointnet
def inspect_batch(batch, batch_idx, device):
"""
A detailed function to inspect and print information about a single batch.
"""
print(f"\n{'='*20} Inspecting Batch {batch_idx} {'='*20}")
# if batch is None:
# print("Batch is None. Skipping.")
# return
# print("Batch contains the following keys:")
# for key in batch.keys():
# print(f" - {key}")
# print(f"{'='*58}")
def main():
"""
Main function to load configuration, set up the dataset,
and process a few batches for inspection.
"""
import argparse
parser = argparse.ArgumentParser(description="Process and inspect data from the VoxelVertexDataset.")
# parser.add_argument('config_path', type=str, help='Path to the configuration YAML file.')
parser.add_argument('--num_batches', type=int, default=3, help='Number of batches to inspect.')
args = parser.parse_args()
# 1. Load Configuration
# print(f"Loading configuration from: {args.config_path}")
# with open(args.config_path) as f:
# cfg = yaml.safe_load(f)
# 2. Initialize Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 3. Initialize Dataset
print("Initializing dataset...")
# dataset = VoxelVertexDataset_edge(
# root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/final_data_decimate_2',
# base_resolution=512,
# min_resolution=64,
# cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/final_data_decimate_60w_2',
# renders_dir=None,
# )
dataset = VoxelVertexDataset_edge(
root_dir='/root/mesh_split_200complex/mesh_split_200complex_train',
base_resolution=512,
min_resolution=64,
cache_dir='/root/Trisf/dataset_cache/objaverse_200_2000_filtered_final_8354files_512to512',
renders_dir=None,
filter_active_voxels=False,
cache_filter_path='',
active_voxel_res=512,
sample_type='dora',
)
# dataset = VoxelVertexDataset_edge(
# root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/03001627',
# base_resolution=512,
# min_resolution=64,
# cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/03001627',
# renders_dir=None,
# )
# dataset = VoxelVertexDataset_edge(
# root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/03636649',
# base_resolution=512,
# min_resolution=64,
# cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/03636649',
# renders_dir=None,
# )
# dataset = VoxelVertexDataset_edge(
# root_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/meshgpt_data/train/04379243',
# base_resolution=512,
# min_resolution=64,
# cache_dir='/HOME/paratera_xy/pxy1054/HDD_POOL/Triposf/dataset_cache/04379243',
# renders_dir=None,
# )
print(f"Dataset initialized with {len(dataset)} samples.")
# 4. Initialize DataLoader
# We don't need a DistributedSampler here, just a regular DataLoader.
print("Initializing DataLoader...")
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False, # Shuffle for a random sample of batches
collate_fn=partial(collate_fn_pointnet,),
num_workers=24,
pin_memory=True,
)
# 5. Data Processing Loop
print(f"\nStarting data inspection loop for {args.num_batches} batches...")
for i, batch in enumerate(dataloader):
inspect_batch(batch, i, device)
print("\nData inspection complete.")
if __name__ == '__main__':
main()