Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import h5py | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| import yaml | |
| class BlackHoleShadowDataset(Dataset): | |
| def __init__(self, h5_path, indices=None, transform=None, log_scale_target=True): | |
| self.h5_path = h5_path | |
| self.transform = transform | |
| self.log_scale_target = log_scale_target | |
| with h5py.File(h5_path, 'r') as hf: | |
| self.num_samples = hf['images'].shape[0] | |
| self.rs_values = hf['rs_meters'][:] | |
| self.mass_values = hf['mass_solar'][:] | |
| if indices is not None: | |
| self.indices = indices | |
| else: | |
| self.indices = np.arange(self.num_samples) | |
| if log_scale_target: | |
| self.target_mean = np.log(self.rs_values[self.indices]).mean() | |
| self.target_std = np.log(self.rs_values[self.indices]).std() | |
| else: | |
| self.target_mean = self.rs_values[self.indices].mean() | |
| self.target_std = self.rs_values[self.indices].std() | |
| def __len__(self): | |
| return len(self.indices) | |
| def __getitem__(self, idx): | |
| real_idx = self.indices[idx] | |
| with h5py.File(self.h5_path, 'r') as hf: | |
| image = hf['images'][real_idx] | |
| image = torch.from_numpy(image).float() | |
| image = image.repeat(3, 1, 1) | |
| if self.transform: | |
| image = self.transform(image) | |
| rs = self.rs_values[real_idx] | |
| if self.log_scale_target: | |
| target = (np.log(rs) - self.target_mean) / self.target_std | |
| else: | |
| target = (rs - self.target_mean) / self.target_std | |
| target = torch.tensor(target, dtype=torch.float32) | |
| mass = torch.tensor(self.mass_values[real_idx], dtype=torch.float64) | |
| return image, target, mass | |
| def denormalize_target(self, normalized_target): | |
| if self.log_scale_target: | |
| log_rs = normalized_target * self.target_std + self.target_mean | |
| return np.exp(log_rs) | |
| return normalized_target * self.target_std + self.target_mean | |
| def get_augmentation_transform(): | |
| import torchvision.transforms as T | |
| return T.Compose([ | |
| T.RandomRotation(degrees=180), | |
| T.RandomHorizontalFlip(p=0.5), | |
| T.RandomVerticalFlip(p=0.5), | |
| T.ColorJitter(brightness=0.2, contrast=0.2), | |
| ]) | |
| def get_dataloaders(config_path='configs/config.yaml'): | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| h5_path = config['data']['output_path'] | |
| val_split = float(config['data']['val_split']) | |
| test_split = float(config['data']['test_split']) | |
| batch_size = int(config['cnn']['batch_size']) | |
| with h5py.File(h5_path, 'r') as hf: | |
| total = hf['images'].shape[0] | |
| all_indices = np.arange(total) | |
| np.random.shuffle(all_indices) | |
| n_test = int(total * test_split) | |
| n_val = int(total * val_split) | |
| n_train = total - n_val - n_test | |
| train_indices = all_indices[:n_train] | |
| val_indices = all_indices[n_train:n_train + n_val] | |
| test_indices = all_indices[n_train + n_val:] | |
| train_dataset = BlackHoleShadowDataset(h5_path, train_indices, transform=get_augmentation_transform()) | |
| val_dataset = BlackHoleShadowDataset(h5_path, val_indices, log_scale_target=True) | |
| test_dataset = BlackHoleShadowDataset(h5_path, test_indices, log_scale_target=True) | |
| val_dataset.target_mean = train_dataset.target_mean | |
| val_dataset.target_std = train_dataset.target_std | |
| test_dataset.target_mean = train_dataset.target_mean | |
| test_dataset.target_std = train_dataset.target_std | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) | |
| return train_loader, val_loader, test_loader, train_dataset | |