"""Build datasets / dataloaders from a Config, consistent across train & test.""" from __future__ import annotations from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from .unified_dataset import UnifiedSegDataset from .transforms import build_transform from ..engine.distributed import is_dist def build_dataset(cfg, split: str) -> UnifiedSegDataset: train = (split == "train") synth = cfg.synth_train_dir if train else "" # construct without transform first so in_channels/num_classes auto-detect runs ds = UnifiedSegDataset( data_root=cfg.data_root, dataset=cfg.dataset, protocol=cfg.protocol, split=split, transform=None, in_channels=cfg.in_channels, num_classes=cfg.num_classes, synth_dir=synth, ) ds.transform = build_transform(cfg.img_size, ds.in_channels, train=train, aug=cfg.aug, normalize=cfg.normalize) return ds def build_loader(cfg, split: str, ds: UnifiedSegDataset) -> DataLoader: train = (split == "train") sampler = None if is_dist(): sampler = DistributedSampler(ds, shuffle=train, drop_last=train) return DataLoader( ds, batch_size=cfg.batch_size, shuffle=(train and sampler is None), sampler=sampler, num_workers=cfg.num_workers, pin_memory=True, drop_last=(train and sampler is None), persistent_workers=cfg.num_workers > 0, )