schwarznet / data /preprocessing.py
That-Random-Coder
Deploy lightweight app and pre-trained model
851541d
Raw
History Blame Contribute Delete
3.97 kB
import numpy as np
import h5py
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import yaml
class BlackHoleShadowDataset(Dataset):
def __init__(self, h5_path, indices=None, transform=None, log_scale_target=True):
self.h5_path = h5_path
self.transform = transform
self.log_scale_target = log_scale_target
with h5py.File(h5_path, 'r') as hf:
self.num_samples = hf['images'].shape[0]
self.rs_values = hf['rs_meters'][:]
self.mass_values = hf['mass_solar'][:]
if indices is not None:
self.indices = indices
else:
self.indices = np.arange(self.num_samples)
if log_scale_target:
self.target_mean = np.log(self.rs_values[self.indices]).mean()
self.target_std = np.log(self.rs_values[self.indices]).std()
else:
self.target_mean = self.rs_values[self.indices].mean()
self.target_std = self.rs_values[self.indices].std()
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
real_idx = self.indices[idx]
with h5py.File(self.h5_path, 'r') as hf:
image = hf['images'][real_idx]
image = torch.from_numpy(image).float()
image = image.repeat(3, 1, 1)
if self.transform:
image = self.transform(image)
rs = self.rs_values[real_idx]
if self.log_scale_target:
target = (np.log(rs) - self.target_mean) / self.target_std
else:
target = (rs - self.target_mean) / self.target_std
target = torch.tensor(target, dtype=torch.float32)
mass = torch.tensor(self.mass_values[real_idx], dtype=torch.float64)
return image, target, mass
def denormalize_target(self, normalized_target):
if self.log_scale_target:
log_rs = normalized_target * self.target_std + self.target_mean
return np.exp(log_rs)
return normalized_target * self.target_std + self.target_mean
def get_augmentation_transform():
import torchvision.transforms as T
return T.Compose([
T.RandomRotation(degrees=180),
T.RandomHorizontalFlip(p=0.5),
T.RandomVerticalFlip(p=0.5),
T.ColorJitter(brightness=0.2, contrast=0.2),
])
def get_dataloaders(config_path='configs/config.yaml'):
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
h5_path = config['data']['output_path']
val_split = float(config['data']['val_split'])
test_split = float(config['data']['test_split'])
batch_size = int(config['cnn']['batch_size'])
with h5py.File(h5_path, 'r') as hf:
total = hf['images'].shape[0]
all_indices = np.arange(total)
np.random.shuffle(all_indices)
n_test = int(total * test_split)
n_val = int(total * val_split)
n_train = total - n_val - n_test
train_indices = all_indices[:n_train]
val_indices = all_indices[n_train:n_train + n_val]
test_indices = all_indices[n_train + n_val:]
train_dataset = BlackHoleShadowDataset(h5_path, train_indices, transform=get_augmentation_transform())
val_dataset = BlackHoleShadowDataset(h5_path, val_indices, log_scale_target=True)
test_dataset = BlackHoleShadowDataset(h5_path, test_indices, log_scale_target=True)
val_dataset.target_mean = train_dataset.target_mean
val_dataset.target_std = train_dataset.target_std
test_dataset.target_mean = train_dataset.target_mean
test_dataset.target_std = train_dataset.target_std
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
return train_loader, val_loader, test_loader, train_dataset