SatFetch-ISRO-BAH2026 / src /data /preprocessing.py
silverballs's picture
Deploy SatFetch Space App - track binary files with Git LFS
0db1cb9
Raw
History Blame Contribute Delete
4.65 kB
"""
Per-modality preprocessing for satellite imagery.
Handles different channel counts:
- Optical RGB: 3 channels (R, G, B)
- SAR: 2 channels (VV, VH)
- Multispectral: 12 channels (Sentinel-2 bands)
"""
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
# ImageNet normalization for RGB
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# Sentinel-2 band statistics (approximate)
SENTINEL2_MEAN = [1353.0, 1117.0, 1042.0, 947.0, 1199.0, 1645.0, 1849.0, 1793.0, 1859.0, 1008.0, 1593.0, 1064.0]
SENTINEL2_STD = [235.0, 309.0, 392.0, 597.0, 490.0, 625.0, 736.0, 755.0, 846.0, 487.0, 561.0, 459.0]
# SAR statistics (approximate, in dB)
SAR_MEAN = [-12.0, -18.0]
SAR_STD = [5.0, 5.0]
def get_optical_transform(size: int = 224) -> transforms.Compose:
"""Get transforms for optical RGB images."""
return transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
def get_sar_transform(size: int = 224) -> transforms.Compose:
"""Get transforms for SAR images (VV/VH channels)."""
return transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize(mean=SAR_MEAN, std=SAR_STD)
])
def get_multispectral_transform(size: int = 224) -> transforms.Compose:
"""Get transforms for multispectral images (12 channels)."""
return transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize(mean=SENTINEL2_MEAN, std=SENTINEL2_STD)
])
def preprocess_image(
image: Image.Image,
modality: str,
size: int = 224
) -> torch.Tensor:
"""
Preprocess image based on modality.
Args:
image: Input PIL image
modality: "optical", "sar", or "multispectral"
size: Output image size
Returns:
Preprocessed tensor
"""
# Handle channel mismatch before applying transform
if modality == "sar":
# SAR expects 2 channels, but PIL images are typically 3 channels
# Convert to numpy, take first 2 channels, convert back
img_array = np.array(image)
if img_array.shape[-1] == 3:
img_array = img_array[..., :2]
image = Image.fromarray(img_array)
transform = get_sar_transform(size)
elif modality == "optical":
transform = get_optical_transform(size)
elif modality == "multispectral":
transform = get_multispectral_transform(size)
else:
raise ValueError(f"Unknown modality: {modality}")
return transform(image)
def handle_channels(
image: np.ndarray,
target_channels: int,
modality: str
) -> np.ndarray:
"""
Handle channel mismatch for different modalities.
Args:
image: Input image array (H, W, C)
target_channels: Expected number of channels
modality: Modality type
Returns:
Image with correct number of channels
"""
current_channels = image.shape[-1] if len(image.shape) == 3 else 1
if current_channels == target_channels:
return image
# ponytail: simple channel handling, not perfect but works for v1
if modality == "optical" and current_channels >= 3:
# Take first 3 channels (RGB)
return image[..., :3]
elif modality == "sar" and current_channels >= 2:
# Take first 2 channels (VV, VH)
return image[..., :2]
elif modality == "multispectral":
if current_channels < target_channels:
# Pad with zeros
padding = np.zeros((*image.shape[:-1], target_channels - current_channels))
return np.concatenate([image, padding], axis=-1)
else:
# Take first 12 channels
return image[..., :target_channels]
return image
# Self-check
if __name__ == "__main__":
# Create dummy images for testing
dummy_rgb = Image.fromarray(np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8))
dummy_sar = Image.fromarray(np.random.randint(0, 255, (256, 256, 2), dtype=np.uint8))
# Test preprocessing
optical_tensor = preprocess_image(dummy_rgb, "optical")
sar_tensor = preprocess_image(dummy_sar, "sar")
print(f"Optical shape: {optical_tensor.shape}") # Should be [3, 224, 224]
print(f"SAR shape: {sar_tensor.shape}") # Should be [2, 224, 224]
print("Preprocessing test passed!")