jaredhwang's picture
Migrate benchmark from https://github.com/kitamoto-lab/benchmarks/
02cdcbc
from torch.utils.data import DataLoader
def load(type,dataset,batch_size,num_workers,type_save='standard'):
train, test = [],[]
if (type_save=='standard') :
file_dir = 'save/'
if (type_save=='same_size') :
file_dir = 'save_same/'
if type==0 :
with open(file_dir + 'old_train.txt','r') as file:
train_id=[line for line in file]
with open(file_dir + 'old_val.txt','r') as file:
test_id =[line for line in file]
if type==1 :
with open(file_dir + 'recent_train.txt','r') as file:
train_id=[line for line in file]
with open(file_dir + 'recent_val.txt','r') as file:
test_id =[line for line in file]
if type==2 :
with open(file_dir + 'now_train.txt','r') as file:
train_id=[line for line in file]
with open(file_dir + 'now_val.txt','r') as file:
test_id =[line for line in file]
if type==3 :
with open(file_dir + 'now_train.txt','r') as file:
train_id1=[line for line in file]
with open(file_dir + 'now_val.txt','r') as file:
test_id1 =[line for line in file]
with open(file_dir + 'recent_train.txt','r') as file:
train_id2=[line for line in file]
with open(file_dir + 'recent_val.txt','r') as file:
test_id2 =[line for line in file]
train_id = train_id1 +train_id2
test_id = test_id1+ test_id2
train_id = [x.replace('\n', '') for x in train_id]
test_id = [x.replace('\n','') for x in test_id]
train = DataLoader(dataset.images_from_sequences(train_id),batch_size= batch_size,num_workers=num_workers,shuffle=True)
test = DataLoader(dataset.images_from_sequences(test_id),batch_size= batch_size,num_workers=num_workers,shuffle=False)
return train, test