Spaces:
Runtime error
Runtime error
File size: 11,013 Bytes
67fb03c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
# https://github.com/Mohamedelrefaie/DrivAerNet/blob/main/RegDGCNN_SurfaceFields/data_loader.py
# data_loader.py
"""
@author: Mohamed Elrefaie, mohamed.elrefaie@mit.edu
Data loading utilities for the DrivAerNet++ dataset.
This module provides functionality for loading and preprocessing point cloud data
with pressure field information from the DrivAerNet++ dataset.
"""
import os
import numpy as np
import torch
from torch.utils.data import Dataset, Subset, DataLoader
import torch.distributed as dist
import pyvista as pv
import logging
from torch.utils.data import default_collate
class SurfacePressureDataset(Dataset):
"""
Dataset class for loading and preprocessing surface pressure data from DrivAerNet++ VTK files.
This dataset handles loading surface meshes with pressure field data,
sampling points, and caching processed data for faster loading.
"""
def __init__(self, root_dir: str, num_points: int, preprocess=False, cache_dir=None):
"""
Initializes the SurfacePressureDataset instance.
Args:
root_dir: Directory containing the VTK files for the car surface meshes.
num_points: Fixed number of points to sample from each 3D model.
preprocess: Flag to indicate if preprocessing should occur or not.
cache_dir: Directory where the preprocessed files (NPZ) are stored.
"""
self.root_dir = root_dir
self.vtk_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.vtk')]
self.num_points = num_points
self.preprocess = preprocess
self.cache_dir = cache_dir if cache_dir else os.path.join(root_dir, "processed_data")
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
def __len__(self):
return len(self.vtk_files)
def _get_cache_path(self, vtk_file_path):
"""Get the corresponding .npz file path for a given .vtk file."""
base_name = os.path.basename(vtk_file_path).replace('.vtk', '.npz')
return os.path.join(self.cache_dir, base_name)
def _save_to_cache(self, cache_path, point_cloud, pressures):
"""Save preprocessed point cloud and pressure data into an npz file."""
np.savez_compressed(cache_path, points=point_cloud.points, pressures=pressures)
def _load_from_cache(self, cache_path):
"""Load preprocessed point cloud and pressure data from an npz file."""
data = np.load(cache_path)
point_cloud = pv.PolyData(data['points'])
pressures = data['pressures']
return point_cloud, pressures
def sample_point_cloud_with_pressure(self, mesh, n_points=5000):
"""
Sample n_points from the surface mesh and get corresponding pressure values.
Args:
mesh: PyVista mesh object with pressure data stored in point_data.
n_points: Number of points to sample.
Returns:
A tuple containing the sampled point cloud and corresponding pressures.
"""
print('mesh.n_points', mesh.n_points)
if mesh.n_points > n_points:
indices = np.random.choice(mesh.n_points, n_points, replace=False)
else:
indices = np.arange(mesh.n_points)
logging.info(f"Mesh has only {mesh.n_points} points. Using all available points.")
sampled_points = mesh.points[indices]
sampled_pressures = mesh.point_data['p'][indices] # Assuming pressure data is stored under key 'p'
sampled_pressures = sampled_pressures.flatten()
return pv.PolyData(sampled_points), sampled_pressures
def __getitem__(self, idx):
vtk_file_path = self.vtk_files[idx]
cache_path = self._get_cache_path(vtk_file_path)
# Check if the data is already cached
if os.path.exists(cache_path):
logging.info(f"Loading cached data from {cache_path}")
point_cloud, pressures = self._load_from_cache(cache_path)
else:
if self.preprocess:
logging.info(f"Preprocessing and caching data for {vtk_file_path}")
try:
mesh = pv.read(vtk_file_path)
except Exception as e:
logging.error(f"Failed to load VTK file: {vtk_file_path}. Error: {e}")
return None, None # Skip the file and return None
point_cloud, pressures = self.sample_point_cloud_with_pressure(mesh, self.num_points)
# Cache the sampled data to a new file
self._save_to_cache(cache_path, point_cloud, pressures)
else:
logging.error(f"Cache file not found for {vtk_file_path} and preprocessing is disabled.")
return None, None # Return None if preprocessing is disabled and cache doesn't exist
point_cloud_np = np.array(point_cloud.points)
# point_cloud_tensor = torch.tensor(point_cloud_np.T[np.newaxis, :, :], dtype=torch.float32)
# pressures_tensor = torch.tensor(pressures[np.newaxis, :], dtype=torch.float32)
point_cloud_tensor = torch.tensor(point_cloud_np, dtype=torch.float32)
pressures_tensor = torch.tensor(pressures, dtype=torch.float32).unsqueeze(1)
data = {'input_pos': point_cloud_tensor, 'output_feat': pressures_tensor, 'output_pos': point_cloud_tensor}
return data
def create_subset(dataset, ids_file):
"""
Create a subset of the dataset based on design IDs from a file.
Args:
dataset: The full dataset
ids_file: Path to a file containing design IDs, one per line
Returns:
A Subset of the dataset containing only the specified designs
"""
try:
with open(ids_file, 'r') as file:
subset_ids = [id_.strip() for id_ in file.readlines()]
subset_files = [f for f in dataset.vtk_files if any(id_ in f for id_ in subset_ids)]
subset_indices = [dataset.vtk_files.index(f) for f in subset_files]
if not subset_indices:
logging.error(f"No matching VTK files found for IDs in {ids_file}.")
return Subset(dataset, subset_indices)
except FileNotFoundError as e:
logging.error(f"Error loading subset file {ids_file}: {e}")
return None
def calculate_normalization_constants(dataloader):
"""
Calculate normalization constants for both pressure values and coordinate ranges
across the entire training dataset.
Args:
dataloader: Training DataLoader
Returns:
tuple: (pressure_mean, pressure_std, coord_ranges)
where coord_ranges = {'min_x', 'max_x', 'min_y', 'max_y', 'min_z', 'max_z'}
"""
all_pressures = []
# Initialize coordinate extremes
max_x = float('-inf')
max_y = float('-inf')
max_z = float('-inf')
min_x = float('inf')
min_y = float('inf')
min_z = float('inf')
print("Calculating normalization constants...")
for batch_idx, batch in enumerate(dataloader):
# Process pressure values
output_feat = batch['output_feat']
pressures = output_feat.flatten().numpy()
all_pressures.extend(pressures)
# print('pressures', pressures.shape)
# Process coordinate ranges
input_pos = batch['input_pos']
# Convert tensor to numpy for coordinate calculations
input_pos_np = input_pos.numpy()
max_x = max(max_x, np.max(input_pos_np[:,:,0]))
max_y = max(max_y, np.max(input_pos_np[:,:,1]))
max_z = max(max_z, np.max(input_pos_np[:,:,2]))
min_x = min(min_x, np.min(input_pos_np[:,:,0]))
min_y = min(min_y, np.min(input_pos_np[:,:,1]))
min_z = min(min_z, np.min(input_pos_np[:,:,2]))
# if batch_idx % 10 == 0: # Print progress every 10 batches
# print(f"Processed {batch_idx + 1} batches...")
# Convert to numpy array for efficient computation
all_pressures = np.array(all_pressures)
# Calculate pressure statistics
pressure_mean = np.mean(all_pressures)
pressure_std = np.std(all_pressures)
# Store coordinate ranges
coord_ranges = {
'min_x': min_x, 'max_x': max_x,
'min_y': min_y, 'max_y': max_y,
'min_z': min_z, 'max_z': max_z
}
# Print comprehensive statistics
print(f"\nPressure statistics from {len(all_pressures)} data points:")
print(f"Mean: {pressure_mean:.6f}")
print(f"Std: {pressure_std:.6f}")
print(f"Min: {np.min(all_pressures):.6f}")
print(f"Max: {np.max(all_pressures):.6f}")
print(f"\nCoordinate ranges:")
print(f"X: [{min_x:.6f}, {max_x:.6f}]")
print(f"Y: [{min_y:.6f}, {max_y:.6f}]")
print(f"Z: [{min_z:.6f}, {max_z:.6f}]")
return pressure_mean, pressure_std, coord_ranges
def get_dataloaders(cfg, dataset_path: str, subset_dir: str, num_points: int, batch_size: int,
cache_dir: str = None, num_workers: int = 4, model: str = None) -> tuple:
"""
Prepare and return the training, validation, and test DataLoader objects.
Args:
dataset_path: Path to the directory containing VTK files
subset_dir: Directory containing train/val/test split files
num_points: Number of points to sample from each mesh
batch_size: Batch size for dataloaders
cache_dir: Directory to store processed data
num_workers: Number of workers for data loading
Returns:
A tuple of (train_dataloader, val_dataloader, test_dataloader)
"""
full_dataset = SurfacePressureDataset(
root_dir=dataset_path,
num_points=num_points,
preprocess=True,
cache_dir=cache_dir
)
train_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'train_design_ids.txt'))
val_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'val_design_ids.txt'))
test_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'test_design_ids.txt'))
collate_fn = None
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
drop_last=True, num_workers=num_workers, collate_fn=collate_fn
)
val_dataloader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=True,
drop_last=True, num_workers=num_workers, collate_fn=collate_fn
)
test_dataloader = DataLoader(
test_dataset, batch_size=1, shuffle=True,
drop_last=False, num_workers=num_workers, collate_fn=collate_fn
)
# # Calculate normalization constants
# pressure_mean, pressure_std, coord_ranges = calculate_normalization_constants(train_dataloader)
# exit()
return train_dataloader, val_dataloader, test_dataloader
# Constants for normalization
# These are for full dataset, not the 400 subset, but are good enough for testing
PRESSURE_MEAN = -94.5
PRESSURE_STD = 117.25
# # For 400 subset:
# PRESSURE_MEAN =-93.573677
# PRESSURE_STD = 114.631371
|