OneForecast / my_utils /data_loader.py
YuanGao-YG's picture
Upload 97 files
912fe5a verified
import logging
import glob
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch import Tensor
import h5py
import math
from my_utils.norm import reshape_fields
def get_data_loader(params, files_pattern, distributed, train):
dataset = GetDataset(params, files_pattern, train)
sampler = DistributedSampler(dataset, shuffle=train) if distributed else None
dataloader = DataLoader(dataset,
batch_size = 1,
num_workers = params.num_data_workers,
shuffle = False,
sampler = sampler if train else None,
drop_last = True,
pin_memory = True)
if train:
return dataloader, dataset, sampler
else:
return dataloader, dataset
class GetDataset(Dataset):
def __init__(self, params, location, train):
self.params = params
self.location = location
self.train = train
self.normalize = params.normalize
self.dt = params.dt
self.n_history = params.n_history
self.in_channels = np.array(params.in_channels)
self.out_channels = np.array(params.out_channels)
self.atmos_channels = np.array(params.atmos_channels)
self.n_in_channels = len(self.in_channels)
self.n_out_channels = len(self.out_channels)
self.add_noise = params.add_noise
self._get_files_stats()
def _get_files_stats(self):
self.files_paths = glob.glob(self.location + "/*.h5")
self.files_paths.sort()
self.n_years = len(self.files_paths)
print('------------', self.files_paths)
with h5py.File(self.files_paths[0], 'r') as _f:
logging.info("Getting file stats from {}".format(self.files_paths[0]))
self.n_samples_per_year = _f['fields'].shape[0] - self.params.multi_steps_finetune
# original image shape (before padding)
self.img_shape_x = _f['fields'].shape[2] - 1 # just get rid of one of the pixels
self.img_shape_y = _f['fields'].shape[3]
self.n_samples_total = self.n_years * self.n_samples_per_year
self.files = [None for _ in range(self.n_years)]
logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location,
self.n_samples_total,
self.img_shape_x,
self.img_shape_y,
self.n_in_channels))
logging.info("Delta t: {} days".format(1 * self.dt))
logging.info("Including {} days of past history in training at a frequency of {} days".format(
1 * self.dt * self.n_history, 1 * self.dt))
def _open_file(self, year_idx):
_file = h5py.File(self.files_paths[year_idx], 'r')
self.files[year_idx] = _file['fields']
def __len__(self):
return self.n_samples_total
def __getitem__(self, global_idx):
year_idx = int(global_idx / self.n_samples_per_year) # which year
local_idx = int(global_idx % self.n_samples_per_year) # which sample in a year
if self.files[year_idx] is None:
self._open_file(year_idx)
if local_idx < self.dt * self.n_history:
local_idx += self.dt * self.n_history
step = 0 if local_idx >= self.n_samples_per_year - self.dt else self.dt
orog = None
if self.params.multi_steps_finetune == 1:
if local_idx == 1463:
local_idx = 1462
if local_idx == 1464:
local_idx = 1463
inp = reshape_fields(
np.nan_to_num(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels, :120, :240], nan=0),
'inp',
self.params,
self.train,
self.normalize,
orog,
self.add_noise
)
tar = reshape_fields(
np.nan_to_num(self.files[year_idx][local_idx+step, self.out_channels, :120, :240], nan=0),
'tar',
self.params,
self.train,
self.normalize,
orog
)
elif self.params.multi_steps_finetune > 1:
if local_idx == 1463:
local_idx = 1462
if local_idx == 1464:
local_idx = 1463
inp = reshape_fields(
np.nan_to_num(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels, :120, :240], nan=0),
'inp',
self.params,
self.train,
self.normalize,
orog,
self.add_noise
)
tar = reshape_fields(
np.nan_to_num(self.files[year_idx][local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.out_channels, :120, :240], nan=0),
'tar',
self.params,
self.train,
self.normalize,
orog
)
return inp, tar