MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
1.48 kB
"""Build datasets / dataloaders from a Config, consistent across train & test."""
from __future__ import annotations
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from .unified_dataset import UnifiedSegDataset
from .transforms import build_transform
from ..engine.distributed import is_dist
def build_dataset(cfg, split: str) -> UnifiedSegDataset:
train = (split == "train")
synth = cfg.synth_train_dir if train else ""
# construct without transform first so in_channels/num_classes auto-detect runs
ds = UnifiedSegDataset(
data_root=cfg.data_root, dataset=cfg.dataset, protocol=cfg.protocol, split=split,
transform=None, in_channels=cfg.in_channels, num_classes=cfg.num_classes,
synth_dir=synth,
)
ds.transform = build_transform(cfg.img_size, ds.in_channels, train=train,
aug=cfg.aug, normalize=cfg.normalize)
return ds
def build_loader(cfg, split: str, ds: UnifiedSegDataset) -> DataLoader:
train = (split == "train")
sampler = None
if is_dist():
sampler = DistributedSampler(ds, shuffle=train, drop_last=train)
return DataLoader(
ds,
batch_size=cfg.batch_size,
shuffle=(train and sampler is None),
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=(train and sampler is None),
persistent_workers=cfg.num_workers > 0,
)