EarthEmbeddingExplorer / models /satclip_ms_model.py
ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
"""
SatCLIP Multi-Spectral Model
This model supports both RGB input and full multi-spectral Sentinel-2 input (12/13 bands).
"""
import sys
import os
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from huggingface_hub import hf_hub_download
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
try:
from models.SatCLIP.satclip.load import get_satclip
print("Successfully imported models.SatCLIP.satclip.load.get_satclip.")
except ImportError:
get_satclip = None
print("Warning: SatCLIP not available. Please check installation.")
class SatCLIPMSModel:
"""
SatCLIP model wrapper supporting multi-spectral Sentinel-2 input.
Supports:
- RGB PIL Image input (auto-converted to 13 channels)
- 12-band Sentinel-2 numpy array (auto-padded to 13 channels)
- 13-band full Sentinel-2 tensor
"""
def __init__(self,
ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt',
embedding_path=None,
device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
if ckpt_path and 'hf' in ckpt_path:
ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt")
self.ckpt_path = ckpt_path
self.embedding_path = embedding_path
self.model = None
self.df_embed = None
self.image_embeddings = None
# SatCLIP input size
self.input_size = 224
# Sentinel-2 bands mapping
# MajorTOM 12 bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12]
# SatCLIP 13 bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B10, B11, B12]
# B10 is missing in MajorTOM, need to insert zeros at index 10
self.majortom_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
self.satclip_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
self.load_model()
if self.embedding_path:
self.load_embeddings()
def load_model(self):
if get_satclip is None:
print("Error: SatCLIP functionality is not available.")
return
print(f"Loading SatCLIP-MS model from {self.ckpt_path}...")
try:
if not os.path.exists(self.ckpt_path):
print(f"Warning: Checkpoint not found at {self.ckpt_path}")
return
# Load model using get_satclip
# return_all=True to get both visual and location encoders
self.model = get_satclip(self.ckpt_path, self.device, return_all=True)
self.model.eval()
print(f"SatCLIP-MS model loaded on {self.device}")
except Exception as e:
print(f"Error loading SatCLIP model: {e}")
def load_embeddings(self):
print(f"Loading SatCLIP embeddings from {self.embedding_path}...")
try:
if not os.path.exists(self.embedding_path):
print(f"Warning: Embedding file not found at {self.embedding_path}")
return
self.df_embed = pq.read_table(self.embedding_path).to_pandas()
# Pre-compute image embeddings tensor
image_embeddings_np = np.stack(self.df_embed['embedding'].values)
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
print(f"SatCLIP Data loaded: {len(self.df_embed)} records")
except Exception as e:
print(f"Error loading SatCLIP embeddings: {e}")
def encode_location(self, lat, lon):
"""
Encode a (latitude, longitude) pair into a vector.
"""
if self.model is None:
return None
# SatCLIP expects input shape (N, 2) -> (lon, lat)
coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device)
with torch.no_grad():
loc_features = self.model.encode_location(coords).float()
loc_features = loc_features / loc_features.norm(dim=1, keepdim=True)
return loc_features
def _preprocess_12band_array(self, img_array: np.ndarray) -> torch.Tensor:
"""
Preprocess a 12-band Sentinel-2 array to 13-band tensor for SatCLIP.
Args:
img_array: numpy array of shape (H, W, 12) with uint16 values (0-10000+)
Returns:
torch.Tensor of shape (13, 224, 224) normalized
"""
# 1. Normalize (SatCLIP standard: / 10000.0)
image = img_array.astype(np.float32) / 10000.0
# 2. Channel First: (H, W, C) -> (C, H, W) -> (12, H, W)
image = image.transpose(2, 0, 1)
# 3. Insert B10 (zeros) at index 10 -> (13, H, W)
# MajorTOM: [B01..B09(idx0-9), B11(idx10), B12(idx11)]
# SatCLIP: [B01..B09(idx0-9), B10(idx10), B11(idx11), B12(idx12)]
B10 = np.zeros((1, image.shape[1], image.shape[2]), dtype=image.dtype)
image_13 = np.concatenate([image[:10], B10, image[10:]], axis=0)
# 4. Convert to Tensor
image_tensor = torch.from_numpy(image_13)
# 5. Resize to 224x224
transforms = T.Resize((self.input_size, self.input_size),
interpolation=T.InterpolationMode.BICUBIC,
antialias=True)
image_tensor = transforms(image_tensor)
return image_tensor
def _preprocess_rgb_image(self, image: Image.Image) -> torch.Tensor:
"""
Preprocess RGB PIL Image to 13-band tensor for SatCLIP.
Maps RGB to B04, B03, B02 and zeros for other bands.
Args:
image: PIL RGB Image
Returns:
torch.Tensor of shape (13, 224, 224)
"""
image = image.convert("RGB")
image = image.resize((self.input_size, self.input_size))
img_np = np.array(image).astype(np.float32) / 255.0
# Construct 13 channels
# S2 bands: B01, B02(B), B03(G), B04(R), B05...
# Indices: 0=B01, 1=B02, 2=B03, 3=B04 ...
input_tensor = np.zeros((13, self.input_size, self.input_size), dtype=np.float32)
input_tensor[1] = img_np[:, :, 2] # Blue -> B02
input_tensor[2] = img_np[:, :, 1] # Green -> B03
input_tensor[3] = img_np[:, :, 0] # Red -> B04
return torch.from_numpy(input_tensor)
def encode_image(self, image, is_multispectral=False):
"""
Encode an image into a feature vector.
Args:
image: Can be one of:
- PIL.Image (RGB) - will be converted to 13-band
- np.ndarray of shape (H, W, 12) - 12-band Sentinel-2 data
- torch.Tensor of shape (13, H, W) or (B, 13, H, W) - ready for model
is_multispectral: Hint to indicate if numpy input is multi-spectral
Returns:
torch.Tensor: Normalized embedding vector
"""
if self.model is None:
return None
try:
# Handle different input types
if isinstance(image, Image.Image):
# RGB PIL Image
input_tensor = self._preprocess_rgb_image(image).unsqueeze(0)
elif isinstance(image, np.ndarray):
# Numpy array - assumed to be 12-band Sentinel-2 (H, W, 12)
if image.ndim == 3 and image.shape[-1] == 12:
input_tensor = self._preprocess_12band_array(image).unsqueeze(0)
elif image.ndim == 3 and image.shape[-1] == 3:
# RGB numpy array
pil_img = Image.fromarray(image.astype(np.uint8))
input_tensor = self._preprocess_rgb_image(pil_img).unsqueeze(0)
else:
print(f"Unsupported numpy array shape: {image.shape}")
return None
elif isinstance(image, torch.Tensor):
# Already a tensor
if image.dim() == 3:
input_tensor = image.unsqueeze(0)
else:
input_tensor = image
# Resize if needed
if input_tensor.shape[-1] != self.input_size or input_tensor.shape[-2] != self.input_size:
transforms = T.Resize((self.input_size, self.input_size),
interpolation=T.InterpolationMode.BICUBIC,
antialias=True)
input_tensor = transforms(input_tensor)
else:
print(f"Unsupported image type: {type(image)}")
return None
# Move to device and encode
input_tensor = input_tensor.to(self.device)
with torch.no_grad():
img_feature = self.model.encode_image(input_tensor)
img_feature = img_feature / img_feature.norm(dim=1, keepdim=True)
return img_feature
except Exception as e:
print(f"Error encoding image in SatCLIP-MS: {e}")
import traceback
traceback.print_exc()
return None
def encode_batch(self, batch_tensors: list) -> np.ndarray:
"""
Encode a batch of pre-processed tensors.
Args:
batch_tensors: List of torch.Tensor, each of shape (13, H, W)
Returns:
np.ndarray of shape (N, embedding_dim)
"""
if self.model is None:
return None
try:
t_stack = torch.stack(batch_tensors).to(self.device)
with torch.no_grad():
feats = self.model.encode_image(t_stack)
feats = feats / feats.norm(dim=1, keepdim=True)
return feats.cpu().numpy()
except Exception as e:
print(f"Error encoding batch: {e}")
return None
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
if self.image_embeddings is None:
return None, None, None
query_features = query_features.float()
# Similarity calculation (Cosine similarity)
probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten()
if top_percent is not None:
k = int(len(probs) * top_percent)
if k < 1:
k = 1
threshold = np.partition(probs, -k)[-k]
# Filter by threshold
mask = probs >= threshold
filtered_indices = np.where(mask)[0]
# Get top k
top_indices = np.argsort(probs)[-top_k:][::-1]
return probs, filtered_indices, top_indices