import os import random import numpy as np from glob import glob import rasterio from rasterio.windows import Window import torch from torch.utils.data import Dataset # ----------------------------- # Utility Functions # ----------------------------- def normalize_band(img, mean, std): """Min-max normalization using mean ± 2sigma .""" min_v = mean - 2 * std max_v = mean + 2 * std img = (img - min_v) / (max_v - min_v + 1e-6) return np.clip(img, 0, 1).astype(np.float32) class GeoAugment: """Random flips + 90° rotations.""" def __init__(self, rotate=True, flip=True): self.rotate = rotate self.flip = flip def __call__(self, x, y): # Horizontal flip if self.flip and random.random() < 0.5: x = np.flip(x, axis=2).copy() y = np.flip(y, axis=2).copy() # Vertical flip if self.flip and random.random() < 0.5: x = np.flip(x, axis=1).copy() y = np.flip(y, axis=1).copy() # Rotations if self.rotate: k = random.choice([0, 1, 2, 3]) if k > 0: x = np.rot90(x, k, axes=(1, 2)).copy() y = np.rot90(y, k, axes=(1, 2)).copy() return x, y # ----------------------------- # Dataset Class # ----------------------------- class SatellitePatchDataset(Dataset): """ Multi-modal satellite dataset loader (S1, S2, DEM). Train/val/test split must be performed by selecting `locations`. """ def __init__( self, root, locations, patch_size=256, stride=None, skip_empty=True, empty_tile_ratio=0.0, task='segmentation', dates=None, masking_ratio=0.5, transform=None, band_stats=None, # { 'S1': {mean:[], std:[]}, 'S2': {...}, 'DEM': {...} } ch_s1=[0, 1], # chanell 0 is VV, channel 1 is VH ch_s2=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # channel 0-11 are B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, NDWI, NDSI ch_dem=[0], # channel 0 is elevation ch_hillshade=[0], # channel 0 is hillshade ch_cloudmask=[0], ): self.root = root self.locations = locations self.patch_size = patch_size self.stride = stride or patch_size self.skip_empty = skip_empty self.empty_tile_ratio = empty_tile_ratio self.task = task self.masking_ratio = masking_ratio self.transform = transform self.dates = dates self.ch_s1 = ch_s1 self.ch_s2 = ch_s2 self.ch_dem = ch_dem self.ch_hillshade = ch_hillshade self.ch_cloudmask = ch_cloudmask self.band_stats = band_stats self.samples = [] # (date, location) self.patch_index = [] # (sample_id, x, y) if task not in ['segmentation', 'mae']: raise ValueError(f"Unsupported task: {task}") self._discover_samples() self._index_patches() # --------------------------------------------------------- # Scan dataset and find all valid (date, location) pairs # --------------------------------------------------------- def _discover_samples(self): for loc in self.locations: loc_dir = os.path.join(self.root, loc) mask_paths = sorted(glob(os.path.join(loc_dir, "*_lake_mask.tif"))) mask_paths = [p for p in mask_paths if os.path.basename(p).split("_")[0] in self.dates] if self.dates is not None else mask_paths for path in mask_paths: basename = os.path.basename(path) date = basename.split("_")[0] if self.ch_s1 is not None and len(self.ch_s1) > 0: s1_path = os.path.join(loc_dir, f"{date}_{loc}_s1.tif") if not os.path.exists(s1_path): continue if self.ch_s2 is not None and len(self.ch_s2) > 0: s2_path = os.path.join(loc_dir, f"{date}_{loc}_s2.tif") if not os.path.exists(s2_path): continue if self.ch_dem is not None and len(self.ch_dem) > 0: dem_path = os.path.join(loc_dir, f"{loc}_dem.tif") if not os.path.exists(dem_path): continue if self.ch_hillshade is not None and len(self.ch_hillshade) > 0: hillshade_path = os.path.join(loc_dir, f"{date}_{loc}_hillshade.tif") if not os.path.exists(hillshade_path): continue if self.ch_cloudmask is not None and len(self.ch_cloudmask) > 0: cloudmask_path = os.path.join(loc_dir, f"{date}_{loc}_cloud_mask.tif") if not os.path.exists(cloudmask_path): continue self.samples.append((date, loc)) # --------------------------------------------------------- # Build patch index (optionally skip empty mask) # --------------------------------------------------------- def _index_patches(self): for i, (date, loc) in enumerate(self.samples): mask_path = os.path.join(self.root, loc, f"{date}_{loc}_lake_mask.tif") ordered_patches = [] empty_indices = [] with rasterio.open(mask_path) as msk: H, W = msk.height, msk.width for y in range(0, H, self.stride): for x in range(0, W, self.stride): patch = msk.read( 1, window=Window(x, y, self.patch_size, self.patch_size), boundless=True, fill_value=0, ) is_empty = np.all(patch == 0) ordered_patches.append((i, x, y, is_empty)) if is_empty: empty_indices.append(len(ordered_patches) - 1) # Decide which empty patches to keep keep_empty = set() if not self.skip_empty: if self.empty_tile_ratio > 0: k = int((len(ordered_patches) - len(empty_indices)) * self.empty_tile_ratio) k = min(k, len(empty_indices)) keep_empty = set(empty_indices[:k]) # deterministic else: keep_empty = set(empty_indices) for i, (idx, x, y, is_empty) in enumerate(ordered_patches): if is_empty and i not in keep_empty: continue self.patch_index.append((idx, x, y)) def reconstruct_image(self, patches, sample_id): """Reconstruct full image from patches for a given sample_id.""" date, loc = self.samples[sample_id] loc_dir = os.path.join(self.root, loc) # Load one band to get image size with rasterio.open(os.path.join(loc_dir, f"{date}_{loc}_lake_mask.tif")) as src: H, W = src.height + self.patch_size - 1, src.width + self.patch_size - 1 full_image = np.zeros((patches.shape[1], H, W), dtype=patches.dtype) count_image = np.zeros((H, W), dtype=np.float32) patch_idx = 0 for y in range(0, H - self.patch_size + 1, self.stride): for x in range(0, W - self.patch_size + 1, self.stride): if patch_idx >= patches.shape[0]: break full_image[:, y:y+self.patch_size, x:x+self.patch_size] += patches[patch_idx] count_image[y:y+self.patch_size, x:x+self.patch_size] += 1.0 patch_idx += 1 count_image[count_image == 0] = 1.0 # avoid division by zero full_image /= count_image[None, :, :] return full_image[:, :src.height, :src.width] # --------------------------------------------------------- # PyTorch Dataset API # --------------------------------------------------------- def __len__(self): return len(self.patch_index) def __getitem__(self, idx): sample_id, x0, y0 = self.patch_index[idx] date, loc = self.samples[sample_id] loc_dir = os.path.join(self.root, loc) window = Window(x0, y0, self.patch_size, self.patch_size) # --------------------------- # Load modalities # --------------------------- channels = [] if self.ch_s1 is not None and len(self.ch_s1) > 0: channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_s1.tif"), "S1", self.ch_s1, window)) if self.ch_s2 is not None and len(self.ch_s2) > 0: channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_s2.tif"), "S2", self.ch_s2, window)) if self.ch_dem is not None and len(self.ch_dem) > 0: channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{loc}_dem.tif"), "DEM", self.ch_dem, window)) if self.ch_hillshade is not None and len(self.ch_hillshade) > 0: channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_hillshade.tif"), "Hillshade", self.ch_hillshade, window)) if self.ch_cloudmask is not None and len(self.ch_cloudmask) > 0: channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_cloud_mask.tif"), "Cloudmask", self.ch_cloudmask, window)) x = np.concatenate(channels, axis=0) # Mask mask_path = os.path.join(loc_dir, f"{date}_{loc}_lake_mask.tif") with rasterio.open(mask_path) as src: y = src.read(1, window=window, boundless=True, fill_value=0).astype(np.float32)[None, ...] y = (y > 0).astype(np.float32) # Apply augmentation if self.transform: x, y = self.transform(x, y) if self.task == 'segmentation': return torch.from_numpy(x), torch.from_numpy(y) if self.task == 'mae': B, H, W = x.shape # mask the image with the given ratio mask_size = 8 num_patches = (H // mask_size) * (W // mask_size) num_masked = int(num_patches * self.masking_ratio) mask = np.hstack([ np.ones(num_masked, dtype=np.float32), np.zeros(num_patches - num_masked, dtype=np.float32), ]) np.random.shuffle(mask) mask = mask.reshape(H // mask_size, W // mask_size) mask = np.kron(mask, np.ones((mask_size, mask_size), dtype=np.float32)) # Upsample to pixel level masked_image = x * (1 - mask[None, :, :]) return torch.from_numpy(masked_image), torch.from_numpy(x) return torch.from_numpy(x), torch.from_numpy(y) # --------------------------------------------------------- # Loading with optional per-band normalization # --------------------------------------------------------- def _load_and_normalize(self, path, key, channels, window): with rasterio.open(path) as src: arr = src.read([c + 1 for c in channels], window=window, boundless=True, fill_value=0).astype(np.float32) arr[~np.isfinite(arr)] = 0 if self.band_stats and key in self.band_stats: means = [self.band_stats[key]["mean"][c] for c in channels] stds = [self.band_stats[key]["std"][c] for c in channels] for i in range(arr.shape[0]): arr[i] = normalize_band(arr[i], means[i], stds[i]) return arr