segspace_app / data.py
functionNormally
Fix mask interpretation: treat value 0 as background, not Water
da69452
import os
from typing import Dict, List, Tuple
import numpy as np
import rasterio
import torch
from huggingface_hub import hf_hub_download
from torch.utils.data import Dataset
from config import DEFAULT_PATCH_SIZE, NUM_CHANNELS, NUM_CLASSES, IGNORE_INDEX, CLASS_NAMES
DATASET_REPO = os.environ.get("DATASET_REPO", "")
HF_TOKEN = os.environ.get("HF_TOKEN", "") or None
BAND_FILES = [f"H_{i}.tif" for i in range(1, 8)]
TRAIN_MASK_FILE = "TRAINING.tif"
VAL_MASK_FILE = "GROUND TRUTH.tif"
# ── File helpers ─────────────────────────────────────────────
def _download(filename: str) -> str:
if not DATASET_REPO:
raise EnvironmentError("DATASET_REPO not set in Space secrets.")
return hf_hub_download(
repo_id=DATASET_REPO,
filename=filename,
repo_type="dataset",
token=HF_TOKEN,
)
def _read_band(path: str) -> np.ndarray:
with rasterio.open(path) as src:
data = src.read(1).astype(np.float32)
if src.nodata is not None:
data[data == src.nodata] = np.nan
return data
def _read_mask_raw(path: str) -> Tuple[np.ndarray, object, str]:
"""Returns (raw_array, nodata_value, info_string)."""
with rasterio.open(path) as src:
data = src.read(1)
nodata = src.nodata
info = f"shape={src.shape} dtype={src.dtypes[0]} nodata={nodata} bands={src.count}"
return data, nodata, info
def _normalize(image: np.ndarray) -> np.ndarray:
out = np.zeros_like(image, dtype=np.float32)
for b in range(image.shape[0]):
band = image[b]
finite = band[np.isfinite(band)]
if len(finite) == 0:
continue
lo, hi = np.percentile(finite, 2), np.percentile(finite, 98)
out[b] = np.clip((np.nan_to_num(band) - lo) / max(hi - lo, 1e-6), 0.0, 1.0)
return out
def _remap_mask(raw: np.ndarray, nodata_val) -> Tuple[np.ndarray, List[int]]:
"""
Map raw pixel values to 0..NUM_CLASSES-1.
Value 0 is treated as unlabeled background β†’ IGNORE_INDEX.
Nodata pixels β†’ IGNORE_INDEX.
Returns (remapped_mask, sorted_raw_class_values_used).
"""
if nodata_val is not None:
nodata_px = raw == int(nodata_val)
else:
nodata_px = np.zeros(raw.shape, dtype=bool)
# Treat pixel value 0 as unlabeled background
background_px = raw == 0
ignore_px = nodata_px | background_px
valid = ~ignore_px
raw_unique = sorted(int(v) for v in np.unique(raw[valid]))
mask = np.full(raw.shape, IGNORE_INDEX, dtype=np.int64)
for cls_idx, raw_val in enumerate(raw_unique[:NUM_CLASSES]):
mask[raw == raw_val] = cls_idx
return mask, raw_unique
def _extract_patches(
image: np.ndarray,
mask: np.ndarray,
patch_size: int,
) -> Tuple[np.ndarray, np.ndarray]:
_, H_img, W_img = image.shape
H_msk, W_msk = mask.shape
H = min(H_img, H_msk)
W = min(W_img, W_msk)
stride = patch_size // 2
imgs, masks = [], []
# Build step lists that always include the last valid position (covers edges)
def steps(size):
s = list(range(0, size - patch_size + 1, stride))
if not s:
s = [0] if size >= patch_size else []
elif s[-1] < size - patch_size:
s.append(size - patch_size)
return s
for y in steps(H):
for x in steps(W):
pm = mask[y : y + patch_size, x : x + patch_size]
pi = image[:, y : y + patch_size, x : x + patch_size]
# Include any patch that contains at least one labeled pixel
if pm.shape == (patch_size, patch_size) and (pm != IGNORE_INDEX).any():
imgs.append(pi)
masks.append(pm)
# Last resort: pad with zeros/IGNORE if image is smaller than patch_size
if not imgs:
ph = min(patch_size, H)
pw = min(patch_size, W)
img_pad = np.zeros((image.shape[0], patch_size, patch_size), dtype=np.float32)
mask_pad = np.full((patch_size, patch_size), IGNORE_INDEX, dtype=np.int64)
img_pad[:, :ph, :pw] = image[:, :ph, :pw]
mask_pad[:ph, :pw] = mask[:ph, :pw]
imgs.append(img_pad)
masks.append(mask_pad)
return np.stack(imgs).astype(np.float32), np.stack(masks).astype(np.int64)
# ── Dataset class ─────────────────────────────────────────────
class MultiSpectralDataset(Dataset):
def __init__(self, images: np.ndarray, masks: np.ndarray):
self.images = images.astype(np.float32)
self.masks = masks.astype(np.int64)
def __len__(self):
return len(self.images)
def __getitem__(self, idx: int):
return torch.from_numpy(self.images[idx]), torch.from_numpy(self.masks[idx])
# ── Public API ────────────────────────────────────────────────
def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict:
# Download and stack bands
band_arrays = [_read_band(_download(f)) for f in BAND_FILES]
image = _normalize(np.stack(band_arrays, axis=0)) # (7, H, W) float32
# Read raw masks + metadata
raw_train, nd_train, info_train = _read_mask_raw(_download(TRAIN_MASK_FILE))
raw_val, nd_val, info_val = _read_mask_raw(_download(VAL_MASK_FILE))
# Remap to 0-indexed classes
train_mask, train_vals = _remap_mask(raw_train, nd_train)
val_mask, val_vals = _remap_mask(raw_val, nd_val)
if not train_vals:
raise ValueError(
f"TRAINING.tif has no labeled pixels after nodata removal. "
f"File info: {info_train} | Unique raw values: {np.unique(raw_train).tolist()}"
)
# Extract patches
tr_imgs, tr_masks = _extract_patches(image, train_mask, patch_size)
va_imgs, va_masks = _extract_patches(image, val_mask, patch_size)
train_labeled = int((train_mask != IGNORE_INDEX).sum())
val_labeled = int((val_mask != IGNORE_INDEX).sum())
def _class_dist(mask, total):
parts = []
for i, name in enumerate(CLASS_NAMES):
n = int((mask == i).sum())
parts.append(f"{name}: {n:,} ({n / max(1, total) * 100:.1f}%)")
return " | ".join(parts)
status = "\n".join([
f"Train patches: **{len(tr_imgs)}** | Val patches: **{len(va_imgs)}** | Patch: **{patch_size}Γ—{patch_size}**",
"",
f"**TRAINING.tif** `{info_train}`",
f"Raw values β†’ classes: `{dict(zip(train_vals, CLASS_NAMES[:len(train_vals)]))}`",
f"Labeled pixels: **{train_labeled:,}** β€” {_class_dist(train_mask, train_labeled)}",
"",
f"**GROUND TRUTH.tif** `{info_val}`",
f"Raw values β†’ classes: `{dict(zip(val_vals, CLASS_NAMES[:len(val_vals)]))}`",
f"Labeled pixels: **{val_labeled:,}** β€” {_class_dist(val_mask, val_labeled)}",
])
return {
"train_images": tr_imgs,
"train_masks": tr_masks,
"val_images": va_imgs,
"val_masks": va_masks,
"status": status,
}