udbhav
Recreate Trame_app branch with clean history
67fb03c
# https://github.com/Mohamedelrefaie/DrivAerNet/blob/main/RegDGCNN_SurfaceFields/data_loader.py
# data_loader.py
"""
@author: Mohamed Elrefaie, mohamed.elrefaie@mit.edu
Data loading utilities for the DrivAerNet++ dataset.
This module provides functionality for loading and preprocessing point cloud data
with pressure field information from the DrivAerNet++ dataset.
"""
import os
import numpy as np
import torch
from torch.utils.data import Dataset, Subset, DataLoader
import torch.distributed as dist
import pyvista as pv
import logging
from torch.utils.data import default_collate
class SurfacePressureDataset(Dataset):
"""
Dataset class for loading and preprocessing surface pressure data from DrivAerNet++ VTK files.
This dataset handles loading surface meshes with pressure field data,
sampling points, and caching processed data for faster loading.
"""
def __init__(self, root_dir: str, num_points: int, preprocess=False, cache_dir=None):
"""
Initializes the SurfacePressureDataset instance.
Args:
root_dir: Directory containing the VTK files for the car surface meshes.
num_points: Fixed number of points to sample from each 3D model.
preprocess: Flag to indicate if preprocessing should occur or not.
cache_dir: Directory where the preprocessed files (NPZ) are stored.
"""
self.root_dir = root_dir
self.vtk_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.vtk')]
self.num_points = num_points
self.preprocess = preprocess
self.cache_dir = cache_dir if cache_dir else os.path.join(root_dir, "processed_data")
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
def __len__(self):
return len(self.vtk_files)
def _get_cache_path(self, vtk_file_path):
"""Get the corresponding .npz file path for a given .vtk file."""
base_name = os.path.basename(vtk_file_path).replace('.vtk', '.npz')
return os.path.join(self.cache_dir, base_name)
def _save_to_cache(self, cache_path, point_cloud, pressures):
"""Save preprocessed point cloud and pressure data into an npz file."""
np.savez_compressed(cache_path, points=point_cloud.points, pressures=pressures)
def _load_from_cache(self, cache_path):
"""Load preprocessed point cloud and pressure data from an npz file."""
data = np.load(cache_path)
point_cloud = pv.PolyData(data['points'])
pressures = data['pressures']
return point_cloud, pressures
def sample_point_cloud_with_pressure(self, mesh, n_points=5000):
"""
Sample n_points from the surface mesh and get corresponding pressure values.
Args:
mesh: PyVista mesh object with pressure data stored in point_data.
n_points: Number of points to sample.
Returns:
A tuple containing the sampled point cloud and corresponding pressures.
"""
print('mesh.n_points', mesh.n_points)
if mesh.n_points > n_points:
indices = np.random.choice(mesh.n_points, n_points, replace=False)
else:
indices = np.arange(mesh.n_points)
logging.info(f"Mesh has only {mesh.n_points} points. Using all available points.")
sampled_points = mesh.points[indices]
sampled_pressures = mesh.point_data['p'][indices] # Assuming pressure data is stored under key 'p'
sampled_pressures = sampled_pressures.flatten()
return pv.PolyData(sampled_points), sampled_pressures
def __getitem__(self, idx):
vtk_file_path = self.vtk_files[idx]
cache_path = self._get_cache_path(vtk_file_path)
# Check if the data is already cached
if os.path.exists(cache_path):
logging.info(f"Loading cached data from {cache_path}")
point_cloud, pressures = self._load_from_cache(cache_path)
else:
if self.preprocess:
logging.info(f"Preprocessing and caching data for {vtk_file_path}")
try:
mesh = pv.read(vtk_file_path)
except Exception as e:
logging.error(f"Failed to load VTK file: {vtk_file_path}. Error: {e}")
return None, None # Skip the file and return None
point_cloud, pressures = self.sample_point_cloud_with_pressure(mesh, self.num_points)
# Cache the sampled data to a new file
self._save_to_cache(cache_path, point_cloud, pressures)
else:
logging.error(f"Cache file not found for {vtk_file_path} and preprocessing is disabled.")
return None, None # Return None if preprocessing is disabled and cache doesn't exist
point_cloud_np = np.array(point_cloud.points)
# point_cloud_tensor = torch.tensor(point_cloud_np.T[np.newaxis, :, :], dtype=torch.float32)
# pressures_tensor = torch.tensor(pressures[np.newaxis, :], dtype=torch.float32)
point_cloud_tensor = torch.tensor(point_cloud_np, dtype=torch.float32)
pressures_tensor = torch.tensor(pressures, dtype=torch.float32).unsqueeze(1)
data = {'input_pos': point_cloud_tensor, 'output_feat': pressures_tensor, 'output_pos': point_cloud_tensor}
return data
def create_subset(dataset, ids_file):
"""
Create a subset of the dataset based on design IDs from a file.
Args:
dataset: The full dataset
ids_file: Path to a file containing design IDs, one per line
Returns:
A Subset of the dataset containing only the specified designs
"""
try:
with open(ids_file, 'r') as file:
subset_ids = [id_.strip() for id_ in file.readlines()]
subset_files = [f for f in dataset.vtk_files if any(id_ in f for id_ in subset_ids)]
subset_indices = [dataset.vtk_files.index(f) for f in subset_files]
if not subset_indices:
logging.error(f"No matching VTK files found for IDs in {ids_file}.")
return Subset(dataset, subset_indices)
except FileNotFoundError as e:
logging.error(f"Error loading subset file {ids_file}: {e}")
return None
def calculate_normalization_constants(dataloader):
"""
Calculate normalization constants for both pressure values and coordinate ranges
across the entire training dataset.
Args:
dataloader: Training DataLoader
Returns:
tuple: (pressure_mean, pressure_std, coord_ranges)
where coord_ranges = {'min_x', 'max_x', 'min_y', 'max_y', 'min_z', 'max_z'}
"""
all_pressures = []
# Initialize coordinate extremes
max_x = float('-inf')
max_y = float('-inf')
max_z = float('-inf')
min_x = float('inf')
min_y = float('inf')
min_z = float('inf')
print("Calculating normalization constants...")
for batch_idx, batch in enumerate(dataloader):
# Process pressure values
output_feat = batch['output_feat']
pressures = output_feat.flatten().numpy()
all_pressures.extend(pressures)
# print('pressures', pressures.shape)
# Process coordinate ranges
input_pos = batch['input_pos']
# Convert tensor to numpy for coordinate calculations
input_pos_np = input_pos.numpy()
max_x = max(max_x, np.max(input_pos_np[:,:,0]))
max_y = max(max_y, np.max(input_pos_np[:,:,1]))
max_z = max(max_z, np.max(input_pos_np[:,:,2]))
min_x = min(min_x, np.min(input_pos_np[:,:,0]))
min_y = min(min_y, np.min(input_pos_np[:,:,1]))
min_z = min(min_z, np.min(input_pos_np[:,:,2]))
# if batch_idx % 10 == 0: # Print progress every 10 batches
# print(f"Processed {batch_idx + 1} batches...")
# Convert to numpy array for efficient computation
all_pressures = np.array(all_pressures)
# Calculate pressure statistics
pressure_mean = np.mean(all_pressures)
pressure_std = np.std(all_pressures)
# Store coordinate ranges
coord_ranges = {
'min_x': min_x, 'max_x': max_x,
'min_y': min_y, 'max_y': max_y,
'min_z': min_z, 'max_z': max_z
}
# Print comprehensive statistics
print(f"\nPressure statistics from {len(all_pressures)} data points:")
print(f"Mean: {pressure_mean:.6f}")
print(f"Std: {pressure_std:.6f}")
print(f"Min: {np.min(all_pressures):.6f}")
print(f"Max: {np.max(all_pressures):.6f}")
print(f"\nCoordinate ranges:")
print(f"X: [{min_x:.6f}, {max_x:.6f}]")
print(f"Y: [{min_y:.6f}, {max_y:.6f}]")
print(f"Z: [{min_z:.6f}, {max_z:.6f}]")
return pressure_mean, pressure_std, coord_ranges
def get_dataloaders(cfg, dataset_path: str, subset_dir: str, num_points: int, batch_size: int,
cache_dir: str = None, num_workers: int = 4, model: str = None) -> tuple:
"""
Prepare and return the training, validation, and test DataLoader objects.
Args:
dataset_path: Path to the directory containing VTK files
subset_dir: Directory containing train/val/test split files
num_points: Number of points to sample from each mesh
batch_size: Batch size for dataloaders
cache_dir: Directory to store processed data
num_workers: Number of workers for data loading
Returns:
A tuple of (train_dataloader, val_dataloader, test_dataloader)
"""
full_dataset = SurfacePressureDataset(
root_dir=dataset_path,
num_points=num_points,
preprocess=True,
cache_dir=cache_dir
)
train_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'train_design_ids.txt'))
val_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'val_design_ids.txt'))
test_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'test_design_ids.txt'))
collate_fn = None
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
drop_last=True, num_workers=num_workers, collate_fn=collate_fn
)
val_dataloader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=True,
drop_last=True, num_workers=num_workers, collate_fn=collate_fn
)
test_dataloader = DataLoader(
test_dataset, batch_size=1, shuffle=True,
drop_last=False, num_workers=num_workers, collate_fn=collate_fn
)
# # Calculate normalization constants
# pressure_mean, pressure_std, coord_ranges = calculate_normalization_constants(train_dataloader)
# exit()
return train_dataloader, val_dataloader, test_dataloader
# Constants for normalization
# These are for full dataset, not the 400 subset, but are good enough for testing
PRESSURE_MEAN = -94.5
PRESSURE_STD = 117.25
# # For 400 subset:
# PRESSURE_MEAN =-93.573677
# PRESSURE_STD = 114.631371