udbhav
Recreate Trame_app branch with clean history
67fb03c
import os
import numpy as np
import torch
from torch.utils.data import Dataset, Subset, DataLoader
import torch.distributed as dist
import pyvista as pv
import logging
from torch.utils.data import default_collate
import random
class SurfacePressureDatasetFullLoad(Dataset):
"""
Dataset class for loading and preprocessing surface pressure data from DrivAerNet++ VTK files.
This dataset handles loading surface meshes with pressure field data,
sampling points, and caching processed data for faster loading.
"""
def __init__(self, root_dir: str, num_points: int, preprocess=False, cache_dir=None):
"""
Initializes the SurfacePressureDataset instance.
Args:
root_dir: Directory containing the VTK files for the car surface meshes.
num_points: Fixed number of points to sample from each 3D model.
preprocess: Flag to indicate if preprocessing should occur or not.
cache_dir: Directory where the preprocessed files (NPZ) are stored.
"""
self.root_dir = root_dir
self.vtk_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.vtk')]
self.num_points = num_points
self.preprocess = preprocess
self.cache_dir = cache_dir if cache_dir else os.path.join(root_dir, "processed_data")
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
def __len__(self):
return len(self.vtk_files)
def _get_cache_path(self, vtk_file_path):
"""Get the corresponding .npz file path for a given .vtk file."""
base_name = os.path.basename(vtk_file_path).replace('.vtk', '.npz')
return os.path.join(self.cache_dir, base_name)
def _save_to_cache(self, cache_path, point_cloud, pressures):
"""Save preprocessed point cloud and pressure data into an npz file."""
np.savez_compressed(cache_path, points=point_cloud.points, pressures=pressures)
def _load_from_cache(self, cache_path):
"""Load preprocessed point cloud and pressure data from an npz file."""
data = np.load(cache_path)
point_cloud = pv.PolyData(data['points'])
pressures = data['pressures']
return point_cloud, pressures
def sample_point_cloud_with_pressure(self, point_cloud_tensor, pressures_tensor, n_points=5000):
"""
Sample n_points from the surface mesh and get corresponding pressure values.
Args:
mesh: PyVista mesh object with pressure data stored in point_data.
n_points: Number of points to sample.
Returns:
A tuple containing the sampled point cloud and corresponding pressures.
"""
# print('point_cloud_tensor.shape', point_cloud_tensor.shape)
if point_cloud_tensor.shape[0] > n_points:
indices = np.random.choice(point_cloud_tensor.shape[0], n_points, replace=False)
else:
indices = np.arange(point_cloud_tensor.shape[0])
sampled_points = point_cloud_tensor[indices]
sampled_pressures = pressures_tensor[indices]
return sampled_points, sampled_pressures
def __getitem__(self, idx):
vtk_file_path = self.vtk_files[idx]
cache_path = self._get_cache_path(vtk_file_path)
# Check if the data is already cached
if os.path.exists(cache_path):
logging.info(f"Loading cached data from {cache_path}")
point_cloud, pressures = self._load_from_cache(cache_path)
else:
if self.preprocess:
logging.info(f"Preprocessing and caching data for {vtk_file_path}")
try:
mesh = pv.read(vtk_file_path)
except Exception as e:
logging.error(f"Failed to load VTK file: {vtk_file_path}. Error: {e}")
return None, None # Skip the file and return None
point_cloud = pv.PolyData(mesh.points)
pressures = mesh.point_data['p']
pressures = pressures.flatten()
# Cache the sampled data to a new file
self._save_to_cache(cache_path, point_cloud, pressures)
else:
logging.error(f"Cache file not found for {vtk_file_path} and preprocessing is disabled.")
return None, None # Return None if preprocessing is disabled and cache doesn't exist
point_cloud_np = np.array(point_cloud.points)
point_cloud_tensor = torch.tensor(point_cloud_np, dtype=torch.float32)
pressures_tensor = torch.tensor(pressures, dtype=torch.float32).unsqueeze(1)
point_cloud_tensor, pressures_tensor = self.sample_point_cloud_with_pressure(point_cloud_tensor,
pressures_tensor,
self.num_points)
# print('point_cloud_tensor.shape', point_cloud_tensor.shape)
# print('pressures_tensor.shape', pressures_tensor.shape)
data = {'input_pos': point_cloud_tensor, 'output_feat': pressures_tensor, 'output_pos': point_cloud_tensor}
return data
def create_subset(dataset, ids_file):
"""
Create a subset of the dataset based on design IDs from a file.
Args:
dataset: The full dataset
ids_file: Path to a file containing design IDs, one per line
Returns:
A Subset of the dataset containing only the specified designs
"""
try:
with open(ids_file, 'r') as file:
subset_ids = [id_.strip() for id_ in file.readlines()]
subset_files = [f for f in dataset.vtk_files if any(id_ in f for id_ in subset_ids)]
subset_indices = [dataset.vtk_files.index(f) for f in subset_files]
if not subset_indices:
logging.error(f"No matching VTK files found for IDs in {ids_file}.")
return Subset(dataset, subset_indices)
except FileNotFoundError as e:
logging.error(f"Error loading subset file {ids_file}: {e}")
return None
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
def get_dataloaders(cfg, dataset_path: str, subset_dir: str, num_points: int, batch_size: int,
cache_dir: str = None, num_workers: int = 4, model: str = None) -> tuple:
"""
Prepare and return the training, validation, and test DataLoader objects.
Args:
dataset_path: Path to the directory containing VTK files
subset_dir: Directory containing train/val/test split files
num_points: Number of points to sample from each mesh
batch_size: Batch size for dataloaders
cache_dir: Directory to store processed data
num_workers: Number of workers for data loading
model: Model type (e.g., 'NeuralCFD')
Returns:
A tuple of (train_dataloader, val_dataloader, test_dataloader)
"""
full_dataset = SurfacePressureDatasetFullLoad(
root_dir=dataset_path,
num_points=num_points,
preprocess=True,
cache_dir=cache_dir
)
train_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'train_design_ids.txt'))
val_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'val_design_ids.txt'))
test_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'test_design_ids.txt'))
collate_fn = None
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
drop_last=True, num_workers=num_workers, collate_fn=collate_fn,
worker_init_fn=seed_worker, generator=g
)
val_dataloader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=True,
drop_last=True, num_workers=num_workers, collate_fn=collate_fn,
worker_init_fn=seed_worker, generator=g
)
test_dataloader = DataLoader(
test_dataset, batch_size=1, shuffle=True,
drop_last=False, num_workers=num_workers, collate_fn=collate_fn,
worker_init_fn=seed_worker, generator=g
)
return train_dataloader, val_dataloader, test_dataloader
# Constants for normalization
# TODO: these are for full dataset, not the 400 subset, but are good enough for testing
PRESSURE_MEAN = -94.5
PRESSURE_STD = 117.25