Timerns's picture
Upload folder using huggingface_hub
984cdba verified
import numpy as np
from tqdm import tqdm
import rasterio
from rasterio.windows import Window
import os
from glob import glob
def compute_band_statistics(
root,
locations,
modalities=("S1", "S2", "DEM", "Hillshade", "Cloudmask"),
patch_size=256,
stride=256
):
"""
Compute per-band mean and std for each modality using only training locations.
Ignores NaNs. Supports tqdm and optional logger.
"""
# Storage for running stats
sums = {m: None for m in modalities}
sq_sums = {m: None for m in modalities}
counts = {m: None for m in modalities}
# Loop over all locations with tqdm
for loc in tqdm(locations, desc="Locations"):
loc_dir = os.path.join(root, loc)
# Only TIFFs
files = sorted(glob(os.path.join(loc_dir, "*.tif")))
# Loop images in location
for path in tqdm(files, desc=f"Files in {loc}", leave=False):
# Detect modality
if path.endswith("_s2.tif") and "S2" in modalities:
key = "S2"
elif path.endswith("_s1.tif") and "S1" in modalities:
key = "S1"
elif path.endswith("_dem.tif") and "DEM" in modalities:
key = "DEM"
elif path.endswith("_hillshade.tif") and "Hillshade" in modalities:
key = "Hillshade"
elif path.endswith("_cloud_mask.tif") and "Cloudmask" in modalities:
key = "Cloudmask"
else:
continue
with rasterio.open(path) as src:
H, W = src.height, src.width
C = src.count
# Initialize accumulators
if sums[key] is None:
sums[key] = np.zeros(C, dtype=np.float64)
sq_sums[key] = np.zeros(C, dtype=np.float64)
counts[key] = np.zeros(C, dtype=np.float64)
# Iterate patches
for y in range(0, H - patch_size + 1, stride):
for x in range(0, W - patch_size + 1, stride):
window = Window(x, y, patch_size, patch_size)
patch = src.read(window=window).astype(np.float64)
patch = patch.reshape(C, -1)
valid_mask = np.isfinite(patch)
valid_values = np.where(valid_mask, patch, 0)
# Accumulate
sums[key] += valid_values.sum(axis=1)
sq_sums[key] += (valid_values ** 2).sum(axis=1)
counts[key] += valid_mask.sum(axis=1)
# ----- Final stats -----
band_stats = {}
for m in modalities:
if counts[m] is None or (counts[m] == 0).all():
continue
mean = sums[m] / counts[m]
sq_mean = sq_sums[m] / counts[m]
var = sq_mean - mean**2
var[var < 0] = 0
std = np.sqrt(var)
band_stats[m] = {
"mean": mean.tolist(),
"std": std.tolist()
}
return band_stats