| """ |
| Per-modality preprocessing for satellite imagery. |
| |
| Handles different channel counts: |
| - Optical RGB: 3 channels (R, G, B) |
| - SAR: 2 channels (VV, VH) |
| - Multispectral: 12 channels (Sentinel-2 bands) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from PIL import Image |
| import numpy as np |
|
|
|
|
| |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
| |
| SENTINEL2_MEAN = [1353.0, 1117.0, 1042.0, 947.0, 1199.0, 1645.0, 1849.0, 1793.0, 1859.0, 1008.0, 1593.0, 1064.0] |
| SENTINEL2_STD = [235.0, 309.0, 392.0, 597.0, 490.0, 625.0, 736.0, 755.0, 846.0, 487.0, 561.0, 459.0] |
|
|
| |
| SAR_MEAN = [-12.0, -18.0] |
| SAR_STD = [5.0, 5.0] |
|
|
|
|
| def get_optical_transform(size: int = 224) -> transforms.Compose: |
| """Get transforms for optical RGB images.""" |
| return transforms.Compose([ |
| transforms.Resize(size), |
| transforms.CenterCrop(size), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) |
| ]) |
|
|
|
|
| def get_sar_transform(size: int = 224) -> transforms.Compose: |
| """Get transforms for SAR images (VV/VH channels).""" |
| return transforms.Compose([ |
| transforms.Resize(size), |
| transforms.CenterCrop(size), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=SAR_MEAN, std=SAR_STD) |
| ]) |
|
|
|
|
| def get_multispectral_transform(size: int = 224) -> transforms.Compose: |
| """Get transforms for multispectral images (12 channels).""" |
| return transforms.Compose([ |
| transforms.Resize(size), |
| transforms.CenterCrop(size), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=SENTINEL2_MEAN, std=SENTINEL2_STD) |
| ]) |
|
|
|
|
| def preprocess_image( |
| image: Image.Image, |
| modality: str, |
| size: int = 224 |
| ) -> torch.Tensor: |
| """ |
| Preprocess image based on modality. |
| |
| Args: |
| image: Input PIL image |
| modality: "optical", "sar", or "multispectral" |
| size: Output image size |
| |
| Returns: |
| Preprocessed tensor |
| """ |
| |
| if modality == "sar": |
| |
| |
| img_array = np.array(image) |
| if img_array.shape[-1] == 3: |
| img_array = img_array[..., :2] |
| image = Image.fromarray(img_array) |
| transform = get_sar_transform(size) |
| elif modality == "optical": |
| transform = get_optical_transform(size) |
| elif modality == "multispectral": |
| transform = get_multispectral_transform(size) |
| else: |
| raise ValueError(f"Unknown modality: {modality}") |
| |
| return transform(image) |
|
|
|
|
| def handle_channels( |
| image: np.ndarray, |
| target_channels: int, |
| modality: str |
| ) -> np.ndarray: |
| """ |
| Handle channel mismatch for different modalities. |
| |
| Args: |
| image: Input image array (H, W, C) |
| target_channels: Expected number of channels |
| modality: Modality type |
| |
| Returns: |
| Image with correct number of channels |
| """ |
| current_channels = image.shape[-1] if len(image.shape) == 3 else 1 |
| |
| if current_channels == target_channels: |
| return image |
| |
| |
| if modality == "optical" and current_channels >= 3: |
| |
| return image[..., :3] |
| elif modality == "sar" and current_channels >= 2: |
| |
| return image[..., :2] |
| elif modality == "multispectral": |
| if current_channels < target_channels: |
| |
| padding = np.zeros((*image.shape[:-1], target_channels - current_channels)) |
| return np.concatenate([image, padding], axis=-1) |
| else: |
| |
| return image[..., :target_channels] |
| |
| return image |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| dummy_rgb = Image.fromarray(np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)) |
| dummy_sar = Image.fromarray(np.random.randint(0, 255, (256, 256, 2), dtype=np.uint8)) |
| |
| |
| optical_tensor = preprocess_image(dummy_rgb, "optical") |
| sar_tensor = preprocess_image(dummy_sar, "sar") |
| |
| print(f"Optical shape: {optical_tensor.shape}") |
| print(f"SAR shape: {sar_tensor.shape}") |
| |
| print("Preprocessing test passed!") |