Antoine1091's picture
Upload folder using huggingface_hub
49d2955 verified
"""
FLAIR French Land Cover Dataset
"""
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
class FLAIRDataset(Dataset):
"""
FLAIR French Land Cover dataset.
15 classes (0-14), classes >= 15 are mapped to ignore_index (255).
Args:
data_root: Path to dataset root
split: 'train', 'valid', or 'test'
crop_size: Size of random/center crop
augment: Whether to apply augmentations (auto-disabled for non-train splits)
ignore_index: Label value to use for ignored classes
"""
# ImageNet normalization
MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)
# Class names
CLASSES = [
'building', 'pervious', 'impervious', 'bare_soil', 'water',
'coniferous', 'deciduous', 'brushwood', 'vineyard', 'herbaceous',
'agricultural', 'plowed_land', 'swimming_pool', 'snow', 'greenhouse'
]
def __init__(self, data_root, split='train', crop_size=512, augment=True, ignore_index=255):
self.data_root = data_root
self.split = split
self.crop_size = crop_size
self.augment = augment and split == 'train'
self.ignore_index = ignore_index
self.img_dir = os.path.join(data_root, split, 'img')
self.msk_dir = os.path.join(data_root, split, 'msk')
self.img_files = sorted(os.listdir(self.img_dir))
def __len__(self):
return len(self.img_files)
def _photometric_distortion(self, img):
"""Apply photometric distortion (brightness, contrast, saturation, hue)."""
# Random brightness
if np.random.rand() > 0.5:
delta = np.random.uniform(-32, 32)
img = img + delta
# Random contrast
if np.random.rand() > 0.5:
alpha = np.random.uniform(0.5, 1.5)
img = img * alpha
# Convert to HSV for saturation and hue
img_uint8 = np.clip(img, 0, 255).astype(np.uint8)
img_hsv = np.array(Image.fromarray(img_uint8).convert('HSV')).astype(np.float32)
# Random saturation
if np.random.rand() > 0.5:
img_hsv[:, :, 1] = img_hsv[:, :, 1] * np.random.uniform(0.5, 1.5)
# Random hue
if np.random.rand() > 0.5:
img_hsv[:, :, 0] = (img_hsv[:, :, 0] + np.random.uniform(-18, 18)) % 256
# Convert back to RGB
img_hsv = np.clip(img_hsv, 0, 255).astype(np.uint8)
img = np.array(Image.fromarray(img_hsv, mode='HSV').convert('RGB')).astype(np.float32)
return np.clip(img, 0, 255)
def _random_rotate(self, img, msk):
"""Random rotation by 90, 180, or 270 degrees."""
k = np.random.choice([0, 1, 2, 3])
if k > 0:
img = np.rot90(img, k).copy()
msk = np.rot90(msk, k).copy()
return img, msk
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_files[idx])
msk_path = os.path.join(self.msk_dir, self.img_files[idx].replace('_RGBI_', '_LABEL-COSIA_'))
img = np.array(Image.open(img_path)).astype(np.float32)[:, :, :3]
msk = np.array(Image.open(msk_path)).astype(np.int64)
# Remap classes: keep 0-14, map >=15 to ignore_index
msk[msk >= 15] = self.ignore_index
# Apply photometric distortion BEFORE normalization
if self.augment:
img = self._photometric_distortion(img)
# Normalize
img = (img - self.MEAN) / self.STD
# Random/center crop
if self.crop_size and img.shape[0] >= self.crop_size:
h, w = img.shape[:2]
if self.augment:
# Try to find a crop with good class coverage (cat_max_ratio logic)
for _ in range(10):
top = np.random.randint(0, h - self.crop_size + 1)
left = np.random.randint(0, w - self.crop_size + 1)
crop_msk = msk[top:top+self.crop_size, left:left+self.crop_size]
valid_msk = crop_msk[crop_msk != self.ignore_index]
if len(valid_msk) > 0:
unique, counts = np.unique(valid_msk, return_counts=True)
if len(unique) > 1:
max_ratio = counts.max() / counts.sum()
if max_ratio < 0.75:
break
img = img[top:top+self.crop_size, left:left+self.crop_size]
msk = msk[top:top+self.crop_size, left:left+self.crop_size]
else:
# Center crop for validation
top = (h - self.crop_size) // 2
left = (w - self.crop_size) // 2
img = img[top:top+self.crop_size, left:left+self.crop_size]
msk = msk[top:top+self.crop_size, left:left+self.crop_size]
# Random rotation
if self.augment and np.random.rand() > 0.5:
img, msk = self._random_rotate(img, msk)
# Random horizontal flip
if self.augment and np.random.rand() > 0.5:
img = np.fliplr(img).copy()
msk = np.fliplr(msk).copy()
return torch.from_numpy(img.transpose(2, 0, 1).astype(np.float32)), torch.from_numpy(msk)