import os from typing import Dict, List, Tuple import numpy as np import rasterio import torch from huggingface_hub import hf_hub_download from torch.utils.data import Dataset from config import DEFAULT_PATCH_SIZE, NUM_CHANNELS, NUM_CLASSES, IGNORE_INDEX, CLASS_NAMES DATASET_REPO = os.environ.get("DATASET_REPO", "") HF_TOKEN = os.environ.get("HF_TOKEN", "") or None BAND_FILES = [f"H_{i}.tif" for i in range(1, 8)] TRAIN_MASK_FILE = "TRAINING.tif" VAL_MASK_FILE = "GROUND TRUTH.tif" # ── File helpers ───────────────────────────────────────────── def _download(filename: str) -> str: if not DATASET_REPO: raise EnvironmentError("DATASET_REPO not set in Space secrets.") return hf_hub_download( repo_id=DATASET_REPO, filename=filename, repo_type="dataset", token=HF_TOKEN, ) def _read_band(path: str) -> np.ndarray: with rasterio.open(path) as src: data = src.read(1).astype(np.float32) if src.nodata is not None: data[data == src.nodata] = np.nan return data def _read_mask_raw(path: str) -> Tuple[np.ndarray, object, str]: """Returns (raw_array, nodata_value, info_string).""" with rasterio.open(path) as src: data = src.read(1) nodata = src.nodata info = f"shape={src.shape} dtype={src.dtypes[0]} nodata={nodata} bands={src.count}" return data, nodata, info def _normalize(image: np.ndarray) -> np.ndarray: out = np.zeros_like(image, dtype=np.float32) for b in range(image.shape[0]): band = image[b] finite = band[np.isfinite(band)] if len(finite) == 0: continue lo, hi = np.percentile(finite, 2), np.percentile(finite, 98) out[b] = np.clip((np.nan_to_num(band) - lo) / max(hi - lo, 1e-6), 0.0, 1.0) return out def _remap_mask(raw: np.ndarray, nodata_val) -> Tuple[np.ndarray, List[int]]: """ Map raw pixel values to 0..NUM_CLASSES-1. Value 0 is treated as unlabeled background → IGNORE_INDEX. Nodata pixels → IGNORE_INDEX. Returns (remapped_mask, sorted_raw_class_values_used). """ if nodata_val is not None: nodata_px = raw == int(nodata_val) else: nodata_px = np.zeros(raw.shape, dtype=bool) # Treat pixel value 0 as unlabeled background background_px = raw == 0 ignore_px = nodata_px | background_px valid = ~ignore_px raw_unique = sorted(int(v) for v in np.unique(raw[valid])) mask = np.full(raw.shape, IGNORE_INDEX, dtype=np.int64) for cls_idx, raw_val in enumerate(raw_unique[:NUM_CLASSES]): mask[raw == raw_val] = cls_idx return mask, raw_unique def _extract_patches( image: np.ndarray, mask: np.ndarray, patch_size: int, ) -> Tuple[np.ndarray, np.ndarray]: _, H_img, W_img = image.shape H_msk, W_msk = mask.shape H = min(H_img, H_msk) W = min(W_img, W_msk) stride = patch_size // 2 imgs, masks = [], [] # Build step lists that always include the last valid position (covers edges) def steps(size): s = list(range(0, size - patch_size + 1, stride)) if not s: s = [0] if size >= patch_size else [] elif s[-1] < size - patch_size: s.append(size - patch_size) return s for y in steps(H): for x in steps(W): pm = mask[y : y + patch_size, x : x + patch_size] pi = image[:, y : y + patch_size, x : x + patch_size] # Include any patch that contains at least one labeled pixel if pm.shape == (patch_size, patch_size) and (pm != IGNORE_INDEX).any(): imgs.append(pi) masks.append(pm) # Last resort: pad with zeros/IGNORE if image is smaller than patch_size if not imgs: ph = min(patch_size, H) pw = min(patch_size, W) img_pad = np.zeros((image.shape[0], patch_size, patch_size), dtype=np.float32) mask_pad = np.full((patch_size, patch_size), IGNORE_INDEX, dtype=np.int64) img_pad[:, :ph, :pw] = image[:, :ph, :pw] mask_pad[:ph, :pw] = mask[:ph, :pw] imgs.append(img_pad) masks.append(mask_pad) return np.stack(imgs).astype(np.float32), np.stack(masks).astype(np.int64) # ── Dataset class ───────────────────────────────────────────── class MultiSpectralDataset(Dataset): def __init__(self, images: np.ndarray, masks: np.ndarray): self.images = images.astype(np.float32) self.masks = masks.astype(np.int64) def __len__(self): return len(self.images) def __getitem__(self, idx: int): return torch.from_numpy(self.images[idx]), torch.from_numpy(self.masks[idx]) # ── Public API ──────────────────────────────────────────────── def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict: # Download and stack bands band_arrays = [_read_band(_download(f)) for f in BAND_FILES] image = _normalize(np.stack(band_arrays, axis=0)) # (7, H, W) float32 # Read raw masks + metadata raw_train, nd_train, info_train = _read_mask_raw(_download(TRAIN_MASK_FILE)) raw_val, nd_val, info_val = _read_mask_raw(_download(VAL_MASK_FILE)) # Remap to 0-indexed classes train_mask, train_vals = _remap_mask(raw_train, nd_train) val_mask, val_vals = _remap_mask(raw_val, nd_val) if not train_vals: raise ValueError( f"TRAINING.tif has no labeled pixels after nodata removal. " f"File info: {info_train} | Unique raw values: {np.unique(raw_train).tolist()}" ) # Extract patches tr_imgs, tr_masks = _extract_patches(image, train_mask, patch_size) va_imgs, va_masks = _extract_patches(image, val_mask, patch_size) train_labeled = int((train_mask != IGNORE_INDEX).sum()) val_labeled = int((val_mask != IGNORE_INDEX).sum()) def _class_dist(mask, total): parts = [] for i, name in enumerate(CLASS_NAMES): n = int((mask == i).sum()) parts.append(f"{name}: {n:,} ({n / max(1, total) * 100:.1f}%)") return " | ".join(parts) status = "\n".join([ f"Train patches: **{len(tr_imgs)}** | Val patches: **{len(va_imgs)}** | Patch: **{patch_size}×{patch_size}**", "", f"**TRAINING.tif** `{info_train}`", f"Raw values → classes: `{dict(zip(train_vals, CLASS_NAMES[:len(train_vals)]))}`", f"Labeled pixels: **{train_labeled:,}** — {_class_dist(train_mask, train_labeled)}", "", f"**GROUND TRUTH.tif** `{info_val}`", f"Raw values → classes: `{dict(zip(val_vals, CLASS_NAMES[:len(val_vals)]))}`", f"Labeled pixels: **{val_labeled:,}** — {_class_dist(val_mask, val_labeled)}", ]) return { "train_images": tr_imgs, "train_masks": tr_masks, "val_images": va_imgs, "val_masks": va_masks, "status": status, }