TEAMS / lib /datasets /make_dataset.py
Richard-ZZZZZ's picture
Upload folder using huggingface_hub
e168a4d verified
from .transforms import make_transforms
from . import samplers
from .dataset_catalog import DatasetCatalog
import torch
import torch.utils.data
import imp
import os
from .collate_batch import make_collator
torch.multiprocessing.set_sharing_strategy('file_system')
def _dataset_factory(data_source, task):
module = '.'.join(['lib.datasets', data_source, task])
path = os.path.join('lib/datasets', data_source, task+'.py')
dataset = imp.load_source(module, path).Dataset # lib/datasets/sbd/snake.py
return dataset
# 动态加载类, 函数返回 Dataset 类,而不是类的实例,用的时候要创建实例化对象。
def make_dataset(cfg, dataset_name, transforms, is_train=True):
args = DatasetCatalog.get(dataset_name)
''' :args 长这个样子: 'SbdTrain': {
'id': 'sbd',
'data_root': 'data/sbd/img',
'ann_file': 'data/sbd/annotations/sbd_train_instance.json',
'split': 'train'
},'''
data_source = args['id']
dataset = _dataset_factory(data_source, cfg.task)
del args['id'] # 删除args字典中的元素(id为键)
# args['cfg'] = cfg
# args['transforms'] = transforms
# args['is_train'] = is_train
dataset = dataset(**args) # 创建 Dataset 类的实例化对象
return dataset
def make_data_sampler(dataset, shuffle):
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
return sampler
def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter):
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last)
if max_iter != -1:
batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, max_iter)
return batch_sampler
def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1):
if is_train:
batch_size = cfg.train.batch_size
shuffle = True
drop_last = False
else:
batch_size = cfg.test.batch_size
shuffle = True if is_distributed else False
drop_last = False
dataset_name = cfg.train.dataset if is_train else cfg.test.dataset
transforms = make_transforms(cfg, is_train)
dataset = make_dataset(cfg, dataset_name, transforms, is_train)
sampler = make_data_sampler(dataset, shuffle)
batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter)
num_workers = cfg.train.num_workers
collator = make_collator(cfg)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collator
)
return data_loader