silverballs's picture
Deploy SatFetch Space App - track binary files with Git LFS
0db1cb9
Raw
History Blame Contribute Delete
5.77 kB
"""
Feature extraction using SatCLIP for satellite imagery.
SatCLIP is trained on Sentinel-2 data - better than generic CLIP
for satellite image retrieval.
"""
import torch
import torch.nn.functional as F
from PIL import Image
from typing import List, Optional, Tuple
from torchvision import transforms
from .satclip_encoder import SatCLIPEncoder
# Wavelength centroids (nm) per modality — needed by DOFA-style models
# These match Sentinel-2 band centers and are used for positional encoding
WAVELENGTHS = {
"optical": torch.tensor([492.4, 559.8, 664.6]), # RGB: B02, B03, B04
"sar": torch.tensor([5400.0, 5600.0]), # C-band VV, VH (approx, in nm-equivalent)
"multispectral": torch.tensor([ # Sentinel-2 MS bands
442.0, 492.4, 559.8, 664.6, 704.1, 740.5, 782.8, 832.8,
864.7, 945.1, 1373.5, 1613.7
]),
}
MODALITY_CHANNELS = {
"optical": 3,
"sar": 2,
"multispectral": 12,
}
class FeatureExtractor:
"""
Extract features from satellite images using SatCLIP.
Uses SatCLIP's ViT trained on Sentinel-2 imagery.
"""
def __init__(self, device: Optional[str] = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.encoder = SatCLIPEncoder(device=self.device)
self.embed_dim = self.encoder.embed_dim
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def _preprocess(self, image: Image.Image, modality: str) -> torch.Tensor:
tensor = self.transform(image).unsqueeze(0)
return self._pad_to_13ch(tensor, modality)
def _pad_to_13ch(self, tensor: torch.Tensor, modality: str = "optical") -> torch.Tensor:
"""Pad tensor to 13 channels for SatCLIP. Handles 1-13 channels."""
n_channels = tensor.shape[1]
if n_channels >= 13:
return tensor[:, :13, :, :]
# Repeat single channel to 3 (grayscale SAR fallback)
if n_channels == 1:
tensor = tensor.repeat(1, 3, 1, 1)
n_channels = 3
# Repeat 2 channels to 3 (SAR VV/VH)
if n_channels == 2:
third = tensor[:, :1, :, :] # duplicate VV as 3rd channel
tensor = torch.cat([tensor, third], dim=1)
n_channels = 3
pad_channels = 13 - n_channels
padding = torch.zeros(
tensor.shape[0], pad_channels, tensor.shape[2], tensor.shape[3])
return torch.cat([tensor, padding], dim=1)
def _preprocess_batch(self, images: List[Image.Image], modality: str) -> torch.Tensor:
return torch.stack([self._preprocess(img, modality) for img in images])
@torch.no_grad()
def extract_features(
self,
image: Image.Image,
modality: str = "optical",
normalize: bool = True
) -> torch.Tensor:
tensor = self._preprocess(image, modality)
features = self.encoder.encode(tensor, normalize=normalize)
return features.squeeze(0)
@torch.no_grad()
def extract_features_from_tensor(
self,
tensor: torch.Tensor,
modality: str = "optical",
normalize: bool = True
) -> torch.Tensor:
"""Extract features from a raw (C, H, W) tensor with arbitrary channels."""
if tensor.ndim == 3:
tensor = tensor.unsqueeze(0)
if tensor.shape[1] < 13:
tensor = self._pad_to_13ch(tensor, modality)
tensor = tensor.to(self.device)
features = self.encoder.encode(tensor, normalize=normalize)
return features.squeeze(0)
@torch.no_grad()
def extract_batch(
self,
images: List[Image.Image],
modality: str = "optical",
batch_size: int = 32,
normalize: bool = True
) -> torch.Tensor:
all_features = []
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size]
tensors = self._preprocess_batch(batch, modality)
features = self.encoder.encode(tensors, normalize=normalize)
all_features.append(features.cpu())
return torch.cat(all_features, dim=0)
def embed_dataset(
self,
dataset,
batch_size: int = 32,
show_progress: bool = True
) -> Tuple[torch.Tensor, List[int], List[int]]:
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
all_embeddings = []
all_modality_labels = []
all_class_labels = []
for batch_idx, (images, mod_labels, class_labels) in enumerate(loader):
images = images.to(self.device)
with torch.no_grad():
features = self.encoder.encode(images, normalize=True)
all_embeddings.append(features.cpu())
all_modality_labels.extend(mod_labels.numpy().tolist())
all_class_labels.extend(class_labels.numpy().tolist())
if show_progress and (batch_idx + 1) % 10 == 0:
print(f"Embedded {batch_idx + 1}/{len(loader)} batches")
return torch.cat(all_embeddings, dim=0), all_modality_labels, all_class_labels
if __name__ == "__main__":
print("Testing SatCLIP FeatureExtractor...")
extractor = FeatureExtractor()
print(f"Embed dim: {extractor.embed_dim}")
dummy = Image.fromarray(torch.randint(0, 255, (224, 224, 3)).numpy())
features = extractor.extract_features(dummy)
print(f"Single shape: {features.shape}")
print(f"L2 norm: {features.norm().item():.4f}")
batch = [dummy] * 4
batch_features = extractor.extract_batch(batch)
print(f"Batch shape: {batch_features.shape}")
print("OK")