| from data_provider.data_loader import ( | |
| Dataset_ETT_hour, | |
| Dataset_ETT_minute, | |
| Dataset_Custom, | |
| Dataset_Pred, | |
| ) | |
| from torch.utils.data import DataLoader | |
| data_dict = { | |
| "ETTh1": Dataset_ETT_hour, | |
| "ETTh2": Dataset_ETT_hour, | |
| "ETTm1": Dataset_ETT_minute, | |
| "ETTm2": Dataset_ETT_minute, | |
| "WTH": Dataset_Custom, | |
| "ECL": Dataset_Custom, | |
| "Solar": Dataset_Custom, | |
| "custom": Dataset_Custom, | |
| } | |
| def data_provider(args, flag): | |
| Data = data_dict[args.data] | |
| assert ( | |
| not args.inverse | |
| ) or args.scale, "Can't enable inverse without enabling scale" | |
| if flag == "test": | |
| shuffle_flag = False | |
| drop_last = True | |
| batch_size = args.batch_size | |
| # freq = args.freq | |
| elif flag == "pred": | |
| shuffle_flag = False | |
| drop_last = False | |
| batch_size = 1 | |
| # freq = args.detail_freq | |
| Data = Dataset_Pred | |
| else: | |
| shuffle_flag = True | |
| drop_last = True | |
| batch_size = args.batch_size | |
| # freq = args.freq | |
| data_set = Data(args, flag=flag) | |
| print(flag, len(data_set)) | |
| data_loader = DataLoader( | |
| data_set, | |
| batch_size=batch_size, | |
| shuffle=shuffle_flag, | |
| num_workers=args.num_workers, | |
| drop_last=drop_last, | |
| ) | |
| return data_set, data_loader | |