OneForecast / my_utils /data_loader_nng.py
YuanGao-YG's picture
Upload 97 files
912fe5a verified
import logging
import glob
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch import Tensor
import h5py
import math
from my_utils.norm_nng import reshape_fields
import torch.nn.functional as F
def get_data_loader(params, files_pattern, files_pattern2, distributed, train):
dataset = GetDataset(params, files_pattern, files_pattern2, train)
sampler = DistributedSampler(dataset, shuffle=train) if distributed else None
# DistributedSampler:
# Allocate a part of the dataset to each process/gpu,
# and avoid duplication of data between different processes
dataloader = DataLoader(dataset,
batch_size = int(params.batch_size),
num_workers = params.num_data_workers,
shuffle = False, # (sampler is None),
sampler = sampler if train else None,
drop_last = True,
pin_memory = True) # pin_memory能加快内存的Tensor转义到GPU的显存的速度
if train:
return dataloader, dataset, sampler
else:
return dataloader, dataset
class GetDataset(Dataset):
def __init__(self, params, location, location2, train):
self.params = params
self.location = location
self.location2 = location2
self.train = train
self.orography = params.orography
self.normalize = params.normalize
self.dt = params.dt # 需要预测的时间步
self.n_history = params.n_history # 输入的时间步
self.in_channels = np.array(params.in_channels)
self.out_channels = np.array(params.out_channels)
self.in_channels2 = np.array(params.in_channels2)
self.out_channels2 = np.array(params.out_channels2)
self.atmos_channels = np.array(params.atmos_channels)
self.n_in_channels = len(self.in_channels)
self.n_out_channels = len(self.out_channels)
self._get_files_stats()
self.add_noise = params.add_noise if train else False
self.fusion_3d_2d = params.fusion_3d_2d
# 获取文件统计信息
def _get_files_stats(self):
self.files_paths = glob.glob(self.location + "/*.h5")
self.files_paths.sort()
self.files_paths2 = glob.glob(self.location2 + "/*.h5")
self.files_paths2.sort()
self.n_years = len(self.files_paths)
with h5py.File(self.files_paths[0], 'r') as _f:
logging.info("Getting file stats from {}".format(self.files_paths[0]))
# self.n_samples_per_year = _f['fields'].shape[0] - 1
self.n_samples_per_year = _f['fields'].shape[0] - self.params.multi_steps_finetune
# original image shape (before padding)
self.img_shape_x = _f['fields'].shape[2] - 1 # just get rid of one of the pixels
self.img_shape_y = _f['fields'].shape[3]
self.n_samples_total = self.n_years * self.n_samples_per_year
self.files = [None for _ in range(self.n_years)]
self.files2 = [None for _ in range(self.n_years)]
logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location,
self.n_samples_total,
self.img_shape_x,
self.img_shape_y,
self.n_in_channels))
logging.info("Delta t: {} days".format(1 * self.dt))
logging.info("Including {} days of past history in training at a frequency of {} days".format(
1 * self.dt * self.n_history, 1 * self.dt))
def _open_file(self, year_idx):
_file = h5py.File(self.files_paths[year_idx], 'r')
self.files[year_idx] = _file['fields']
_file2 = h5py.File(self.files_paths2[year_idx], 'r')
self.files2[year_idx] = _file2['fields']
if self.orography and self.params.normalization == 'zscore':
_orog_file = h5py.File(self.params.orography_norm_zscore_path, 'r')
if self.orography and self.params.normalization == 'maxmin':
_orog_file = h5py.File(self.params.orography_norm_maxmin_path, 'r')
def __len__(self):
return self.n_samples_total
def __getitem__(self, global_idx):
year_idx = int(global_idx / self.n_samples_per_year) # which year
local_idx = int(global_idx % self.n_samples_per_year) # which sample in a year
if self.files[year_idx] is None:
self._open_file(year_idx)
# If there are not enough historical time steps available in the features, shift to future time steps.
if local_idx < self.dt * self.n_history:
local_idx += self.dt * self.n_history
# If the sample is the final one for the year, predict the current time step. Otherwise, predict the next time step.
step = 0 if local_idx >= self.n_samples_per_year - self.dt else self.dt
if self.orography:
orog = self.orography_field
if np.shape(orog)[0] == 721:
orog = orog[0:720]
# logging.info(f'orog: {orog.shape}')
else:
orog = None
if self.params.multi_steps_finetune == 1:
if local_idx > 1459:
local_idx = 1459
inp = reshape_fields(
np.nan_to_num(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels, :120, :240], nan=0),
'inp',
self.params,
self.train,
self.normalize,
orog,
self.add_noise
)
inp2 = reshape_fields(
np.nan_to_num(self.files2[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels2, 216:336, 1176:1416], nan=0),
'inp2',
self.params,
self.train,
self.normalize,
orog,
self.add_noise
)
tar = reshape_fields(
np.nan_to_num(self.files2[year_idx][local_idx+step, self.out_channels2, 216:336, 1176:1416], nan=0),
'tar',
self.params,
self.train,
self.normalize,
orog
)
return inp, inp2, tar