TerraMind-HYPERVIEW / datasets /hyperview_datamodule.py
KPLabs's picture
Upload folder using huggingface_hub
87904b0 verified
from typing import Any, cast
import albumentations as A
from albumentations.core.composition import TransformType
from albumentations.pytorch import ToTensorV2
from torch import Tensor
from torch.utils.data import DataLoader
from torchgeo.datamodules import NonGeoDataModule
try:
from datasets.hyperview_dataset import HyperviewNonGeo
except ModuleNotFoundError:
from hyperview_dataset import HyperviewNonGeo
train_transforms = [
A.Resize(224, 224),
A.GaussNoise(std_range=(0.0, 0.005), mean_range=(0.0, 0.0), p=0.5),
A.ElasticTransform(p=0.25),
A.RandomRotate90(p=0.5),
A.VerticalFlip(p=0.5),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(
rotate_limit=90,
shift_limit_x=0.05,
shift_limit_y=0.05,
p=0.5,
),
ToTensorV2(),
]
train_full_transform = A.Compose(cast(list[TransformType], train_transforms))
eval_transforms = [A.Resize(224, 224), ToTensorV2()]
eval_transform = A.Compose(cast(list[TransformType], eval_transforms))
class HyperviewNonGeoDataModule(NonGeoDataModule):
"""DataModule for Hyperview non-geospatial hyperspectral cubes."""
def __init__(
self,
data_root: str,
label_train_path: str | None = None,
label_test_path: str | None = None,
batch_size: int = 4,
num_workers: int = 0,
train_transform: A.Compose | None = None,
val_transform: A.Compose | None = None,
test_transform: A.Compose | None = None,
target_index: list[int] | None = None,
cropped_load: bool = False,
val_ratio: float = 0.2,
random_state: int = 42,
fill_last_with_150: bool = True,
drop_last: bool = False,
exclude_11x11: bool = False,
only_11x11: bool = False,
mask: str = "none",
num_bands: int = 12,
label_scaling: str = "std",
split: str = "test",
aoi: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize datamodule with paths, transforms and split configuration."""
super().__init__(HyperviewNonGeo, batch_size, num_workers, **kwargs)
self.data_root = data_root
self.label_train_path = label_train_path
self.label_test_path = label_test_path
self.target_index = [0, 1, 2, 3] if target_index is None else target_index
self.cropped_load = cropped_load
self.val_ratio = val_ratio
self.random_state = random_state
self.train_transform = train_transform if train_transform is not None else train_full_transform
self.val_transform = val_transform if val_transform is not None else eval_transform
self.test_transform = test_transform if test_transform is not None else eval_transform
self.fill_last_with_150 = fill_last_with_150
self.drop_last = drop_last
self.exclude_11x11 = exclude_11x11
self.only_11x11 = only_11x11
self.mask = mask
self.num_bands = num_bands
self.label_scaling = label_scaling
self.split = split
self.aoi = aoi
def _build_dataset(
self,
split: str,
label_path: str | None,
transform: A.Compose | None,
aoi: str | None,
) -> HyperviewNonGeo:
"""Create dataset instance with shared keyword arguments."""
return HyperviewNonGeo(
data_root=self.data_root,
split=split,
label_path=label_path,
transform=transform,
target_index=self.target_index,
cropped_load=self.cropped_load,
val_ratio=self.val_ratio,
random_state=self.random_state,
fill_last_with_150=self.fill_last_with_150,
exclude_11x11=self.exclude_11x11,
only_11x11=self.only_11x11,
mask=self.mask,
num_bands=self.num_bands,
label_scaling=self.label_scaling,
aoi=aoi,
)
def setup(self, stage: str | None = None) -> None:
"""Instantiate datasets for requested stage."""
if stage in {None, "fit", "train"}:
self.train_dataset = self._build_dataset(
split="train",
label_path=self.label_train_path,
transform=self.train_transform,
aoi=None,
)
if stage in {None, "fit", "validate", "val"}:
self.val_dataset = self._build_dataset(
split="val",
label_path=self.label_train_path,
transform=self.val_transform,
aoi=None,
)
if stage in {None, "test", "predict"}:
self.test_dataset = self._build_dataset(
split=self.split,
label_path=self.label_test_path,
transform=self.test_transform,
aoi=self.aoi,
)
def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
"""Build split-specific DataLoader."""
dataset = self._valid_attribute(f"{split}_dataset", "dataset")
batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=(split == "train"),
num_workers=self.num_workers,
drop_last=(split == "train" and self.drop_last),
collate_fn=self.collate_fn,
)