""" Feature extraction using SatCLIP for satellite imagery. SatCLIP is trained on Sentinel-2 data - better than generic CLIP for satellite image retrieval. """ import torch import torch.nn.functional as F from PIL import Image from typing import List, Optional, Tuple from torchvision import transforms from .satclip_encoder import SatCLIPEncoder # Wavelength centroids (nm) per modality — needed by DOFA-style models # These match Sentinel-2 band centers and are used for positional encoding WAVELENGTHS = { "optical": torch.tensor([492.4, 559.8, 664.6]), # RGB: B02, B03, B04 "sar": torch.tensor([5400.0, 5600.0]), # C-band VV, VH (approx, in nm-equivalent) "multispectral": torch.tensor([ # Sentinel-2 MS bands 442.0, 492.4, 559.8, 664.6, 704.1, 740.5, 782.8, 832.8, 864.7, 945.1, 1373.5, 1613.7 ]), } MODALITY_CHANNELS = { "optical": 3, "sar": 2, "multispectral": 12, } class FeatureExtractor: """ Extract features from satellite images using SatCLIP. Uses SatCLIP's ViT trained on Sentinel-2 imagery. """ def __init__(self, device: Optional[str] = None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.encoder = SatCLIPEncoder(device=self.device) self.embed_dim = self.encoder.embed_dim self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) def _preprocess(self, image: Image.Image, modality: str) -> torch.Tensor: tensor = self.transform(image).unsqueeze(0) return self._pad_to_13ch(tensor, modality) def _pad_to_13ch(self, tensor: torch.Tensor, modality: str = "optical") -> torch.Tensor: """Pad tensor to 13 channels for SatCLIP. Handles 1-13 channels.""" n_channels = tensor.shape[1] if n_channels >= 13: return tensor[:, :13, :, :] # Repeat single channel to 3 (grayscale SAR fallback) if n_channels == 1: tensor = tensor.repeat(1, 3, 1, 1) n_channels = 3 # Repeat 2 channels to 3 (SAR VV/VH) if n_channels == 2: third = tensor[:, :1, :, :] # duplicate VV as 3rd channel tensor = torch.cat([tensor, third], dim=1) n_channels = 3 pad_channels = 13 - n_channels padding = torch.zeros( tensor.shape[0], pad_channels, tensor.shape[2], tensor.shape[3]) return torch.cat([tensor, padding], dim=1) def _preprocess_batch(self, images: List[Image.Image], modality: str) -> torch.Tensor: return torch.stack([self._preprocess(img, modality) for img in images]) @torch.no_grad() def extract_features( self, image: Image.Image, modality: str = "optical", normalize: bool = True ) -> torch.Tensor: tensor = self._preprocess(image, modality) features = self.encoder.encode(tensor, normalize=normalize) return features.squeeze(0) @torch.no_grad() def extract_features_from_tensor( self, tensor: torch.Tensor, modality: str = "optical", normalize: bool = True ) -> torch.Tensor: """Extract features from a raw (C, H, W) tensor with arbitrary channels.""" if tensor.ndim == 3: tensor = tensor.unsqueeze(0) if tensor.shape[1] < 13: tensor = self._pad_to_13ch(tensor, modality) tensor = tensor.to(self.device) features = self.encoder.encode(tensor, normalize=normalize) return features.squeeze(0) @torch.no_grad() def extract_batch( self, images: List[Image.Image], modality: str = "optical", batch_size: int = 32, normalize: bool = True ) -> torch.Tensor: all_features = [] for i in range(0, len(images), batch_size): batch = images[i:i + batch_size] tensors = self._preprocess_batch(batch, modality) features = self.encoder.encode(tensors, normalize=normalize) all_features.append(features.cpu()) return torch.cat(all_features, dim=0) def embed_dataset( self, dataset, batch_size: int = 32, show_progress: bool = True ) -> Tuple[torch.Tensor, List[int], List[int]]: from torch.utils.data import DataLoader loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) all_embeddings = [] all_modality_labels = [] all_class_labels = [] for batch_idx, (images, mod_labels, class_labels) in enumerate(loader): images = images.to(self.device) with torch.no_grad(): features = self.encoder.encode(images, normalize=True) all_embeddings.append(features.cpu()) all_modality_labels.extend(mod_labels.numpy().tolist()) all_class_labels.extend(class_labels.numpy().tolist()) if show_progress and (batch_idx + 1) % 10 == 0: print(f"Embedded {batch_idx + 1}/{len(loader)} batches") return torch.cat(all_embeddings, dim=0), all_modality_labels, all_class_labels if __name__ == "__main__": print("Testing SatCLIP FeatureExtractor...") extractor = FeatureExtractor() print(f"Embed dim: {extractor.embed_dim}") dummy = Image.fromarray(torch.randint(0, 255, (224, 224, 3)).numpy()) features = extractor.extract_features(dummy) print(f"Single shape: {features.shape}") print(f"L2 norm: {features.norm().item():.4f}") batch = [dummy] * 4 batch_features = extractor.extract_batch(batch) print(f"Batch shape: {batch_features.shape}") print("OK")