| """ |
| Feature extraction using SatCLIP for satellite imagery. |
| |
| SatCLIP is trained on Sentinel-2 data - better than generic CLIP |
| for satellite image retrieval. |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from typing import List, Optional, Tuple |
| from torchvision import transforms |
|
|
| from .satclip_encoder import SatCLIPEncoder |
|
|
| |
| |
| WAVELENGTHS = { |
| "optical": torch.tensor([492.4, 559.8, 664.6]), |
| "sar": torch.tensor([5400.0, 5600.0]), |
| "multispectral": torch.tensor([ |
| 442.0, 492.4, 559.8, 664.6, 704.1, 740.5, 782.8, 832.8, |
| 864.7, 945.1, 1373.5, 1613.7 |
| ]), |
| } |
|
|
| MODALITY_CHANNELS = { |
| "optical": 3, |
| "sar": 2, |
| "multispectral": 12, |
| } |
|
|
|
|
| class FeatureExtractor: |
| """ |
| Extract features from satellite images using SatCLIP. |
| |
| Uses SatCLIP's ViT trained on Sentinel-2 imagery. |
| """ |
|
|
| def __init__(self, device: Optional[str] = None): |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.encoder = SatCLIPEncoder(device=self.device) |
| self.embed_dim = self.encoder.embed_dim |
| self.transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| ]) |
|
|
| def _preprocess(self, image: Image.Image, modality: str) -> torch.Tensor: |
| tensor = self.transform(image).unsqueeze(0) |
| return self._pad_to_13ch(tensor, modality) |
|
|
| def _pad_to_13ch(self, tensor: torch.Tensor, modality: str = "optical") -> torch.Tensor: |
| """Pad tensor to 13 channels for SatCLIP. Handles 1-13 channels.""" |
| n_channels = tensor.shape[1] |
| if n_channels >= 13: |
| return tensor[:, :13, :, :] |
| |
| if n_channels == 1: |
| tensor = tensor.repeat(1, 3, 1, 1) |
| n_channels = 3 |
| |
| if n_channels == 2: |
| third = tensor[:, :1, :, :] |
| tensor = torch.cat([tensor, third], dim=1) |
| n_channels = 3 |
| pad_channels = 13 - n_channels |
| padding = torch.zeros( |
| tensor.shape[0], pad_channels, tensor.shape[2], tensor.shape[3]) |
| return torch.cat([tensor, padding], dim=1) |
|
|
| def _preprocess_batch(self, images: List[Image.Image], modality: str) -> torch.Tensor: |
| return torch.stack([self._preprocess(img, modality) for img in images]) |
|
|
| @torch.no_grad() |
| def extract_features( |
| self, |
| image: Image.Image, |
| modality: str = "optical", |
| normalize: bool = True |
| ) -> torch.Tensor: |
| tensor = self._preprocess(image, modality) |
| features = self.encoder.encode(tensor, normalize=normalize) |
| return features.squeeze(0) |
|
|
| @torch.no_grad() |
| def extract_features_from_tensor( |
| self, |
| tensor: torch.Tensor, |
| modality: str = "optical", |
| normalize: bool = True |
| ) -> torch.Tensor: |
| """Extract features from a raw (C, H, W) tensor with arbitrary channels.""" |
| if tensor.ndim == 3: |
| tensor = tensor.unsqueeze(0) |
| if tensor.shape[1] < 13: |
| tensor = self._pad_to_13ch(tensor, modality) |
| tensor = tensor.to(self.device) |
| features = self.encoder.encode(tensor, normalize=normalize) |
| return features.squeeze(0) |
|
|
| @torch.no_grad() |
| def extract_batch( |
| self, |
| images: List[Image.Image], |
| modality: str = "optical", |
| batch_size: int = 32, |
| normalize: bool = True |
| ) -> torch.Tensor: |
| all_features = [] |
| for i in range(0, len(images), batch_size): |
| batch = images[i:i + batch_size] |
| tensors = self._preprocess_batch(batch, modality) |
| features = self.encoder.encode(tensors, normalize=normalize) |
| all_features.append(features.cpu()) |
| return torch.cat(all_features, dim=0) |
|
|
| def embed_dataset( |
| self, |
| dataset, |
| batch_size: int = 32, |
| show_progress: bool = True |
| ) -> Tuple[torch.Tensor, List[int], List[int]]: |
| from torch.utils.data import DataLoader |
|
|
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
| all_embeddings = [] |
| all_modality_labels = [] |
| all_class_labels = [] |
|
|
| for batch_idx, (images, mod_labels, class_labels) in enumerate(loader): |
| images = images.to(self.device) |
| with torch.no_grad(): |
| features = self.encoder.encode(images, normalize=True) |
| all_embeddings.append(features.cpu()) |
| all_modality_labels.extend(mod_labels.numpy().tolist()) |
| all_class_labels.extend(class_labels.numpy().tolist()) |
| if show_progress and (batch_idx + 1) % 10 == 0: |
| print(f"Embedded {batch_idx + 1}/{len(loader)} batches") |
|
|
| return torch.cat(all_embeddings, dim=0), all_modality_labels, all_class_labels |
|
|
|
|
| if __name__ == "__main__": |
| print("Testing SatCLIP FeatureExtractor...") |
| extractor = FeatureExtractor() |
| print(f"Embed dim: {extractor.embed_dim}") |
|
|
| dummy = Image.fromarray(torch.randint(0, 255, (224, 224, 3)).numpy()) |
| features = extractor.extract_features(dummy) |
| print(f"Single shape: {features.shape}") |
| print(f"L2 norm: {features.norm().item():.4f}") |
|
|
| batch = [dummy] * 4 |
| batch_features = extractor.extract_batch(batch) |
| print(f"Batch shape: {batch_features.shape}") |
| print("OK") |
|
|