File size: 7,413 Bytes
400ed1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
Preprocessing pipeline for Prithvi EO V2 flood inference.

Handles GeoTIFF loading, band selection, normalization, and tiling
to produce model-ready [B, 6, 224, 224] tensors.
"""

from dataclasses import dataclass
from typing import List, Optional, Tuple

import numpy as np
import rasterio
import torch

# Sentinel-2 6-band indices from 13-band product (0-indexed)
S2_6BAND_INDICES = [1, 2, 3, 8, 11, 12]
S2_6BAND_NAMES = ["B02", "B03", "B04", "B8A", "B11", "B12"]

# HLS normalization statistics (for 6 bands: Blue, Green, Red, NIR, SWIR1, SWIR2)
HLS_MEANS = [0.033, 0.055, 0.054, 0.197, 0.120, 0.073]
HLS_STDS = [0.023, 0.031, 0.040, 0.071, 0.058, 0.047]

INPUT_SIZE = 224


@dataclass
class GeoMetadata:
    """Geospatial metadata extracted from a GeoTIFF."""

    crs: Optional[str]
    transform: Optional[List[float]]
    bounds: Optional[List[float]]
    width: int
    height: int
    num_bands: int


def load_geotiff(path: str) -> Tuple[np.ndarray, GeoMetadata]:
    """Load a GeoTIFF and return pixel data + geo metadata.

    Args:
        path: Path to the .tif file.

    Returns:
        (data [C, H, W] float32, GeoMetadata)
    """
    with rasterio.open(path) as src:
        data = src.read().astype(np.float32)  # [C, H, W]
        meta = GeoMetadata(
            crs=str(src.crs) if src.crs else None,
            transform=list(src.transform)[:6] if src.transform else None,
            bounds=list(src.bounds) if src.bounds else None,
            width=src.width,
            height=src.height,
            num_bands=src.count,
        )
    return data, meta


def select_bands(data: np.ndarray) -> np.ndarray:
    """Select 6 Sentinel-2 bands from a 13-band product.

    If already 6 bands, returns as-is.

    Args:
        data: [C, H, W] array.

    Returns:
        [6, H, W] array with bands [B02, B03, B04, B8A, B11, B12].
    """
    if data.shape[0] == 13:
        return data[S2_6BAND_INDICES, :, :]
    if data.shape[0] == 6:
        return data
    raise ValueError(
        f"Expected 6 or 13 bands, got {data.shape[0]}. "
        f"Provide a 6-band or 13-band Sentinel-2 GeoTIFF."
    )


def normalize_reflectance(data: np.ndarray) -> np.ndarray:
    """Scale raw UInt16 reflectance (0-10000) to 0-1 range.

    Args:
        data: [C, H, W] float32 array.

    Returns:
        [C, H, W] array in 0-1 range.
    """
    if data.max() > 10.0:
        data = data / 10000.0
    return data


def normalize_per_band(
    data: np.ndarray,
    means: Optional[List[float]] = None,
    stds: Optional[List[float]] = None,
) -> np.ndarray:
    """Apply per-band HLS normalization.

    Args:
        data: [C, H, W] float32 array in 0-1 range.
        means: Per-band means (defaults to HLS statistics).
        stds: Per-band standard deviations (defaults to HLS statistics).

    Returns:
        Normalized [C, H, W] array.
    """
    means = means or HLS_MEANS
    stds = stds or HLS_STDS
    for i in range(data.shape[0]):
        data[i] = (data[i] - means[i]) / (stds[i] + 1e-8)
    return data


def center_crop(data: np.ndarray, size: int = INPUT_SIZE) -> np.ndarray:
    """Center-crop spatial dimensions to size x size.

    Args:
        data: [C, H, W] array.
        size: Target spatial dimension.

    Returns:
        [C, size, size] array.
    """
    _, h, w = data.shape
    if h < size or w < size:
        # Pad if smaller
        pad_h = max(0, size - h)
        pad_w = max(0, size - w)
        data = np.pad(
            data,
            ((0, 0), (pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)),
            mode="constant",
            constant_values=0,
        )
        _, h, w = data.shape
    top = (h - size) // 2
    left = (w - size) // 2
    return data[:, top : top + size, left : left + size]


def tile_image(
    data: np.ndarray, tile_size: int = INPUT_SIZE, overlap: int = 0
) -> Tuple[List[np.ndarray], List[Tuple[int, int]], Tuple[int, int]]:
    """Tile a large image into tile_size x tile_size patches.

    Args:
        data: [C, H, W] normalized array.
        tile_size: Patch size.
        overlap: Pixel overlap between adjacent tiles.

    Returns:
        (tiles, positions, original_hw) where:
            tiles: list of [C, tile_size, tile_size] arrays
            positions: list of (top, left) offsets
            original_hw: (H, W) of the padded image
    """
    _, h, w = data.shape
    stride = tile_size - overlap

    # Pad to ensure full coverage
    pad_h = (stride - (h % stride)) % stride if h % stride != 0 else 0
    pad_w = (stride - (w % stride)) % stride if w % stride != 0 else 0
    if pad_h > 0 or pad_w > 0:
        data = np.pad(data, ((0, 0), (0, pad_h), (0, pad_w)), mode="constant")

    _, h_padded, w_padded = data.shape
    tiles = []
    positions = []

    for top in range(0, h_padded - tile_size + 1, stride):
        for left in range(0, w_padded - tile_size + 1, stride):
            tile = data[:, top : top + tile_size, left : left + tile_size]
            tiles.append(tile)
            positions.append((top, left))

    return tiles, positions, (h_padded, w_padded)


def stitch_tiles(
    tiles: List[np.ndarray],
    positions: List[Tuple[int, int]],
    canvas_hw: Tuple[int, int],
    tile_size: int = INPUT_SIZE,
) -> np.ndarray:
    """Stitch tiles back into a full-size image using averaging for overlaps.

    Args:
        tiles: List of [H_tile, W_tile] prediction arrays.
        positions: (top, left) for each tile.
        canvas_hw: (H, W) of the output canvas.
        tile_size: Size of each tile.

    Returns:
        [H, W] stitched array.
    """
    h, w = canvas_hw
    canvas = np.zeros((h, w), dtype=np.float64)
    counts = np.zeros((h, w), dtype=np.float64)

    for tile, (top, left) in zip(tiles, positions):
        canvas[top : top + tile_size, left : left + tile_size] += tile.astype(np.float64)
        counts[top : top + tile_size, left : left + tile_size] += 1.0

    counts = np.maximum(counts, 1.0)
    return (canvas / counts).astype(np.float32)


def preprocess_geotiff(
    path: str,
    tile_size: int = INPUT_SIZE,
    overlap: int = 0,
    device: str = "cpu",
) -> Tuple[torch.Tensor, GeoMetadata, Optional[List[Tuple[int, int]]], Optional[Tuple[int, int]]]:
    """Full preprocessing pipeline: load -> select bands -> normalize -> tile/crop.

    For images <= tile_size, returns a single center-cropped tensor.
    For larger images, returns tiled patches.

    Args:
        path: Path to GeoTIFF.
        tile_size: Target tile size (224 for Prithvi).
        overlap: Overlap for tiling large images.
        device: Target tensor device.

    Returns:
        (tensor [B, 6, H, W], geo_metadata, positions_or_None, canvas_hw_or_None)
    """
    data, meta = load_geotiff(path)
    data = select_bands(data)
    data = normalize_reflectance(data)
    data = normalize_per_band(data)

    _, h, w = data.shape

    if h <= tile_size and w <= tile_size:
        # Single center crop
        cropped = center_crop(data, tile_size)
        tensor = torch.from_numpy(cropped).unsqueeze(0).to(device)
        return tensor, meta, None, None
    else:
        # Tile large image
        tiles, positions, canvas_hw = tile_image(data, tile_size, overlap)
        batch = np.stack(tiles, axis=0)  # [N, C, H, W]
        tensor = torch.from_numpy(batch).to(device)
        return tensor, meta, positions, canvas_hw