import torch.nn.functional as F from data import dutils from nowcasting.hko_iterator import HKOIterator def GET_TrainLoader(meta, param, batch_size, in_len, out_len): if meta['dataset'] == 'SEVIR': total_seq_len = in_len + out_len train_config = { 'data_types': ['vil'], 'layout': 'NTCHW', 'seq_len': total_seq_len, 'raw_seq_len': total_seq_len, 'end_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE, 'start_date': None } test_config = { 'data_types': ['vil'], 'layout': 'NTCHW', 'seq_len': total_seq_len, 'raw_seq_len': total_seq_len, 'end_date': None, 'start_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE } train_loader = dutils.SEVIRDataIterator(**train_config, batch_size=batch_size) test_loader = dutils.SEVIRDataIterator(**test_config, batch_size=8 if batch_size > 8 else batch_size) return train_loader, test_loader elif meta['dataset'].startswith('HKO'): total_seq_len = in_len + out_len pkl_path = param['pd_path'] train_loader = HKOIterator(pd_path=pkl_path.replace('test', 'train'), sample_mode="random", seq_len=total_seq_len, stride=1) test_loader = HKOIterator(pd_path=pkl_path, sample_mode="sequent", seq_len=total_seq_len, stride=in_len) return train_loader, test_loader elif meta['dataset'] == 'meteonet': train_loader, test_loader = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8 if batch_size > 8 else batch_size, train=True, **param) return train_loader, test_loader else: raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}') def GET_TestLoader(meta, param, batch_size): if meta['dataset'] == 'SEVIR': return dutils.SEVIRDataIterator(**param, batch_size=batch_size) elif meta['dataset'].startswith('HKO'): return HKOIterator(**param) elif meta['dataset'] == 'meteonet': _, test_iter = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8, train=False, **param) return iter(test_iter) else: raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}')