""" Preprocessing pipeline for Prithvi EO V2 flood inference. Handles GeoTIFF loading, band selection, normalization, and tiling to produce model-ready [B, 6, 224, 224] tensors. """ from dataclasses import dataclass from typing import List, Optional, Tuple import numpy as np import rasterio import torch # Sentinel-2 6-band indices from 13-band product (0-indexed) S2_6BAND_INDICES = [1, 2, 3, 8, 11, 12] S2_6BAND_NAMES = ["B02", "B03", "B04", "B8A", "B11", "B12"] # HLS normalization statistics (for 6 bands: Blue, Green, Red, NIR, SWIR1, SWIR2) HLS_MEANS = [0.033, 0.055, 0.054, 0.197, 0.120, 0.073] HLS_STDS = [0.023, 0.031, 0.040, 0.071, 0.058, 0.047] INPUT_SIZE = 224 @dataclass class GeoMetadata: """Geospatial metadata extracted from a GeoTIFF.""" crs: Optional[str] transform: Optional[List[float]] bounds: Optional[List[float]] width: int height: int num_bands: int def load_geotiff(path: str) -> Tuple[np.ndarray, GeoMetadata]: """Load a GeoTIFF and return pixel data + geo metadata. Args: path: Path to the .tif file. Returns: (data [C, H, W] float32, GeoMetadata) """ with rasterio.open(path) as src: data = src.read().astype(np.float32) # [C, H, W] meta = GeoMetadata( crs=str(src.crs) if src.crs else None, transform=list(src.transform)[:6] if src.transform else None, bounds=list(src.bounds) if src.bounds else None, width=src.width, height=src.height, num_bands=src.count, ) return data, meta def select_bands(data: np.ndarray) -> np.ndarray: """Select 6 Sentinel-2 bands from a 13-band product. If already 6 bands, returns as-is. Args: data: [C, H, W] array. Returns: [6, H, W] array with bands [B02, B03, B04, B8A, B11, B12]. """ if data.shape[0] == 13: return data[S2_6BAND_INDICES, :, :] if data.shape[0] == 6: return data raise ValueError( f"Expected 6 or 13 bands, got {data.shape[0]}. " f"Provide a 6-band or 13-band Sentinel-2 GeoTIFF." ) def normalize_reflectance(data: np.ndarray) -> np.ndarray: """Scale raw UInt16 reflectance (0-10000) to 0-1 range. Args: data: [C, H, W] float32 array. Returns: [C, H, W] array in 0-1 range. """ if data.max() > 10.0: data = data / 10000.0 return data def normalize_per_band( data: np.ndarray, means: Optional[List[float]] = None, stds: Optional[List[float]] = None, ) -> np.ndarray: """Apply per-band HLS normalization. Args: data: [C, H, W] float32 array in 0-1 range. means: Per-band means (defaults to HLS statistics). stds: Per-band standard deviations (defaults to HLS statistics). Returns: Normalized [C, H, W] array. """ means = means or HLS_MEANS stds = stds or HLS_STDS for i in range(data.shape[0]): data[i] = (data[i] - means[i]) / (stds[i] + 1e-8) return data def center_crop(data: np.ndarray, size: int = INPUT_SIZE) -> np.ndarray: """Center-crop spatial dimensions to size x size. Args: data: [C, H, W] array. size: Target spatial dimension. Returns: [C, size, size] array. """ _, h, w = data.shape if h < size or w < size: # Pad if smaller pad_h = max(0, size - h) pad_w = max(0, size - w) data = np.pad( data, ((0, 0), (pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)), mode="constant", constant_values=0, ) _, h, w = data.shape top = (h - size) // 2 left = (w - size) // 2 return data[:, top : top + size, left : left + size] def tile_image( data: np.ndarray, tile_size: int = INPUT_SIZE, overlap: int = 0 ) -> Tuple[List[np.ndarray], List[Tuple[int, int]], Tuple[int, int]]: """Tile a large image into tile_size x tile_size patches. Args: data: [C, H, W] normalized array. tile_size: Patch size. overlap: Pixel overlap between adjacent tiles. Returns: (tiles, positions, original_hw) where: tiles: list of [C, tile_size, tile_size] arrays positions: list of (top, left) offsets original_hw: (H, W) of the padded image """ _, h, w = data.shape stride = tile_size - overlap # Pad to ensure full coverage pad_h = (stride - (h % stride)) % stride if h % stride != 0 else 0 pad_w = (stride - (w % stride)) % stride if w % stride != 0 else 0 if pad_h > 0 or pad_w > 0: data = np.pad(data, ((0, 0), (0, pad_h), (0, pad_w)), mode="constant") _, h_padded, w_padded = data.shape tiles = [] positions = [] for top in range(0, h_padded - tile_size + 1, stride): for left in range(0, w_padded - tile_size + 1, stride): tile = data[:, top : top + tile_size, left : left + tile_size] tiles.append(tile) positions.append((top, left)) return tiles, positions, (h_padded, w_padded) def stitch_tiles( tiles: List[np.ndarray], positions: List[Tuple[int, int]], canvas_hw: Tuple[int, int], tile_size: int = INPUT_SIZE, ) -> np.ndarray: """Stitch tiles back into a full-size image using averaging for overlaps. Args: tiles: List of [H_tile, W_tile] prediction arrays. positions: (top, left) for each tile. canvas_hw: (H, W) of the output canvas. tile_size: Size of each tile. Returns: [H, W] stitched array. """ h, w = canvas_hw canvas = np.zeros((h, w), dtype=np.float64) counts = np.zeros((h, w), dtype=np.float64) for tile, (top, left) in zip(tiles, positions): canvas[top : top + tile_size, left : left + tile_size] += tile.astype(np.float64) counts[top : top + tile_size, left : left + tile_size] += 1.0 counts = np.maximum(counts, 1.0) return (canvas / counts).astype(np.float32) def preprocess_geotiff( path: str, tile_size: int = INPUT_SIZE, overlap: int = 0, device: str = "cpu", ) -> Tuple[torch.Tensor, GeoMetadata, Optional[List[Tuple[int, int]]], Optional[Tuple[int, int]]]: """Full preprocessing pipeline: load -> select bands -> normalize -> tile/crop. For images <= tile_size, returns a single center-cropped tensor. For larger images, returns tiled patches. Args: path: Path to GeoTIFF. tile_size: Target tile size (224 for Prithvi). overlap: Overlap for tiling large images. device: Target tensor device. Returns: (tensor [B, 6, H, W], geo_metadata, positions_or_None, canvas_hw_or_None) """ data, meta = load_geotiff(path) data = select_bands(data) data = normalize_reflectance(data) data = normalize_per_band(data) _, h, w = data.shape if h <= tile_size and w <= tile_size: # Single center crop cropped = center_crop(data, tile_size) tensor = torch.from_numpy(cropped).unsqueeze(0).to(device) return tensor, meta, None, None else: # Tile large image tiles, positions, canvas_hw = tile_image(data, tile_size, overlap) batch = np.stack(tiles, axis=0) # [N, C, H, W] tensor = torch.from_numpy(batch).to(device) return tensor, meta, positions, canvas_hw