import numpy as np import h5py import torch from torch.utils.data import Dataset, DataLoader, random_split import yaml class BlackHoleShadowDataset(Dataset): def __init__(self, h5_path, indices=None, transform=None, log_scale_target=True): self.h5_path = h5_path self.transform = transform self.log_scale_target = log_scale_target with h5py.File(h5_path, 'r') as hf: self.num_samples = hf['images'].shape[0] self.rs_values = hf['rs_meters'][:] self.mass_values = hf['mass_solar'][:] if indices is not None: self.indices = indices else: self.indices = np.arange(self.num_samples) if log_scale_target: self.target_mean = np.log(self.rs_values[self.indices]).mean() self.target_std = np.log(self.rs_values[self.indices]).std() else: self.target_mean = self.rs_values[self.indices].mean() self.target_std = self.rs_values[self.indices].std() def __len__(self): return len(self.indices) def __getitem__(self, idx): real_idx = self.indices[idx] with h5py.File(self.h5_path, 'r') as hf: image = hf['images'][real_idx] image = torch.from_numpy(image).float() image = image.repeat(3, 1, 1) if self.transform: image = self.transform(image) rs = self.rs_values[real_idx] if self.log_scale_target: target = (np.log(rs) - self.target_mean) / self.target_std else: target = (rs - self.target_mean) / self.target_std target = torch.tensor(target, dtype=torch.float32) mass = torch.tensor(self.mass_values[real_idx], dtype=torch.float64) return image, target, mass def denormalize_target(self, normalized_target): if self.log_scale_target: log_rs = normalized_target * self.target_std + self.target_mean return np.exp(log_rs) return normalized_target * self.target_std + self.target_mean def get_augmentation_transform(): import torchvision.transforms as T return T.Compose([ T.RandomRotation(degrees=180), T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.5), T.ColorJitter(brightness=0.2, contrast=0.2), ]) def get_dataloaders(config_path='configs/config.yaml'): with open(config_path, 'r') as f: config = yaml.safe_load(f) h5_path = config['data']['output_path'] val_split = float(config['data']['val_split']) test_split = float(config['data']['test_split']) batch_size = int(config['cnn']['batch_size']) with h5py.File(h5_path, 'r') as hf: total = hf['images'].shape[0] all_indices = np.arange(total) np.random.shuffle(all_indices) n_test = int(total * test_split) n_val = int(total * val_split) n_train = total - n_val - n_test train_indices = all_indices[:n_train] val_indices = all_indices[n_train:n_train + n_val] test_indices = all_indices[n_train + n_val:] train_dataset = BlackHoleShadowDataset(h5_path, train_indices, transform=get_augmentation_transform()) val_dataset = BlackHoleShadowDataset(h5_path, val_indices, log_scale_target=True) test_dataset = BlackHoleShadowDataset(h5_path, test_indices, log_scale_target=True) val_dataset.target_mean = train_dataset.target_mean val_dataset.target_std = train_dataset.target_std test_dataset.target_mean = train_dataset.target_mean test_dataset.target_std = train_dataset.target_std train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, val_loader, test_loader, train_dataset