from data_provider.data_loader import ( Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred, ) from torch.utils.data import DataLoader data_dict = { "ETTh1": Dataset_ETT_hour, "ETTh2": Dataset_ETT_hour, "ETTm1": Dataset_ETT_minute, "ETTm2": Dataset_ETT_minute, "WTH": Dataset_Custom, "ECL": Dataset_Custom, "Solar": Dataset_Custom, "custom": Dataset_Custom, } def data_provider(args, flag): Data = data_dict[args.data] assert ( not args.inverse ) or args.scale, "Can't enable inverse without enabling scale" if flag == "test": shuffle_flag = False drop_last = True batch_size = args.batch_size # freq = args.freq elif flag == "pred": shuffle_flag = False drop_last = False batch_size = 1 # freq = args.detail_freq Data = Dataset_Pred else: shuffle_flag = True drop_last = True batch_size = args.batch_size # freq = args.freq data_set = Data(args, flag=flag) print(flag, len(data_set)) data_loader = DataLoader( data_set, batch_size=batch_size, shuffle=shuffle_flag, num_workers=args.num_workers, drop_last=drop_last, ) return data_set, data_loader