""" Per-modality preprocessing for satellite imagery. Handles different channel counts: - Optical RGB: 3 channels (R, G, B) - SAR: 2 channels (VV, VH) - Multispectral: 12 channels (Sentinel-2 bands) """ import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import numpy as np # ImageNet normalization for RGB IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # Sentinel-2 band statistics (approximate) SENTINEL2_MEAN = [1353.0, 1117.0, 1042.0, 947.0, 1199.0, 1645.0, 1849.0, 1793.0, 1859.0, 1008.0, 1593.0, 1064.0] SENTINEL2_STD = [235.0, 309.0, 392.0, 597.0, 490.0, 625.0, 736.0, 755.0, 846.0, 487.0, 561.0, 459.0] # SAR statistics (approximate, in dB) SAR_MEAN = [-12.0, -18.0] SAR_STD = [5.0, 5.0] def get_optical_transform(size: int = 224) -> transforms.Compose: """Get transforms for optical RGB images.""" return transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ]) def get_sar_transform(size: int = 224) -> transforms.Compose: """Get transforms for SAR images (VV/VH channels).""" return transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(mean=SAR_MEAN, std=SAR_STD) ]) def get_multispectral_transform(size: int = 224) -> transforms.Compose: """Get transforms for multispectral images (12 channels).""" return transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(mean=SENTINEL2_MEAN, std=SENTINEL2_STD) ]) def preprocess_image( image: Image.Image, modality: str, size: int = 224 ) -> torch.Tensor: """ Preprocess image based on modality. Args: image: Input PIL image modality: "optical", "sar", or "multispectral" size: Output image size Returns: Preprocessed tensor """ # Handle channel mismatch before applying transform if modality == "sar": # SAR expects 2 channels, but PIL images are typically 3 channels # Convert to numpy, take first 2 channels, convert back img_array = np.array(image) if img_array.shape[-1] == 3: img_array = img_array[..., :2] image = Image.fromarray(img_array) transform = get_sar_transform(size) elif modality == "optical": transform = get_optical_transform(size) elif modality == "multispectral": transform = get_multispectral_transform(size) else: raise ValueError(f"Unknown modality: {modality}") return transform(image) def handle_channels( image: np.ndarray, target_channels: int, modality: str ) -> np.ndarray: """ Handle channel mismatch for different modalities. Args: image: Input image array (H, W, C) target_channels: Expected number of channels modality: Modality type Returns: Image with correct number of channels """ current_channels = image.shape[-1] if len(image.shape) == 3 else 1 if current_channels == target_channels: return image # ponytail: simple channel handling, not perfect but works for v1 if modality == "optical" and current_channels >= 3: # Take first 3 channels (RGB) return image[..., :3] elif modality == "sar" and current_channels >= 2: # Take first 2 channels (VV, VH) return image[..., :2] elif modality == "multispectral": if current_channels < target_channels: # Pad with zeros padding = np.zeros((*image.shape[:-1], target_channels - current_channels)) return np.concatenate([image, padding], axis=-1) else: # Take first 12 channels return image[..., :target_channels] return image # Self-check if __name__ == "__main__": # Create dummy images for testing dummy_rgb = Image.fromarray(np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)) dummy_sar = Image.fromarray(np.random.randint(0, 255, (256, 256, 2), dtype=np.uint8)) # Test preprocessing optical_tensor = preprocess_image(dummy_rgb, "optical") sar_tensor = preprocess_image(dummy_sar, "sar") print(f"Optical shape: {optical_tensor.shape}") # Should be [3, 224, 224] print(f"SAR shape: {sar_tensor.shape}") # Should be [2, 224, 224] print("Preprocessing test passed!")