| """ |
| FLAIR French Land Cover Dataset |
| """ |
|
|
| import os |
| import numpy as np |
| from PIL import Image |
| import torch |
| from torch.utils.data import Dataset |
|
|
|
|
| class FLAIRDataset(Dataset): |
| """ |
| FLAIR French Land Cover dataset. |
| |
| 15 classes (0-14), classes >= 15 are mapped to ignore_index (255). |
| |
| Args: |
| data_root: Path to dataset root |
| split: 'train', 'valid', or 'test' |
| crop_size: Size of random/center crop |
| augment: Whether to apply augmentations (auto-disabled for non-train splits) |
| ignore_index: Label value to use for ignored classes |
| """ |
|
|
| |
| MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32) |
| STD = np.array([58.395, 57.12, 57.375], dtype=np.float32) |
|
|
| |
| CLASSES = [ |
| 'building', 'pervious', 'impervious', 'bare_soil', 'water', |
| 'coniferous', 'deciduous', 'brushwood', 'vineyard', 'herbaceous', |
| 'agricultural', 'plowed_land', 'swimming_pool', 'snow', 'greenhouse' |
| ] |
|
|
| def __init__(self, data_root, split='train', crop_size=512, augment=True, ignore_index=255): |
| self.data_root = data_root |
| self.split = split |
| self.crop_size = crop_size |
| self.augment = augment and split == 'train' |
| self.ignore_index = ignore_index |
|
|
| self.img_dir = os.path.join(data_root, split, 'img') |
| self.msk_dir = os.path.join(data_root, split, 'msk') |
| self.img_files = sorted(os.listdir(self.img_dir)) |
|
|
| def __len__(self): |
| return len(self.img_files) |
|
|
| def _photometric_distortion(self, img): |
| """Apply photometric distortion (brightness, contrast, saturation, hue).""" |
| |
| if np.random.rand() > 0.5: |
| delta = np.random.uniform(-32, 32) |
| img = img + delta |
|
|
| |
| if np.random.rand() > 0.5: |
| alpha = np.random.uniform(0.5, 1.5) |
| img = img * alpha |
|
|
| |
| img_uint8 = np.clip(img, 0, 255).astype(np.uint8) |
| img_hsv = np.array(Image.fromarray(img_uint8).convert('HSV')).astype(np.float32) |
|
|
| |
| if np.random.rand() > 0.5: |
| img_hsv[:, :, 1] = img_hsv[:, :, 1] * np.random.uniform(0.5, 1.5) |
|
|
| |
| if np.random.rand() > 0.5: |
| img_hsv[:, :, 0] = (img_hsv[:, :, 0] + np.random.uniform(-18, 18)) % 256 |
|
|
| |
| img_hsv = np.clip(img_hsv, 0, 255).astype(np.uint8) |
| img = np.array(Image.fromarray(img_hsv, mode='HSV').convert('RGB')).astype(np.float32) |
|
|
| return np.clip(img, 0, 255) |
|
|
| def _random_rotate(self, img, msk): |
| """Random rotation by 90, 180, or 270 degrees.""" |
| k = np.random.choice([0, 1, 2, 3]) |
| if k > 0: |
| img = np.rot90(img, k).copy() |
| msk = np.rot90(msk, k).copy() |
| return img, msk |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.img_files[idx]) |
| msk_path = os.path.join(self.msk_dir, self.img_files[idx].replace('_RGBI_', '_LABEL-COSIA_')) |
|
|
| img = np.array(Image.open(img_path)).astype(np.float32)[:, :, :3] |
| msk = np.array(Image.open(msk_path)).astype(np.int64) |
|
|
| |
| msk[msk >= 15] = self.ignore_index |
|
|
| |
| if self.augment: |
| img = self._photometric_distortion(img) |
|
|
| |
| img = (img - self.MEAN) / self.STD |
|
|
| |
| if self.crop_size and img.shape[0] >= self.crop_size: |
| h, w = img.shape[:2] |
| if self.augment: |
| |
| for _ in range(10): |
| top = np.random.randint(0, h - self.crop_size + 1) |
| left = np.random.randint(0, w - self.crop_size + 1) |
| crop_msk = msk[top:top+self.crop_size, left:left+self.crop_size] |
| valid_msk = crop_msk[crop_msk != self.ignore_index] |
| if len(valid_msk) > 0: |
| unique, counts = np.unique(valid_msk, return_counts=True) |
| if len(unique) > 1: |
| max_ratio = counts.max() / counts.sum() |
| if max_ratio < 0.75: |
| break |
| img = img[top:top+self.crop_size, left:left+self.crop_size] |
| msk = msk[top:top+self.crop_size, left:left+self.crop_size] |
| else: |
| |
| top = (h - self.crop_size) // 2 |
| left = (w - self.crop_size) // 2 |
| img = img[top:top+self.crop_size, left:left+self.crop_size] |
| msk = msk[top:top+self.crop_size, left:left+self.crop_size] |
|
|
| |
| if self.augment and np.random.rand() > 0.5: |
| img, msk = self._random_rotate(img, msk) |
|
|
| |
| if self.augment and np.random.rand() > 0.5: |
| img = np.fliplr(img).copy() |
| msk = np.fliplr(msk).copy() |
|
|
| return torch.from_numpy(img.transpose(2, 0, 1).astype(np.float32)), torch.from_numpy(msk) |
|
|