import numpy as np from tqdm import tqdm import rasterio from rasterio.windows import Window import os from glob import glob def compute_band_statistics( root, locations, modalities=("S1", "S2", "DEM", "Hillshade", "Cloudmask"), patch_size=256, stride=256 ): """ Compute per-band mean and std for each modality using only training locations. Ignores NaNs. Supports tqdm and optional logger. """ # Storage for running stats sums = {m: None for m in modalities} sq_sums = {m: None for m in modalities} counts = {m: None for m in modalities} # Loop over all locations with tqdm for loc in tqdm(locations, desc="Locations"): loc_dir = os.path.join(root, loc) # Only TIFFs files = sorted(glob(os.path.join(loc_dir, "*.tif"))) # Loop images in location for path in tqdm(files, desc=f"Files in {loc}", leave=False): # Detect modality if path.endswith("_s2.tif") and "S2" in modalities: key = "S2" elif path.endswith("_s1.tif") and "S1" in modalities: key = "S1" elif path.endswith("_dem.tif") and "DEM" in modalities: key = "DEM" elif path.endswith("_hillshade.tif") and "Hillshade" in modalities: key = "Hillshade" elif path.endswith("_cloud_mask.tif") and "Cloudmask" in modalities: key = "Cloudmask" else: continue with rasterio.open(path) as src: H, W = src.height, src.width C = src.count # Initialize accumulators if sums[key] is None: sums[key] = np.zeros(C, dtype=np.float64) sq_sums[key] = np.zeros(C, dtype=np.float64) counts[key] = np.zeros(C, dtype=np.float64) # Iterate patches for y in range(0, H - patch_size + 1, stride): for x in range(0, W - patch_size + 1, stride): window = Window(x, y, patch_size, patch_size) patch = src.read(window=window).astype(np.float64) patch = patch.reshape(C, -1) valid_mask = np.isfinite(patch) valid_values = np.where(valid_mask, patch, 0) # Accumulate sums[key] += valid_values.sum(axis=1) sq_sums[key] += (valid_values ** 2).sum(axis=1) counts[key] += valid_mask.sum(axis=1) # ----- Final stats ----- band_stats = {} for m in modalities: if counts[m] is None or (counts[m] == 0).all(): continue mean = sums[m] / counts[m] sq_mean = sq_sums[m] / counts[m] var = sq_mean - mean**2 var[var < 0] = 0 std = np.sqrt(var) band_stats[m] = { "mean": mean.tolist(), "std": std.tolist() } return band_stats