| """ |
| Preprocessing pipeline for Prithvi EO V2 flood inference. |
| |
| Handles GeoTIFF loading, band selection, normalization, and tiling |
| to produce model-ready [B, 6, 224, 224] tensors. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
| import rasterio |
| import torch |
|
|
| |
| S2_6BAND_INDICES = [1, 2, 3, 8, 11, 12] |
| S2_6BAND_NAMES = ["B02", "B03", "B04", "B8A", "B11", "B12"] |
|
|
| |
| HLS_MEANS = [0.033, 0.055, 0.054, 0.197, 0.120, 0.073] |
| HLS_STDS = [0.023, 0.031, 0.040, 0.071, 0.058, 0.047] |
|
|
| INPUT_SIZE = 224 |
|
|
|
|
| @dataclass |
| class GeoMetadata: |
| """Geospatial metadata extracted from a GeoTIFF.""" |
|
|
| crs: Optional[str] |
| transform: Optional[List[float]] |
| bounds: Optional[List[float]] |
| width: int |
| height: int |
| num_bands: int |
|
|
|
|
| def load_geotiff(path: str) -> Tuple[np.ndarray, GeoMetadata]: |
| """Load a GeoTIFF and return pixel data + geo metadata. |
| |
| Args: |
| path: Path to the .tif file. |
| |
| Returns: |
| (data [C, H, W] float32, GeoMetadata) |
| """ |
| with rasterio.open(path) as src: |
| data = src.read().astype(np.float32) |
| meta = GeoMetadata( |
| crs=str(src.crs) if src.crs else None, |
| transform=list(src.transform)[:6] if src.transform else None, |
| bounds=list(src.bounds) if src.bounds else None, |
| width=src.width, |
| height=src.height, |
| num_bands=src.count, |
| ) |
| return data, meta |
|
|
|
|
| def select_bands(data: np.ndarray) -> np.ndarray: |
| """Select 6 Sentinel-2 bands from a 13-band product. |
| |
| If already 6 bands, returns as-is. |
| |
| Args: |
| data: [C, H, W] array. |
| |
| Returns: |
| [6, H, W] array with bands [B02, B03, B04, B8A, B11, B12]. |
| """ |
| if data.shape[0] == 13: |
| return data[S2_6BAND_INDICES, :, :] |
| if data.shape[0] == 6: |
| return data |
| raise ValueError( |
| f"Expected 6 or 13 bands, got {data.shape[0]}. " |
| f"Provide a 6-band or 13-band Sentinel-2 GeoTIFF." |
| ) |
|
|
|
|
| def normalize_reflectance(data: np.ndarray) -> np.ndarray: |
| """Scale raw UInt16 reflectance (0-10000) to 0-1 range. |
| |
| Args: |
| data: [C, H, W] float32 array. |
| |
| Returns: |
| [C, H, W] array in 0-1 range. |
| """ |
| if data.max() > 10.0: |
| data = data / 10000.0 |
| return data |
|
|
|
|
| def normalize_per_band( |
| data: np.ndarray, |
| means: Optional[List[float]] = None, |
| stds: Optional[List[float]] = None, |
| ) -> np.ndarray: |
| """Apply per-band HLS normalization. |
| |
| Args: |
| data: [C, H, W] float32 array in 0-1 range. |
| means: Per-band means (defaults to HLS statistics). |
| stds: Per-band standard deviations (defaults to HLS statistics). |
| |
| Returns: |
| Normalized [C, H, W] array. |
| """ |
| means = means or HLS_MEANS |
| stds = stds or HLS_STDS |
| for i in range(data.shape[0]): |
| data[i] = (data[i] - means[i]) / (stds[i] + 1e-8) |
| return data |
|
|
|
|
| def center_crop(data: np.ndarray, size: int = INPUT_SIZE) -> np.ndarray: |
| """Center-crop spatial dimensions to size x size. |
| |
| Args: |
| data: [C, H, W] array. |
| size: Target spatial dimension. |
| |
| Returns: |
| [C, size, size] array. |
| """ |
| _, h, w = data.shape |
| if h < size or w < size: |
| |
| pad_h = max(0, size - h) |
| pad_w = max(0, size - w) |
| data = np.pad( |
| data, |
| ((0, 0), (pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)), |
| mode="constant", |
| constant_values=0, |
| ) |
| _, h, w = data.shape |
| top = (h - size) // 2 |
| left = (w - size) // 2 |
| return data[:, top : top + size, left : left + size] |
|
|
|
|
| def tile_image( |
| data: np.ndarray, tile_size: int = INPUT_SIZE, overlap: int = 0 |
| ) -> Tuple[List[np.ndarray], List[Tuple[int, int]], Tuple[int, int]]: |
| """Tile a large image into tile_size x tile_size patches. |
| |
| Args: |
| data: [C, H, W] normalized array. |
| tile_size: Patch size. |
| overlap: Pixel overlap between adjacent tiles. |
| |
| Returns: |
| (tiles, positions, original_hw) where: |
| tiles: list of [C, tile_size, tile_size] arrays |
| positions: list of (top, left) offsets |
| original_hw: (H, W) of the padded image |
| """ |
| _, h, w = data.shape |
| stride = tile_size - overlap |
|
|
| |
| pad_h = (stride - (h % stride)) % stride if h % stride != 0 else 0 |
| pad_w = (stride - (w % stride)) % stride if w % stride != 0 else 0 |
| if pad_h > 0 or pad_w > 0: |
| data = np.pad(data, ((0, 0), (0, pad_h), (0, pad_w)), mode="constant") |
|
|
| _, h_padded, w_padded = data.shape |
| tiles = [] |
| positions = [] |
|
|
| for top in range(0, h_padded - tile_size + 1, stride): |
| for left in range(0, w_padded - tile_size + 1, stride): |
| tile = data[:, top : top + tile_size, left : left + tile_size] |
| tiles.append(tile) |
| positions.append((top, left)) |
|
|
| return tiles, positions, (h_padded, w_padded) |
|
|
|
|
| def stitch_tiles( |
| tiles: List[np.ndarray], |
| positions: List[Tuple[int, int]], |
| canvas_hw: Tuple[int, int], |
| tile_size: int = INPUT_SIZE, |
| ) -> np.ndarray: |
| """Stitch tiles back into a full-size image using averaging for overlaps. |
| |
| Args: |
| tiles: List of [H_tile, W_tile] prediction arrays. |
| positions: (top, left) for each tile. |
| canvas_hw: (H, W) of the output canvas. |
| tile_size: Size of each tile. |
| |
| Returns: |
| [H, W] stitched array. |
| """ |
| h, w = canvas_hw |
| canvas = np.zeros((h, w), dtype=np.float64) |
| counts = np.zeros((h, w), dtype=np.float64) |
|
|
| for tile, (top, left) in zip(tiles, positions): |
| canvas[top : top + tile_size, left : left + tile_size] += tile.astype(np.float64) |
| counts[top : top + tile_size, left : left + tile_size] += 1.0 |
|
|
| counts = np.maximum(counts, 1.0) |
| return (canvas / counts).astype(np.float32) |
|
|
|
|
| def preprocess_geotiff( |
| path: str, |
| tile_size: int = INPUT_SIZE, |
| overlap: int = 0, |
| device: str = "cpu", |
| ) -> Tuple[torch.Tensor, GeoMetadata, Optional[List[Tuple[int, int]]], Optional[Tuple[int, int]]]: |
| """Full preprocessing pipeline: load -> select bands -> normalize -> tile/crop. |
| |
| For images <= tile_size, returns a single center-cropped tensor. |
| For larger images, returns tiled patches. |
| |
| Args: |
| path: Path to GeoTIFF. |
| tile_size: Target tile size (224 for Prithvi). |
| overlap: Overlap for tiling large images. |
| device: Target tensor device. |
| |
| Returns: |
| (tensor [B, 6, H, W], geo_metadata, positions_or_None, canvas_hw_or_None) |
| """ |
| data, meta = load_geotiff(path) |
| data = select_bands(data) |
| data = normalize_reflectance(data) |
| data = normalize_per_band(data) |
|
|
| _, h, w = data.shape |
|
|
| if h <= tile_size and w <= tile_size: |
| |
| cropped = center_crop(data, tile_size) |
| tensor = torch.from_numpy(cropped).unsqueeze(0).to(device) |
| return tensor, meta, None, None |
| else: |
| |
| tiles, positions, canvas_hw = tile_image(data, tile_size, overlap) |
| batch = np.stack(tiles, axis=0) |
| tensor = torch.from_numpy(batch).to(device) |
| return tensor, meta, positions, canvas_hw |
|
|