prithvi-flood-detection / app /preprocessing.py
Tushar365's picture
Deploy Prithvi flood detection app
400ed1e verified
"""
Preprocessing pipeline for Prithvi EO V2 flood inference.
Handles GeoTIFF loading, band selection, normalization, and tiling
to produce model-ready [B, 6, 224, 224] tensors.
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
import rasterio
import torch
# Sentinel-2 6-band indices from 13-band product (0-indexed)
S2_6BAND_INDICES = [1, 2, 3, 8, 11, 12]
S2_6BAND_NAMES = ["B02", "B03", "B04", "B8A", "B11", "B12"]
# HLS normalization statistics (for 6 bands: Blue, Green, Red, NIR, SWIR1, SWIR2)
HLS_MEANS = [0.033, 0.055, 0.054, 0.197, 0.120, 0.073]
HLS_STDS = [0.023, 0.031, 0.040, 0.071, 0.058, 0.047]
INPUT_SIZE = 224
@dataclass
class GeoMetadata:
"""Geospatial metadata extracted from a GeoTIFF."""
crs: Optional[str]
transform: Optional[List[float]]
bounds: Optional[List[float]]
width: int
height: int
num_bands: int
def load_geotiff(path: str) -> Tuple[np.ndarray, GeoMetadata]:
"""Load a GeoTIFF and return pixel data + geo metadata.
Args:
path: Path to the .tif file.
Returns:
(data [C, H, W] float32, GeoMetadata)
"""
with rasterio.open(path) as src:
data = src.read().astype(np.float32) # [C, H, W]
meta = GeoMetadata(
crs=str(src.crs) if src.crs else None,
transform=list(src.transform)[:6] if src.transform else None,
bounds=list(src.bounds) if src.bounds else None,
width=src.width,
height=src.height,
num_bands=src.count,
)
return data, meta
def select_bands(data: np.ndarray) -> np.ndarray:
"""Select 6 Sentinel-2 bands from a 13-band product.
If already 6 bands, returns as-is.
Args:
data: [C, H, W] array.
Returns:
[6, H, W] array with bands [B02, B03, B04, B8A, B11, B12].
"""
if data.shape[0] == 13:
return data[S2_6BAND_INDICES, :, :]
if data.shape[0] == 6:
return data
raise ValueError(
f"Expected 6 or 13 bands, got {data.shape[0]}. "
f"Provide a 6-band or 13-band Sentinel-2 GeoTIFF."
)
def normalize_reflectance(data: np.ndarray) -> np.ndarray:
"""Scale raw UInt16 reflectance (0-10000) to 0-1 range.
Args:
data: [C, H, W] float32 array.
Returns:
[C, H, W] array in 0-1 range.
"""
if data.max() > 10.0:
data = data / 10000.0
return data
def normalize_per_band(
data: np.ndarray,
means: Optional[List[float]] = None,
stds: Optional[List[float]] = None,
) -> np.ndarray:
"""Apply per-band HLS normalization.
Args:
data: [C, H, W] float32 array in 0-1 range.
means: Per-band means (defaults to HLS statistics).
stds: Per-band standard deviations (defaults to HLS statistics).
Returns:
Normalized [C, H, W] array.
"""
means = means or HLS_MEANS
stds = stds or HLS_STDS
for i in range(data.shape[0]):
data[i] = (data[i] - means[i]) / (stds[i] + 1e-8)
return data
def center_crop(data: np.ndarray, size: int = INPUT_SIZE) -> np.ndarray:
"""Center-crop spatial dimensions to size x size.
Args:
data: [C, H, W] array.
size: Target spatial dimension.
Returns:
[C, size, size] array.
"""
_, h, w = data.shape
if h < size or w < size:
# Pad if smaller
pad_h = max(0, size - h)
pad_w = max(0, size - w)
data = np.pad(
data,
((0, 0), (pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)),
mode="constant",
constant_values=0,
)
_, h, w = data.shape
top = (h - size) // 2
left = (w - size) // 2
return data[:, top : top + size, left : left + size]
def tile_image(
data: np.ndarray, tile_size: int = INPUT_SIZE, overlap: int = 0
) -> Tuple[List[np.ndarray], List[Tuple[int, int]], Tuple[int, int]]:
"""Tile a large image into tile_size x tile_size patches.
Args:
data: [C, H, W] normalized array.
tile_size: Patch size.
overlap: Pixel overlap between adjacent tiles.
Returns:
(tiles, positions, original_hw) where:
tiles: list of [C, tile_size, tile_size] arrays
positions: list of (top, left) offsets
original_hw: (H, W) of the padded image
"""
_, h, w = data.shape
stride = tile_size - overlap
# Pad to ensure full coverage
pad_h = (stride - (h % stride)) % stride if h % stride != 0 else 0
pad_w = (stride - (w % stride)) % stride if w % stride != 0 else 0
if pad_h > 0 or pad_w > 0:
data = np.pad(data, ((0, 0), (0, pad_h), (0, pad_w)), mode="constant")
_, h_padded, w_padded = data.shape
tiles = []
positions = []
for top in range(0, h_padded - tile_size + 1, stride):
for left in range(0, w_padded - tile_size + 1, stride):
tile = data[:, top : top + tile_size, left : left + tile_size]
tiles.append(tile)
positions.append((top, left))
return tiles, positions, (h_padded, w_padded)
def stitch_tiles(
tiles: List[np.ndarray],
positions: List[Tuple[int, int]],
canvas_hw: Tuple[int, int],
tile_size: int = INPUT_SIZE,
) -> np.ndarray:
"""Stitch tiles back into a full-size image using averaging for overlaps.
Args:
tiles: List of [H_tile, W_tile] prediction arrays.
positions: (top, left) for each tile.
canvas_hw: (H, W) of the output canvas.
tile_size: Size of each tile.
Returns:
[H, W] stitched array.
"""
h, w = canvas_hw
canvas = np.zeros((h, w), dtype=np.float64)
counts = np.zeros((h, w), dtype=np.float64)
for tile, (top, left) in zip(tiles, positions):
canvas[top : top + tile_size, left : left + tile_size] += tile.astype(np.float64)
counts[top : top + tile_size, left : left + tile_size] += 1.0
counts = np.maximum(counts, 1.0)
return (canvas / counts).astype(np.float32)
def preprocess_geotiff(
path: str,
tile_size: int = INPUT_SIZE,
overlap: int = 0,
device: str = "cpu",
) -> Tuple[torch.Tensor, GeoMetadata, Optional[List[Tuple[int, int]]], Optional[Tuple[int, int]]]:
"""Full preprocessing pipeline: load -> select bands -> normalize -> tile/crop.
For images <= tile_size, returns a single center-cropped tensor.
For larger images, returns tiled patches.
Args:
path: Path to GeoTIFF.
tile_size: Target tile size (224 for Prithvi).
overlap: Overlap for tiling large images.
device: Target tensor device.
Returns:
(tensor [B, 6, H, W], geo_metadata, positions_or_None, canvas_hw_or_None)
"""
data, meta = load_geotiff(path)
data = select_bands(data)
data = normalize_reflectance(data)
data = normalize_per_band(data)
_, h, w = data.shape
if h <= tile_size and w <= tile_size:
# Single center crop
cropped = center_crop(data, tile_size)
tensor = torch.from_numpy(cropped).unsqueeze(0).to(device)
return tensor, meta, None, None
else:
# Tile large image
tiles, positions, canvas_hw = tile_image(data, tile_size, overlap)
batch = np.stack(tiles, axis=0) # [N, C, H, W]
tensor = torch.from_numpy(batch).to(device)
return tensor, meta, positions, canvas_hw