| from pathlib import Path |
| from typing import Optional |
| from torchvision import transforms |
|
|
| from dataset import flairhub, urur, isic, crag, swissimage, swissimage_inference |
|
|
| DATASET_NAMES = ( |
| "FLAIRHUB", |
| "URUR", |
| "ISIC", |
| "CRAG", |
| "SWISSIMAGE", |
| "SWISSIMAGEINFERENCE", |
| ) |
|
|
| def build_datasets(cfg, transform: Optional[transforms.Compose], train_augment=None): |
| base = Path(cfg.data_path) |
| name = str(cfg.dataset_name).upper() |
|
|
| if name == "FLAIRHUB": |
| ds_train = flairhub.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, transform, augment=train_augment) |
| ds_val = flairhub.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, transform, augment=None) |
| ds_test = flairhub.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, transform, augment=None) |
| return ds_train, ds_val, ds_test |
|
|
| if name == "SWISSIMAGE": |
| ds_train = swissimage.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, transform, augment=train_augment) |
| ds_val = swissimage.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, transform, augment=None) |
| ds_test = swissimage.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, transform, augment=None) |
| return ds_train, ds_val, ds_test |
|
|
| if name == "URUR": |
| ds_train = urur.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, cfg.num_classes, "train", cfg.ignore_index, transform, augment=train_augment) |
| ds_val = urur.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, cfg.num_classes, "val", cfg.ignore_index, transform, augment=None) |
| ds_test = urur.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, cfg.num_classes, "test", cfg.ignore_index, transform, augment=None) |
| return ds_train, ds_val, ds_test |
|
|
| if name == "ISIC": |
| ds_train = isic.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, cfg.num_classes, "train", cfg.ignore_index, transform, augment=train_augment) |
| ds_val = isic.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, cfg.num_classes, "val", cfg.ignore_index, transform, augment=None) |
| ds_test = isic.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, cfg.num_classes, "test", cfg.ignore_index, transform, augment=None) |
| return ds_train, ds_val, ds_test |
|
|
| if name == "CRAG": |
| ds_train = crag.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, cfg.num_classes, "train", cfg.ignore_index, transform, augment=train_augment) |
| ds_val = crag.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, cfg.num_classes, "val", cfg.ignore_index, transform, augment=None) |
| ds_test = crag.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, cfg.num_classes, "test", cfg.ignore_index, transform, augment=None) |
| return ds_train, ds_val, ds_test |
|
|
| if name == "SWISSIMAGEINFERENCE": |
| ds_train = swissimage_inference.build(base / cfg.train_img_subdir, cfg.num_classes, "train", cfg.ignore_index, transform) |
| ds_val = swissimage_inference.build(base / cfg.val_img_subdir, cfg.num_classes, "val", cfg.ignore_index, transform) |
| ds_test = swissimage_inference.build(base / cfg.test_img_subdir, cfg.num_classes, "test", cfg.ignore_index, transform) |
| return ds_train, ds_val, ds_test |
|
|
| raise ValueError(f"Unknown dataset_name={cfg.dataset_name}. Supported: {', '.join(DATASET_NAMES)}") |
|
|
| def build_eval_dataset(cfg, split: str, transform: Optional[transforms.Compose]): |
| base = Path(cfg.data_path) |
| name = str(cfg.dataset_name).upper() |
| split = str(split).lower() |
| if split not in {"test", "val"}: |
| raise ValueError("split must be 'test' or 'val'") |
|
|
| if name in {"FLAIRHUB", "SWISSIMAGE"}: |
| img_subdir = getattr(cfg, f"{split}_img_subdir") |
| msk_subdir = getattr(cfg, f"{split}_msk_subdir") |
| if name == "FLAIRHUB": |
| return flairhub.build(base / img_subdir, base / msk_subdir, transform, augment=None) |
| return swissimage.build(base / img_subdir, base / msk_subdir, transform, augment=None) |
|
|
| if name == "URUR": |
| return urur.build(base / getattr(cfg, f"{split}_img_subdir"), base / getattr(cfg, f"{split}_msk_subdir"), |
| cfg.num_classes, "test" if split == "test" else "val", |
| cfg.ignore_index, transform, augment=None) |
|
|
| if name == "ISIC": |
| return isic.build(base / getattr(cfg, f"{split}_img_subdir"), base / getattr(cfg, f"{split}_msk_subdir"), |
| cfg.num_classes, "test" if split == "test" else "val", |
| cfg.ignore_index, transform, augment=None) |
|
|
| if name == "CRAG": |
| return crag.build(base / getattr(cfg, f"{split}_img_subdir"), base / getattr(cfg, f"{split}_msk_subdir"), |
| cfg.num_classes, "test" if split == "test" else "val", |
| cfg.ignore_index, transform, augment=None) |
|
|
| if name == "SWISSIMAGEINFERENCE": |
| img_subdir = getattr(cfg, f"{split}_img_subdir") |
| return swissimage_inference.build(base / img_subdir, cfg.num_classes, "test" if split == "test" else "val", |
| cfg.ignore_index, transform) |
|
|
| raise ValueError(f"Unknown dataset_name={cfg.dataset_name}. Supported: {', '.join(DATASET_NAMES)}") |
|
|