Image Segmentation
English
CASWiT / dataset /factory.py
antoine.carreaud67
Update with new datasets
9367521
from pathlib import Path
from typing import Optional
from torchvision import transforms
from dataset import flairhub, urur, isic, crag, swissimage, swissimage_inference
DATASET_NAMES = (
"FLAIRHUB",
"URUR",
"ISIC",
"CRAG",
"SWISSIMAGE",
"SWISSIMAGEINFERENCE",
)
def build_datasets(cfg, transform: Optional[transforms.Compose], train_augment=None):
base = Path(cfg.data_path)
name = str(cfg.dataset_name).upper()
if name == "FLAIRHUB":
ds_train = flairhub.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, transform, augment=train_augment)
ds_val = flairhub.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, transform, augment=None)
ds_test = flairhub.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, transform, augment=None)
return ds_train, ds_val, ds_test
if name == "SWISSIMAGE":
ds_train = swissimage.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, transform, augment=train_augment)
ds_val = swissimage.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, transform, augment=None)
ds_test = swissimage.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, transform, augment=None)
return ds_train, ds_val, ds_test
if name == "URUR":
ds_train = urur.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, cfg.num_classes, "train", cfg.ignore_index, transform, augment=train_augment)
ds_val = urur.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, cfg.num_classes, "val", cfg.ignore_index, transform, augment=None)
ds_test = urur.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, cfg.num_classes, "test", cfg.ignore_index, transform, augment=None)
return ds_train, ds_val, ds_test
if name == "ISIC":
ds_train = isic.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, cfg.num_classes, "train", cfg.ignore_index, transform, augment=train_augment)
ds_val = isic.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, cfg.num_classes, "val", cfg.ignore_index, transform, augment=None)
ds_test = isic.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, cfg.num_classes, "test", cfg.ignore_index, transform, augment=None)
return ds_train, ds_val, ds_test
if name == "CRAG":
ds_train = crag.build(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, cfg.num_classes, "train", cfg.ignore_index, transform, augment=train_augment)
ds_val = crag.build(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, cfg.num_classes, "val", cfg.ignore_index, transform, augment=None)
ds_test = crag.build(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, cfg.num_classes, "test", cfg.ignore_index, transform, augment=None)
return ds_train, ds_val, ds_test
if name == "SWISSIMAGEINFERENCE":
ds_train = swissimage_inference.build(base / cfg.train_img_subdir, cfg.num_classes, "train", cfg.ignore_index, transform)
ds_val = swissimage_inference.build(base / cfg.val_img_subdir, cfg.num_classes, "val", cfg.ignore_index, transform)
ds_test = swissimage_inference.build(base / cfg.test_img_subdir, cfg.num_classes, "test", cfg.ignore_index, transform)
return ds_train, ds_val, ds_test
raise ValueError(f"Unknown dataset_name={cfg.dataset_name}. Supported: {', '.join(DATASET_NAMES)}")
def build_eval_dataset(cfg, split: str, transform: Optional[transforms.Compose]):
base = Path(cfg.data_path)
name = str(cfg.dataset_name).upper()
split = str(split).lower()
if split not in {"test", "val"}:
raise ValueError("split must be 'test' or 'val'")
if name in {"FLAIRHUB", "SWISSIMAGE"}:
img_subdir = getattr(cfg, f"{split}_img_subdir")
msk_subdir = getattr(cfg, f"{split}_msk_subdir")
if name == "FLAIRHUB":
return flairhub.build(base / img_subdir, base / msk_subdir, transform, augment=None)
return swissimage.build(base / img_subdir, base / msk_subdir, transform, augment=None)
if name == "URUR":
return urur.build(base / getattr(cfg, f"{split}_img_subdir"), base / getattr(cfg, f"{split}_msk_subdir"),
cfg.num_classes, "test" if split == "test" else "val",
cfg.ignore_index, transform, augment=None)
if name == "ISIC":
return isic.build(base / getattr(cfg, f"{split}_img_subdir"), base / getattr(cfg, f"{split}_msk_subdir"),
cfg.num_classes, "test" if split == "test" else "val",
cfg.ignore_index, transform, augment=None)
if name == "CRAG":
return crag.build(base / getattr(cfg, f"{split}_img_subdir"), base / getattr(cfg, f"{split}_msk_subdir"),
cfg.num_classes, "test" if split == "test" else "val",
cfg.ignore_index, transform, augment=None)
if name == "SWISSIMAGEINFERENCE":
img_subdir = getattr(cfg, f"{split}_img_subdir")
return swissimage_inference.build(base / img_subdir, cfg.num_classes, "test" if split == "test" else "val",
cfg.ignore_index, transform)
raise ValueError(f"Unknown dataset_name={cfg.dataset_name}. Supported: {', '.join(DATASET_NAMES)}")