|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
from torch.utils.data import ConcatDataset |
|
|
|
|
|
from .base import BaseDataset |
|
|
|
|
|
|
|
|
def build_train(config, device): |
|
|
data_list = [] |
|
|
total_images = 0 |
|
|
for dataset in config.training_data: |
|
|
dataset_name = dataset.upper() |
|
|
config.n_train = np.Inf |
|
|
if type(dataset) is list: |
|
|
dataset_name, n_train = dataset |
|
|
config.n_train = n_train |
|
|
|
|
|
dataset = BaseDataset(name=dataset_name, config=config, device=device, isEval=False) |
|
|
data_list.append(dataset) |
|
|
total_images += dataset.total_images |
|
|
|
|
|
return ConcatDataset(data_list), total_images |
|
|
|
|
|
|
|
|
def build_val(config, device): |
|
|
data_list = [] |
|
|
total_images = 0 |
|
|
for dataset in config.eval_data: |
|
|
dataset_name = dataset.upper() |
|
|
config.n_train = np.Inf |
|
|
if type(dataset) is list: |
|
|
dataset_name, n_train = dataset |
|
|
config.n_train = n_train |
|
|
|
|
|
dataset = BaseDataset(name=dataset_name, config=config, device=device, isEval=True) |
|
|
data_list.append(dataset) |
|
|
total_images += dataset.total_images |
|
|
|
|
|
return ConcatDataset(data_list), total_images |
|
|
|