""" SatCLIP Multi-Spectral Model This model supports both RGB input and full multi-spectral Sentinel-2 input (12/13 bands). """ import sys import os import torch import numpy as np import pandas as pd import pyarrow.parquet as pq import torch.nn.functional as F import torchvision.transforms as T from PIL import Image from huggingface_hub import hf_hub_download import warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) try: from models.SatCLIP.satclip.load import get_satclip print("Successfully imported models.SatCLIP.satclip.load.get_satclip.") except ImportError: get_satclip = None print("Warning: SatCLIP not available. Please check installation.") class SatCLIPMSModel: """ SatCLIP model wrapper supporting multi-spectral Sentinel-2 input. Supports: - RGB PIL Image input (auto-converted to 13 channels) - 12-band Sentinel-2 numpy array (auto-padded to 13 channels) - 13-band full Sentinel-2 tensor """ def __init__(self, ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt', embedding_path=None, device=None): self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") if ckpt_path and 'hf' in ckpt_path: ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt") self.ckpt_path = ckpt_path self.embedding_path = embedding_path self.model = None self.df_embed = None self.image_embeddings = None # SatCLIP input size self.input_size = 224 # Sentinel-2 bands mapping # MajorTOM 12 bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12] # SatCLIP 13 bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B10, B11, B12] # B10 is missing in MajorTOM, need to insert zeros at index 10 self.majortom_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] self.satclip_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] self.load_model() if self.embedding_path: self.load_embeddings() def load_model(self): if get_satclip is None: print("Error: SatCLIP functionality is not available.") return print(f"Loading SatCLIP-MS model from {self.ckpt_path}...") try: if not os.path.exists(self.ckpt_path): print(f"Warning: Checkpoint not found at {self.ckpt_path}") return # Load model using get_satclip # return_all=True to get both visual and location encoders self.model = get_satclip(self.ckpt_path, self.device, return_all=True) self.model.eval() print(f"SatCLIP-MS model loaded on {self.device}") except Exception as e: print(f"Error loading SatCLIP model: {e}") def load_embeddings(self): print(f"Loading SatCLIP embeddings from {self.embedding_path}...") try: if not os.path.exists(self.embedding_path): print(f"Warning: Embedding file not found at {self.embedding_path}") return self.df_embed = pq.read_table(self.embedding_path).to_pandas() # Pre-compute image embeddings tensor image_embeddings_np = np.stack(self.df_embed['embedding'].values) self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) print(f"SatCLIP Data loaded: {len(self.df_embed)} records") except Exception as e: print(f"Error loading SatCLIP embeddings: {e}") def encode_location(self, lat, lon): """ Encode a (latitude, longitude) pair into a vector. """ if self.model is None: return None # SatCLIP expects input shape (N, 2) -> (lon, lat) coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device) with torch.no_grad(): loc_features = self.model.encode_location(coords).float() loc_features = loc_features / loc_features.norm(dim=1, keepdim=True) return loc_features def _preprocess_12band_array(self, img_array: np.ndarray) -> torch.Tensor: """ Preprocess a 12-band Sentinel-2 array to 13-band tensor for SatCLIP. Args: img_array: numpy array of shape (H, W, 12) with uint16 values (0-10000+) Returns: torch.Tensor of shape (13, 224, 224) normalized """ # 1. Normalize (SatCLIP standard: / 10000.0) image = img_array.astype(np.float32) / 10000.0 # 2. Channel First: (H, W, C) -> (C, H, W) -> (12, H, W) image = image.transpose(2, 0, 1) # 3. Insert B10 (zeros) at index 10 -> (13, H, W) # MajorTOM: [B01..B09(idx0-9), B11(idx10), B12(idx11)] # SatCLIP: [B01..B09(idx0-9), B10(idx10), B11(idx11), B12(idx12)] B10 = np.zeros((1, image.shape[1], image.shape[2]), dtype=image.dtype) image_13 = np.concatenate([image[:10], B10, image[10:]], axis=0) # 4. Convert to Tensor image_tensor = torch.from_numpy(image_13) # 5. Resize to 224x224 transforms = T.Resize((self.input_size, self.input_size), interpolation=T.InterpolationMode.BICUBIC, antialias=True) image_tensor = transforms(image_tensor) return image_tensor def _preprocess_rgb_image(self, image: Image.Image) -> torch.Tensor: """ Preprocess RGB PIL Image to 13-band tensor for SatCLIP. Maps RGB to B04, B03, B02 and zeros for other bands. Args: image: PIL RGB Image Returns: torch.Tensor of shape (13, 224, 224) """ image = image.convert("RGB") image = image.resize((self.input_size, self.input_size)) img_np = np.array(image).astype(np.float32) / 255.0 # Construct 13 channels # S2 bands: B01, B02(B), B03(G), B04(R), B05... # Indices: 0=B01, 1=B02, 2=B03, 3=B04 ... input_tensor = np.zeros((13, self.input_size, self.input_size), dtype=np.float32) input_tensor[1] = img_np[:, :, 2] # Blue -> B02 input_tensor[2] = img_np[:, :, 1] # Green -> B03 input_tensor[3] = img_np[:, :, 0] # Red -> B04 return torch.from_numpy(input_tensor) def encode_image(self, image, is_multispectral=False): """ Encode an image into a feature vector. Args: image: Can be one of: - PIL.Image (RGB) - will be converted to 13-band - np.ndarray of shape (H, W, 12) - 12-band Sentinel-2 data - torch.Tensor of shape (13, H, W) or (B, 13, H, W) - ready for model is_multispectral: Hint to indicate if numpy input is multi-spectral Returns: torch.Tensor: Normalized embedding vector """ if self.model is None: return None try: # Handle different input types if isinstance(image, Image.Image): # RGB PIL Image input_tensor = self._preprocess_rgb_image(image).unsqueeze(0) elif isinstance(image, np.ndarray): # Numpy array - assumed to be 12-band Sentinel-2 (H, W, 12) if image.ndim == 3 and image.shape[-1] == 12: input_tensor = self._preprocess_12band_array(image).unsqueeze(0) elif image.ndim == 3 and image.shape[-1] == 3: # RGB numpy array pil_img = Image.fromarray(image.astype(np.uint8)) input_tensor = self._preprocess_rgb_image(pil_img).unsqueeze(0) else: print(f"Unsupported numpy array shape: {image.shape}") return None elif isinstance(image, torch.Tensor): # Already a tensor if image.dim() == 3: input_tensor = image.unsqueeze(0) else: input_tensor = image # Resize if needed if input_tensor.shape[-1] != self.input_size or input_tensor.shape[-2] != self.input_size: transforms = T.Resize((self.input_size, self.input_size), interpolation=T.InterpolationMode.BICUBIC, antialias=True) input_tensor = transforms(input_tensor) else: print(f"Unsupported image type: {type(image)}") return None # Move to device and encode input_tensor = input_tensor.to(self.device) with torch.no_grad(): img_feature = self.model.encode_image(input_tensor) img_feature = img_feature / img_feature.norm(dim=1, keepdim=True) return img_feature except Exception as e: print(f"Error encoding image in SatCLIP-MS: {e}") import traceback traceback.print_exc() return None def encode_batch(self, batch_tensors: list) -> np.ndarray: """ Encode a batch of pre-processed tensors. Args: batch_tensors: List of torch.Tensor, each of shape (13, H, W) Returns: np.ndarray of shape (N, embedding_dim) """ if self.model is None: return None try: t_stack = torch.stack(batch_tensors).to(self.device) with torch.no_grad(): feats = self.model.encode_image(t_stack) feats = feats / feats.norm(dim=1, keepdim=True) return feats.cpu().numpy() except Exception as e: print(f"Error encoding batch: {e}") return None def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): if self.image_embeddings is None: return None, None, None query_features = query_features.float() # Similarity calculation (Cosine similarity) probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten() if top_percent is not None: k = int(len(probs) * top_percent) if k < 1: k = 1 threshold = np.partition(probs, -k)[-k] # Filter by threshold mask = probs >= threshold filtered_indices = np.where(mask)[0] # Get top k top_indices = np.argsort(probs)[-top_k:][::-1] return probs, filtered_indices, top_indices