| | import sys |
| | sys.path.append('rscd') |
| | from typing import Iterable, Optional, Sequence, Union |
| | from torch.utils.data import DataLoader, Dataset, Sampler |
| | from torch.utils.data.dataloader import _collate_fn_t, _worker_init_fn_t |
| | from rscd.datasets.levircd_dataset import * |
| | from rscd.datasets.whucd_dataset import * |
| | from rscd.datasets.dsifn_dataset import * |
| | from rscd.datasets.clcd_dataset import * |
| | from rscd.datasets.sysucd_dataset import * |
| | from rscd.datasets.base_dataset import * |
| |
|
| | def get_loader(dataset, cfg): |
| | loader = DataLoader( |
| | dataset=dataset, |
| | batch_size=cfg.batch_size, |
| | num_workers=cfg.num_workers, |
| | pin_memory=cfg.pin_memory, |
| | shuffle=cfg.shuffle, |
| | drop_last=cfg.drop_last |
| | ) |
| | return loader |
| |
|
| | |
| | def build_dataloader(cfg, mode='train'): |
| | dataset_type = cfg.type |
| | data_root = cfg.data_root |
| | if mode == 'train': |
| | dataset = eval(dataset_type)(data_root, mode, **cfg.train_mode) |
| | loader_cfg = cfg.train_mode.loader |
| | elif mode == 'val': |
| | dataset = eval(dataset_type)(data_root, mode, **cfg.val_mode) |
| | loader_cfg = cfg.val_mode.loader |
| | else: |
| | dataset = eval(dataset_type)(data_root, mode, **cfg.test_mode) |
| | loader_cfg = cfg.test_mode.loader |
| |
|
| | data_loader = DataLoader( |
| | dataset = dataset, |
| | batch_size = loader_cfg.batch_size, |
| | num_workers = loader_cfg.num_workers, |
| | pin_memory = loader_cfg.pin_memory, |
| | shuffle = loader_cfg.shuffle, |
| | drop_last = loader_cfg.drop_last |
| | ) |
| | |
| | return data_loader |
| |
|
| | if __name__ == '__main__': |
| | file_path = "E:/zjuse/2308CD/rschangedetection/configs/BIT.py" |
| |
|
| | print(file_path) |
| |
|
| | from utils.config import Config |
| |
|
| | cfg = Config.fromfile(file_path) |
| | print(cfg) |
| | train_loader = build_dataloader(cfg.dataset_config) |
| | cnt = 0 |
| | for i,(imgA, imgB, tar) in enumerate(train_loader): |
| | print(imgA.shape) |
| | cnt += 1 |
| | if cnt > 10: |
| | break |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|