code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """Build datasets / dataloaders from a Config, consistent across train & test.""" | |
| from __future__ import annotations | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from .unified_dataset import UnifiedSegDataset | |
| from .transforms import build_transform | |
| from ..engine.distributed import is_dist | |
| def build_dataset(cfg, split: str) -> UnifiedSegDataset: | |
| train = (split == "train") | |
| synth = cfg.synth_train_dir if train else "" | |
| # construct without transform first so in_channels/num_classes auto-detect runs | |
| ds = UnifiedSegDataset( | |
| data_root=cfg.data_root, dataset=cfg.dataset, protocol=cfg.protocol, split=split, | |
| transform=None, in_channels=cfg.in_channels, num_classes=cfg.num_classes, | |
| synth_dir=synth, | |
| ) | |
| ds.transform = build_transform(cfg.img_size, ds.in_channels, train=train, | |
| aug=cfg.aug, normalize=cfg.normalize) | |
| return ds | |
| def build_loader(cfg, split: str, ds: UnifiedSegDataset) -> DataLoader: | |
| train = (split == "train") | |
| sampler = None | |
| if is_dist(): | |
| sampler = DistributedSampler(ds, shuffle=train, drop_last=train) | |
| return DataLoader( | |
| ds, | |
| batch_size=cfg.batch_size, | |
| shuffle=(train and sampler is None), | |
| sampler=sampler, | |
| num_workers=cfg.num_workers, | |
| pin_memory=True, | |
| drop_last=(train and sampler is None), | |
| persistent_workers=cfg.num_workers > 0, | |
| ) | |