import os import numpy as np import torch from torch.utils.data import Dataset, Subset, DataLoader import torch.distributed as dist import pyvista as pv import logging from torch.utils.data import default_collate import random class SurfacePressureDatasetFullLoad(Dataset): """ Dataset class for loading and preprocessing surface pressure data from DrivAerNet++ VTK files. This dataset handles loading surface meshes with pressure field data, sampling points, and caching processed data for faster loading. """ def __init__(self, root_dir: str, num_points: int, preprocess=False, cache_dir=None): """ Initializes the SurfacePressureDataset instance. Args: root_dir: Directory containing the VTK files for the car surface meshes. num_points: Fixed number of points to sample from each 3D model. preprocess: Flag to indicate if preprocessing should occur or not. cache_dir: Directory where the preprocessed files (NPZ) are stored. """ self.root_dir = root_dir self.vtk_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.vtk')] self.num_points = num_points self.preprocess = preprocess self.cache_dir = cache_dir if cache_dir else os.path.join(root_dir, "processed_data") if not os.path.exists(self.cache_dir): os.makedirs(self.cache_dir) def __len__(self): return len(self.vtk_files) def _get_cache_path(self, vtk_file_path): """Get the corresponding .npz file path for a given .vtk file.""" base_name = os.path.basename(vtk_file_path).replace('.vtk', '.npz') return os.path.join(self.cache_dir, base_name) def _save_to_cache(self, cache_path, point_cloud, pressures): """Save preprocessed point cloud and pressure data into an npz file.""" np.savez_compressed(cache_path, points=point_cloud.points, pressures=pressures) def _load_from_cache(self, cache_path): """Load preprocessed point cloud and pressure data from an npz file.""" data = np.load(cache_path) point_cloud = pv.PolyData(data['points']) pressures = data['pressures'] return point_cloud, pressures def sample_point_cloud_with_pressure(self, point_cloud_tensor, pressures_tensor, n_points=5000): """ Sample n_points from the surface mesh and get corresponding pressure values. Args: mesh: PyVista mesh object with pressure data stored in point_data. n_points: Number of points to sample. Returns: A tuple containing the sampled point cloud and corresponding pressures. """ # print('point_cloud_tensor.shape', point_cloud_tensor.shape) if point_cloud_tensor.shape[0] > n_points: indices = np.random.choice(point_cloud_tensor.shape[0], n_points, replace=False) else: indices = np.arange(point_cloud_tensor.shape[0]) sampled_points = point_cloud_tensor[indices] sampled_pressures = pressures_tensor[indices] return sampled_points, sampled_pressures def __getitem__(self, idx): vtk_file_path = self.vtk_files[idx] cache_path = self._get_cache_path(vtk_file_path) # Check if the data is already cached if os.path.exists(cache_path): logging.info(f"Loading cached data from {cache_path}") point_cloud, pressures = self._load_from_cache(cache_path) else: if self.preprocess: logging.info(f"Preprocessing and caching data for {vtk_file_path}") try: mesh = pv.read(vtk_file_path) except Exception as e: logging.error(f"Failed to load VTK file: {vtk_file_path}. Error: {e}") return None, None # Skip the file and return None point_cloud = pv.PolyData(mesh.points) pressures = mesh.point_data['p'] pressures = pressures.flatten() # Cache the sampled data to a new file self._save_to_cache(cache_path, point_cloud, pressures) else: logging.error(f"Cache file not found for {vtk_file_path} and preprocessing is disabled.") return None, None # Return None if preprocessing is disabled and cache doesn't exist point_cloud_np = np.array(point_cloud.points) point_cloud_tensor = torch.tensor(point_cloud_np, dtype=torch.float32) pressures_tensor = torch.tensor(pressures, dtype=torch.float32).unsqueeze(1) point_cloud_tensor, pressures_tensor = self.sample_point_cloud_with_pressure(point_cloud_tensor, pressures_tensor, self.num_points) # print('point_cloud_tensor.shape', point_cloud_tensor.shape) # print('pressures_tensor.shape', pressures_tensor.shape) data = {'input_pos': point_cloud_tensor, 'output_feat': pressures_tensor, 'output_pos': point_cloud_tensor} return data def create_subset(dataset, ids_file): """ Create a subset of the dataset based on design IDs from a file. Args: dataset: The full dataset ids_file: Path to a file containing design IDs, one per line Returns: A Subset of the dataset containing only the specified designs """ try: with open(ids_file, 'r') as file: subset_ids = [id_.strip() for id_ in file.readlines()] subset_files = [f for f in dataset.vtk_files if any(id_ in f for id_ in subset_ids)] subset_indices = [dataset.vtk_files.index(f) for f in subset_files] if not subset_indices: logging.error(f"No matching VTK files found for IDs in {ids_file}.") return Subset(dataset, subset_indices) except FileNotFoundError as e: logging.error(f"Error loading subset file {ids_file}: {e}") return None def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) g = torch.Generator() g.manual_seed(0) def get_dataloaders(cfg, dataset_path: str, subset_dir: str, num_points: int, batch_size: int, cache_dir: str = None, num_workers: int = 4, model: str = None) -> tuple: """ Prepare and return the training, validation, and test DataLoader objects. Args: dataset_path: Path to the directory containing VTK files subset_dir: Directory containing train/val/test split files num_points: Number of points to sample from each mesh batch_size: Batch size for dataloaders cache_dir: Directory to store processed data num_workers: Number of workers for data loading model: Model type (e.g., 'NeuralCFD') Returns: A tuple of (train_dataloader, val_dataloader, test_dataloader) """ full_dataset = SurfacePressureDatasetFullLoad( root_dir=dataset_path, num_points=num_points, preprocess=True, cache_dir=cache_dir ) train_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'train_design_ids.txt')) val_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'val_design_ids.txt')) test_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'test_design_ids.txt')) collate_fn = None train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, collate_fn=collate_fn, worker_init_fn=seed_worker, generator=g ) val_dataloader = DataLoader( val_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, collate_fn=collate_fn, worker_init_fn=seed_worker, generator=g ) test_dataloader = DataLoader( test_dataset, batch_size=1, shuffle=True, drop_last=False, num_workers=num_workers, collate_fn=collate_fn, worker_init_fn=seed_worker, generator=g ) return train_dataloader, val_dataloader, test_dataloader # Constants for normalization # TODO: these are for full dataset, not the 400 subset, but are good enough for testing PRESSURE_MEAN = -94.5 PRESSURE_STD = 117.25