|
|
from dataclasses import dataclass |
|
|
import h5py |
|
|
import torch |
|
|
|
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import random |
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from time import time |
|
|
from datetime import datetime |
|
|
import concurrent.futures |
|
|
import psutil |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import socket |
|
|
|
|
|
class Dataset4h5(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
dir_name, |
|
|
num_image=10, |
|
|
field='brightness_temp', |
|
|
idx='range', |
|
|
num_redshift=512, |
|
|
HII_DIM=64, |
|
|
rescale=True, |
|
|
drop_prob = 0, |
|
|
dim=2, |
|
|
transform=True, |
|
|
ranges_dict=None, |
|
|
num_workers=1, |
|
|
startat=0, |
|
|
|
|
|
str_len = 120, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.dir_name = dir_name |
|
|
self.num_image = num_image |
|
|
self.idx = idx |
|
|
self.field = field |
|
|
|
|
|
self.num_redshift = num_redshift |
|
|
self.HII_DIM = HII_DIM |
|
|
self.drop_prob = drop_prob |
|
|
self.dim = dim |
|
|
self.transform = transform |
|
|
self.num_workers = num_workers |
|
|
self.startat = startat |
|
|
self.str_len = str_len |
|
|
|
|
|
self.load_h5() |
|
|
if rescale: |
|
|
rescale_start = time() |
|
|
self.images = self.rescale(self.images, ranges=ranges_dict['images'], to=[-1,1]) |
|
|
self.params = self.rescale(self.params, ranges=ranges_dict['params'], to=[0,1]) |
|
|
rescale_end = time() |
|
|
|
|
|
print(f"images & params rescaled to [{self.images.min()}, {self.images.max()}] & [{self.params.min()}, {self.params.max()}] after {rescale_end-rescale_start:.3f}s") |
|
|
|
|
|
|
|
|
self.len = len(self.params) |
|
|
self.images = torch.from_numpy(self.images) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy() |
|
|
self.params = torch.from_numpy(self.params*cond_filter) |
|
|
|
|
|
|
|
|
def load_h5(self): |
|
|
with h5py.File(self.dir_name, 'r') as f: |
|
|
print(f"dataset content: {f.keys()}") |
|
|
max_num_image = len(f['brightness_temp']) |
|
|
field_shape = f['brightness_temp'].shape[1:] |
|
|
|
|
|
self.params_keys = list(f['params']['keys']) |
|
|
print(f"{max_num_image} {f['brightness_temp'].dtype} images of shape {field_shape} can be loaded with params.keys {self.params_keys}") |
|
|
|
|
|
|
|
|
if self.idx == "random": |
|
|
self.idx = np.sort(random.sample(range(max_num_image), self.num_image)) |
|
|
print(f"loading {self.num_image} images randomly with idx = {self.idx[:5]}...{self.idx[-5:]}") |
|
|
|
|
|
elif self.idx == "range": |
|
|
rank = torch.cuda.current_device() |
|
|
local_world_size = torch.cuda.device_count() |
|
|
self.global_rank = rank + local_world_size * int(os.environ["SLURM_NODEID"]) |
|
|
self.idx = range( |
|
|
self.global_rank*self.num_image, (self.global_rank+1)*self.num_image |
|
|
) |
|
|
print(f"loading {len(self.idx)} images with idx = {self.idx}") |
|
|
else: |
|
|
print(f"loading {len(self.idx)} images with idx = {self.idx}") |
|
|
|
|
|
self.params = np.empty((self.num_image, len(self.params_keys)), dtype=np.float32) |
|
|
if self.dim == 2: |
|
|
self.images = np.empty((self.num_image, 1, self.HII_DIM, self.num_redshift), dtype=np.float32) |
|
|
elif self.dim == 3: |
|
|
self.images = np.empty((self.num_image, 1, self.HII_DIM, self.HII_DIM, self.num_redshift), dtype=np.float32) |
|
|
|
|
|
|
|
|
concurrent_init_start = time() |
|
|
if self.num_workers == 1: |
|
|
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.global_rank}, loading by {self.num_workers} workers, {datetime.now().strftime('%d-%H:%M:%S.%f')}".center(self.str_len, '-')) |
|
|
self.images, self.params = self.read_data_chunk(self.dir_name, self.idx, torch.cuda.current_device(), concurrent_init_start, concurrent_init_start) |
|
|
self.params = self.params.astype(self.images.dtype) |
|
|
concurrent_start = time() |
|
|
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.global_rank}, images {self.images.shape} & params {self.params.shape} loaded after {concurrent_start-concurrent_init_start:.3f}s, {datetime.now().strftime('%d-%H:%M:%S.%f')}".center(self.str_len, '-')) |
|
|
else: |
|
|
with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_workers) as executor: |
|
|
concurrent_init_end = time() |
|
|
print(f" {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.global_rank}, concurrently loading by {self.num_workers}/{len(os.sched_getaffinity(0))} workers, initialized after {concurrent_init_end-concurrent_init_start:.3f}s ".center(self.str_len, '-')) |
|
|
futures = [None] * self.num_workers |
|
|
for i, idx in enumerate(np.array_split(self.idx, self.num_workers)): |
|
|
executor_start = time() |
|
|
futures[i] = executor.submit(self.read_data_chunk, self.dir_name, idx, torch.cuda.current_device(), concurrent_init_end, executor_start) |
|
|
|
|
|
concurrent_start = time() |
|
|
start_idx = 0 |
|
|
for future in concurrent.futures.as_completed(futures): |
|
|
images, params = future.result() |
|
|
batch_size = params.shape[0] |
|
|
self.images[start_idx:start_idx+batch_size] = images |
|
|
self.params[start_idx:start_idx+batch_size] = params |
|
|
start_idx += batch_size |
|
|
concurrent_end = time() |
|
|
print(f" {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.global_rank}, {start_idx} images {self.images.shape} & params {self.params.shape} loaded after {concurrent_start-concurrent_init_start:.3f}/{concurrent_end-concurrent_start:.3f}s ".center(self.str_len, '-')) |
|
|
|
|
|
transform_start = time() |
|
|
if self.transform: |
|
|
self.images = self.flip_rotate(self.images) |
|
|
|
|
|
transform_end = time() |
|
|
print(f"images transformed after {transform_end-transform_start:.3f}s") |
|
|
|
|
|
def read_data_chunk(self, f, idx, device, concurrent_init_end, executor_start): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
set_device = time() |
|
|
torch.cuda.set_device(device) |
|
|
open_h5py = time() |
|
|
with h5py.File(self.dir_name, 'r') as f: |
|
|
images_start = time() |
|
|
if self.dim == 2: |
|
|
images = f[self.field][idx, 0, :self.HII_DIM, self.startat:self.startat+self.num_redshift][:,None] |
|
|
|
|
|
elif self.dim == 3: |
|
|
images = f[self.field][idx, :self.HII_DIM, :self.HII_DIM, self.startat:self.startat+self.num_redshift][:,None] |
|
|
images_end = time() |
|
|
pid = os.getpid() |
|
|
cpu_num = psutil.Process(pid).cpu_num() |
|
|
|
|
|
param_start = time() |
|
|
params = f['params']['values'][idx] |
|
|
param_end = time() |
|
|
print(f"cuda:{torch.cuda.current_device()}/{self.global_rank}, CPU:{cpu_num}, images {images.shape} & params {params.shape} loaded after {executor_start-concurrent_init_end:.3f}/{set_device-executor_start:.3f}/{open_h5py-set_device:.3f}/{images_start-open_h5py:.3f}s + {images_end-images_start:.3f}s & {param_end-param_start:.3f}s") |
|
|
|
|
|
return images, params |
|
|
|
|
|
def flip_rotate(self, img): |
|
|
|
|
|
|
|
|
x_flip_idx = random.sample(range(len(img)), len(img)//2) |
|
|
img[x_flip_idx] = img[x_flip_idx, :, ::-1, :] |
|
|
|
|
|
|
|
|
|
|
|
if img.ndim-2 == 3: |
|
|
y_flip_idx = random.sample(range(len(img)), len(img)//2) |
|
|
xy_flip_idx = random.sample(range(len(img)), len(img)//2) |
|
|
|
|
|
|
|
|
|
|
|
img[y_flip_idx] = img[y_flip_idx, :, :, ::-1, :] |
|
|
img[xy_flip_idx] = img[xy_flip_idx, :, :, :, :].transpose(0,1,3,2,4) |
|
|
return img |
|
|
|
|
|
def rescale(self, value, ranges, to: list): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(np.shape(value)[1]): |
|
|
value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0]) |
|
|
|
|
|
value = value * (to[1]-to[0]) + to[0] |
|
|
return value |
|
|
|
|
|
def __getitem__(self, index): |
|
|
return self.images[index], self.params[index] |
|
|
|
|
|
def __len__(self): |
|
|
return self.len |
|
|
|