Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import rasterio | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from torch.utils.data import Dataset | |
| from config import DEFAULT_PATCH_SIZE, NUM_CHANNELS, NUM_CLASSES, IGNORE_INDEX, CLASS_NAMES | |
| DATASET_REPO = os.environ.get("DATASET_REPO", "") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") or None | |
| BAND_FILES = [f"H_{i}.tif" for i in range(1, 8)] | |
| TRAIN_MASK_FILE = "TRAINING.tif" | |
| VAL_MASK_FILE = "GROUND TRUTH.tif" | |
| # ββ File helpers βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _download(filename: str) -> str: | |
| if not DATASET_REPO: | |
| raise EnvironmentError("DATASET_REPO not set in Space secrets.") | |
| return hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| def _read_band(path: str) -> np.ndarray: | |
| with rasterio.open(path) as src: | |
| data = src.read(1).astype(np.float32) | |
| if src.nodata is not None: | |
| data[data == src.nodata] = np.nan | |
| return data | |
| def _read_mask_raw(path: str) -> Tuple[np.ndarray, object, str]: | |
| """Returns (raw_array, nodata_value, info_string).""" | |
| with rasterio.open(path) as src: | |
| data = src.read(1) | |
| nodata = src.nodata | |
| info = f"shape={src.shape} dtype={src.dtypes[0]} nodata={nodata} bands={src.count}" | |
| return data, nodata, info | |
| def _normalize(image: np.ndarray) -> np.ndarray: | |
| out = np.zeros_like(image, dtype=np.float32) | |
| for b in range(image.shape[0]): | |
| band = image[b] | |
| finite = band[np.isfinite(band)] | |
| if len(finite) == 0: | |
| continue | |
| lo, hi = np.percentile(finite, 2), np.percentile(finite, 98) | |
| out[b] = np.clip((np.nan_to_num(band) - lo) / max(hi - lo, 1e-6), 0.0, 1.0) | |
| return out | |
| def _remap_mask(raw: np.ndarray, nodata_val) -> Tuple[np.ndarray, List[int]]: | |
| """ | |
| Map raw pixel values to 0..NUM_CLASSES-1. | |
| Value 0 is treated as unlabeled background β IGNORE_INDEX. | |
| Nodata pixels β IGNORE_INDEX. | |
| Returns (remapped_mask, sorted_raw_class_values_used). | |
| """ | |
| if nodata_val is not None: | |
| nodata_px = raw == int(nodata_val) | |
| else: | |
| nodata_px = np.zeros(raw.shape, dtype=bool) | |
| # Treat pixel value 0 as unlabeled background | |
| background_px = raw == 0 | |
| ignore_px = nodata_px | background_px | |
| valid = ~ignore_px | |
| raw_unique = sorted(int(v) for v in np.unique(raw[valid])) | |
| mask = np.full(raw.shape, IGNORE_INDEX, dtype=np.int64) | |
| for cls_idx, raw_val in enumerate(raw_unique[:NUM_CLASSES]): | |
| mask[raw == raw_val] = cls_idx | |
| return mask, raw_unique | |
| def _extract_patches( | |
| image: np.ndarray, | |
| mask: np.ndarray, | |
| patch_size: int, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| _, H_img, W_img = image.shape | |
| H_msk, W_msk = mask.shape | |
| H = min(H_img, H_msk) | |
| W = min(W_img, W_msk) | |
| stride = patch_size // 2 | |
| imgs, masks = [], [] | |
| # Build step lists that always include the last valid position (covers edges) | |
| def steps(size): | |
| s = list(range(0, size - patch_size + 1, stride)) | |
| if not s: | |
| s = [0] if size >= patch_size else [] | |
| elif s[-1] < size - patch_size: | |
| s.append(size - patch_size) | |
| return s | |
| for y in steps(H): | |
| for x in steps(W): | |
| pm = mask[y : y + patch_size, x : x + patch_size] | |
| pi = image[:, y : y + patch_size, x : x + patch_size] | |
| # Include any patch that contains at least one labeled pixel | |
| if pm.shape == (patch_size, patch_size) and (pm != IGNORE_INDEX).any(): | |
| imgs.append(pi) | |
| masks.append(pm) | |
| # Last resort: pad with zeros/IGNORE if image is smaller than patch_size | |
| if not imgs: | |
| ph = min(patch_size, H) | |
| pw = min(patch_size, W) | |
| img_pad = np.zeros((image.shape[0], patch_size, patch_size), dtype=np.float32) | |
| mask_pad = np.full((patch_size, patch_size), IGNORE_INDEX, dtype=np.int64) | |
| img_pad[:, :ph, :pw] = image[:, :ph, :pw] | |
| mask_pad[:ph, :pw] = mask[:ph, :pw] | |
| imgs.append(img_pad) | |
| masks.append(mask_pad) | |
| return np.stack(imgs).astype(np.float32), np.stack(masks).astype(np.int64) | |
| # ββ Dataset class βββββββββββββββββββββββββββββββββββββββββββββ | |
| class MultiSpectralDataset(Dataset): | |
| def __init__(self, images: np.ndarray, masks: np.ndarray): | |
| self.images = images.astype(np.float32) | |
| self.masks = masks.astype(np.int64) | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx: int): | |
| return torch.from_numpy(self.images[idx]), torch.from_numpy(self.masks[idx]) | |
| # ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict: | |
| # Download and stack bands | |
| band_arrays = [_read_band(_download(f)) for f in BAND_FILES] | |
| image = _normalize(np.stack(band_arrays, axis=0)) # (7, H, W) float32 | |
| # Read raw masks + metadata | |
| raw_train, nd_train, info_train = _read_mask_raw(_download(TRAIN_MASK_FILE)) | |
| raw_val, nd_val, info_val = _read_mask_raw(_download(VAL_MASK_FILE)) | |
| # Remap to 0-indexed classes | |
| train_mask, train_vals = _remap_mask(raw_train, nd_train) | |
| val_mask, val_vals = _remap_mask(raw_val, nd_val) | |
| if not train_vals: | |
| raise ValueError( | |
| f"TRAINING.tif has no labeled pixels after nodata removal. " | |
| f"File info: {info_train} | Unique raw values: {np.unique(raw_train).tolist()}" | |
| ) | |
| # Extract patches | |
| tr_imgs, tr_masks = _extract_patches(image, train_mask, patch_size) | |
| va_imgs, va_masks = _extract_patches(image, val_mask, patch_size) | |
| train_labeled = int((train_mask != IGNORE_INDEX).sum()) | |
| val_labeled = int((val_mask != IGNORE_INDEX).sum()) | |
| def _class_dist(mask, total): | |
| parts = [] | |
| for i, name in enumerate(CLASS_NAMES): | |
| n = int((mask == i).sum()) | |
| parts.append(f"{name}: {n:,} ({n / max(1, total) * 100:.1f}%)") | |
| return " | ".join(parts) | |
| status = "\n".join([ | |
| f"Train patches: **{len(tr_imgs)}** | Val patches: **{len(va_imgs)}** | Patch: **{patch_size}Γ{patch_size}**", | |
| "", | |
| f"**TRAINING.tif** `{info_train}`", | |
| f"Raw values β classes: `{dict(zip(train_vals, CLASS_NAMES[:len(train_vals)]))}`", | |
| f"Labeled pixels: **{train_labeled:,}** β {_class_dist(train_mask, train_labeled)}", | |
| "", | |
| f"**GROUND TRUTH.tif** `{info_val}`", | |
| f"Raw values β classes: `{dict(zip(val_vals, CLASS_NAMES[:len(val_vals)]))}`", | |
| f"Labeled pixels: **{val_labeled:,}** β {_class_dist(val_mask, val_labeled)}", | |
| ]) | |
| return { | |
| "train_images": tr_imgs, | |
| "train_masks": tr_masks, | |
| "val_images": va_imgs, | |
| "val_masks": va_masks, | |
| "status": status, | |
| } | |