|
|
import numpy as np
|
|
|
from tqdm import tqdm
|
|
|
import rasterio
|
|
|
from rasterio.windows import Window
|
|
|
import os
|
|
|
from glob import glob
|
|
|
|
|
|
def compute_band_statistics(
|
|
|
root,
|
|
|
locations,
|
|
|
modalities=("S1", "S2", "DEM", "Hillshade", "Cloudmask"),
|
|
|
patch_size=256,
|
|
|
stride=256
|
|
|
):
|
|
|
"""
|
|
|
Compute per-band mean and std for each modality using only training locations.
|
|
|
Ignores NaNs. Supports tqdm and optional logger.
|
|
|
"""
|
|
|
|
|
|
|
|
|
sums = {m: None for m in modalities}
|
|
|
sq_sums = {m: None for m in modalities}
|
|
|
counts = {m: None for m in modalities}
|
|
|
|
|
|
|
|
|
for loc in tqdm(locations, desc="Locations"):
|
|
|
loc_dir = os.path.join(root, loc)
|
|
|
|
|
|
|
|
|
files = sorted(glob(os.path.join(loc_dir, "*.tif")))
|
|
|
|
|
|
|
|
|
for path in tqdm(files, desc=f"Files in {loc}", leave=False):
|
|
|
|
|
|
|
|
|
if path.endswith("_s2.tif") and "S2" in modalities:
|
|
|
key = "S2"
|
|
|
elif path.endswith("_s1.tif") and "S1" in modalities:
|
|
|
key = "S1"
|
|
|
elif path.endswith("_dem.tif") and "DEM" in modalities:
|
|
|
key = "DEM"
|
|
|
elif path.endswith("_hillshade.tif") and "Hillshade" in modalities:
|
|
|
key = "Hillshade"
|
|
|
elif path.endswith("_cloud_mask.tif") and "Cloudmask" in modalities:
|
|
|
key = "Cloudmask"
|
|
|
else:
|
|
|
continue
|
|
|
|
|
|
with rasterio.open(path) as src:
|
|
|
H, W = src.height, src.width
|
|
|
C = src.count
|
|
|
|
|
|
|
|
|
if sums[key] is None:
|
|
|
sums[key] = np.zeros(C, dtype=np.float64)
|
|
|
sq_sums[key] = np.zeros(C, dtype=np.float64)
|
|
|
counts[key] = np.zeros(C, dtype=np.float64)
|
|
|
|
|
|
|
|
|
for y in range(0, H - patch_size + 1, stride):
|
|
|
for x in range(0, W - patch_size + 1, stride):
|
|
|
window = Window(x, y, patch_size, patch_size)
|
|
|
patch = src.read(window=window).astype(np.float64)
|
|
|
|
|
|
patch = patch.reshape(C, -1)
|
|
|
|
|
|
valid_mask = np.isfinite(patch)
|
|
|
|
|
|
valid_values = np.where(valid_mask, patch, 0)
|
|
|
|
|
|
|
|
|
sums[key] += valid_values.sum(axis=1)
|
|
|
sq_sums[key] += (valid_values ** 2).sum(axis=1)
|
|
|
counts[key] += valid_mask.sum(axis=1)
|
|
|
|
|
|
|
|
|
band_stats = {}
|
|
|
|
|
|
for m in modalities:
|
|
|
|
|
|
if counts[m] is None or (counts[m] == 0).all():
|
|
|
continue
|
|
|
|
|
|
mean = sums[m] / counts[m]
|
|
|
sq_mean = sq_sums[m] / counts[m]
|
|
|
var = sq_mean - mean**2
|
|
|
var[var < 0] = 0
|
|
|
std = np.sqrt(var)
|
|
|
|
|
|
band_stats[m] = {
|
|
|
"mean": mean.tolist(),
|
|
|
"std": std.tolist()
|
|
|
}
|
|
|
|
|
|
return band_stats
|
|
|
|