import os import numpy as np import torch from torch.utils.data import Dataset, DataLoader import pyvista as pv class Data_loader(Dataset): def __init__(self, cfg, split, epoch_seed=None, mode='train'): """ data_dir: parent directory split: list of int, e.g. [0,1,2,3,4] for train, [5] for val, [6] for test num_points: number of points to sample per geometry epoch_seed: seed for random sampling (for training) mode: 'train', 'val', or 'test' """ self.data_dir = cfg.data_dir self.split = split self.num_points = cfg.num_points self.epoch_seed = epoch_seed self.mode = mode self.cfg = cfg self.meshes = [] self.mesh_names = [] for idx in split: folder = f"{cfg.data_folder}_{idx}" vtp_file = os.path.join(self.data_dir,folder, f"{folder}.vtp") if not os.path.exists(vtp_file): raise FileNotFoundError(f"{vtp_file} not found.") mesh = pv.read(vtp_file) self.meshes.append(mesh) self.mesh_names.append(folder) # For validation chunking self.val_indices = None self.val_chunk_ptr = 0 def set_epoch(self, epoch): self.epoch_seed = epoch self.val_indices = None self.val_chunk_ptr = 0 def __len__(self): if self.mode == 'train': return len(self.meshes) elif self.mode == 'val': return len(self.meshes) elif self.mode == 'test': # Number of chunks = total points in all val meshes // num_points + remainder chunk total = 0 for mesh in self.meshes: return len(self.meshes) else: raise ValueError(f"Unknown mode: {self.mode}") def __getitem__(self, idx): if self.mode == 'train' or self.mode == 'val': # Each item is a geometry, sample num_points randomly mesh = self.meshes[idx] n_pts = mesh.points.shape[0] rng = np.random.default_rng(self.epoch_seed+idx) indices = rng.choice(n_pts, self.num_points, replace=False) pos = mesh.points[indices] target = mesh["pressure"][indices] pos = torch.tensor(pos, dtype=torch.float32) target = torch.tensor(target, dtype=torch.float32).unsqueeze(-1) if self.cfg.normalization == "std_norm": target = (target - self.cfg.press_mean) / self.cfg.press_std if self.cfg.pos_embed_sincos: input_pos_mins = torch.tensor(self.cfg.input_pos_mins) input_pos_maxs = torch.tensor(self.cfg.input_pos_maxs) pos = 1000*(pos - input_pos_mins) / (input_pos_maxs - input_pos_mins) return {"input_pos": pos, "output_feat": target ,"data_id": self.mesh_names[idx]} elif self.mode == 'test': # For each mesh in test, scramble all points and return the full mesh mesh = self.meshes[idx] n_pts = mesh.points.shape[0] rng = np.random.default_rng(self.epoch_seed+idx) indices = rng.permutation(n_pts) pos = mesh.points[indices] target = mesh["pressure"][indices] pos = torch.tensor(pos, dtype=torch.float32) target = torch.tensor(target, dtype=torch.float32).unsqueeze(-1) if self.cfg.normalization == "std_norm": target = (target - self.cfg.press_mean) / self.cfg.press_std if self.cfg.pos_embed_sincos: input_pos_mins = torch.tensor(self.cfg.input_pos_mins) input_pos_maxs = torch.tensor(self.cfg.input_pos_maxs) pos = 1000*(pos - input_pos_mins) / (input_pos_maxs - input_pos_mins) return {"input_pos": pos, "output_feat": target ,"data_id": self.mesh_names[idx],"physical_coordinates":mesh.points[indices]} else: raise ValueError(f"Unknown mode: {self.mode}") def get_dataloaders(cfg): #with open(os.path.join(cfg.splits_file, "train_design_ids.txt")) as f: # train_split = [int(line.strip().split('_')[-1]) for line in f if line.strip()] # with open(os.path.join(cfg.splits_file, "val_design_ids.txt")) as f: # val_split = [int(line.strip().split('_')[-1]) for line in f if line.strip()] # with open(os.path.join(cfg.splits_file, "test_design_ids.txt")) as f: # test_split = [int(line.strip().split('_')[-1]) for line in f if line.strip()] # Create a list from 0 to 399 all_indices = set(range(400)) # Remove numbers present in test_split test_split = [396, 397, 398, 399, 4, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 5, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 6, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 7, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 8, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 9, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] train_split = sorted(list(all_indices - set(test_split))) print("Indices from 0 to 399 not in train_split:", test_split) train_dataset = Data_loader(cfg, train_split, mode='train') val_dataset = Data_loader(cfg, test_split, mode='val') ##?? test_dataset = Data_loader(cfg, test_split, mode='test') train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) return train_loader, val_loader, test_loader