File size: 3,169 Bytes
984cdba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
from tqdm import tqdm
import rasterio
from rasterio.windows import Window
import os
from glob import glob

def compute_band_statistics(

    root,

    locations,

    modalities=("S1", "S2", "DEM", "Hillshade", "Cloudmask"),

    patch_size=256,

    stride=256

):
    """

    Compute per-band mean and std for each modality using only training locations.

    Ignores NaNs. Supports tqdm and optional logger.

    """

    # Storage for running stats
    sums     = {m: None for m in modalities}
    sq_sums  = {m: None for m in modalities}
    counts   = {m: None for m in modalities}

    # Loop over all locations with tqdm
    for loc in tqdm(locations, desc="Locations"):
        loc_dir = os.path.join(root, loc)

        # Only TIFFs
        files = sorted(glob(os.path.join(loc_dir, "*.tif")))

        # Loop images in location
        for path in tqdm(files, desc=f"Files in {loc}", leave=False):

            # Detect modality
            if path.endswith("_s2.tif") and "S2" in modalities:
                key = "S2"
            elif path.endswith("_s1.tif") and "S1" in modalities:
                key = "S1"
            elif path.endswith("_dem.tif") and "DEM" in modalities:
                key = "DEM"
            elif path.endswith("_hillshade.tif") and "Hillshade" in modalities:
                key = "Hillshade"
            elif path.endswith("_cloud_mask.tif") and "Cloudmask" in modalities:
                key = "Cloudmask"
            else:
                continue

            with rasterio.open(path) as src:
                H, W = src.height, src.width
                C = src.count

                # Initialize accumulators
                if sums[key] is None:
                    sums[key]    = np.zeros(C, dtype=np.float64)
                    sq_sums[key] = np.zeros(C, dtype=np.float64)
                    counts[key]  = np.zeros(C, dtype=np.float64)

                # Iterate patches
                for y in range(0, H - patch_size + 1, stride):
                    for x in range(0, W - patch_size + 1, stride):
                        window = Window(x, y, patch_size, patch_size)
                        patch = src.read(window=window).astype(np.float64)

                        patch = patch.reshape(C, -1)

                        valid_mask = np.isfinite(patch)

                        valid_values = np.where(valid_mask, patch, 0)

                        # Accumulate
                        sums[key]    += valid_values.sum(axis=1)
                        sq_sums[key] += (valid_values ** 2).sum(axis=1)
                        counts[key]  += valid_mask.sum(axis=1)

    # ----- Final stats -----
    band_stats = {}

    for m in modalities:

        if counts[m] is None or (counts[m] == 0).all():
            continue

        mean = sums[m] / counts[m]
        sq_mean = sq_sums[m] / counts[m]
        var = sq_mean - mean**2
        var[var < 0] = 0
        std = np.sqrt(var)

        band_stats[m] = {
            "mean": mean.tolist(),
            "std": std.tolist()
        }

    return band_stats