|
|
|
|
|
""" |
|
|
SatCLIP Multi-Spectral Model |
|
|
|
|
|
This model supports both RGB input and full multi-spectral Sentinel-2 input (12/13 bands). |
|
|
""" |
|
|
import sys |
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import pyarrow.parquet as pq |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
import warnings |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
try: |
|
|
from models.SatCLIP.satclip.load import get_satclip |
|
|
print("Successfully imported models.SatCLIP.satclip.load.get_satclip.") |
|
|
except ImportError: |
|
|
get_satclip = None |
|
|
print("Warning: SatCLIP not available. Please check installation.") |
|
|
|
|
|
|
|
|
class SatCLIPMSModel: |
|
|
""" |
|
|
SatCLIP model wrapper supporting multi-spectral Sentinel-2 input. |
|
|
|
|
|
Supports: |
|
|
- RGB PIL Image input (auto-converted to 13 channels) |
|
|
- 12-band Sentinel-2 numpy array (auto-padded to 13 channels) |
|
|
- 13-band full Sentinel-2 tensor |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt', |
|
|
embedding_path=None, |
|
|
device=None): |
|
|
|
|
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
if ckpt_path and 'hf' in ckpt_path: |
|
|
ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt") |
|
|
self.ckpt_path = ckpt_path |
|
|
self.embedding_path = embedding_path |
|
|
|
|
|
self.model = None |
|
|
self.df_embed = None |
|
|
self.image_embeddings = None |
|
|
|
|
|
|
|
|
self.input_size = 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.majortom_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] |
|
|
self.satclip_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] |
|
|
|
|
|
self.load_model() |
|
|
if self.embedding_path: |
|
|
self.load_embeddings() |
|
|
|
|
|
def load_model(self): |
|
|
if get_satclip is None: |
|
|
print("Error: SatCLIP functionality is not available.") |
|
|
return |
|
|
|
|
|
print(f"Loading SatCLIP-MS model from {self.ckpt_path}...") |
|
|
try: |
|
|
if not os.path.exists(self.ckpt_path): |
|
|
print(f"Warning: Checkpoint not found at {self.ckpt_path}") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
self.model = get_satclip(self.ckpt_path, self.device, return_all=True) |
|
|
self.model.eval() |
|
|
print(f"SatCLIP-MS model loaded on {self.device}") |
|
|
except Exception as e: |
|
|
print(f"Error loading SatCLIP model: {e}") |
|
|
|
|
|
def load_embeddings(self): |
|
|
print(f"Loading SatCLIP embeddings from {self.embedding_path}...") |
|
|
try: |
|
|
if not os.path.exists(self.embedding_path): |
|
|
print(f"Warning: Embedding file not found at {self.embedding_path}") |
|
|
return |
|
|
|
|
|
self.df_embed = pq.read_table(self.embedding_path).to_pandas() |
|
|
|
|
|
|
|
|
image_embeddings_np = np.stack(self.df_embed['embedding'].values) |
|
|
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() |
|
|
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) |
|
|
print(f"SatCLIP Data loaded: {len(self.df_embed)} records") |
|
|
except Exception as e: |
|
|
print(f"Error loading SatCLIP embeddings: {e}") |
|
|
|
|
|
def encode_location(self, lat, lon): |
|
|
""" |
|
|
Encode a (latitude, longitude) pair into a vector. |
|
|
""" |
|
|
if self.model is None: |
|
|
return None |
|
|
|
|
|
|
|
|
coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
loc_features = self.model.encode_location(coords).float() |
|
|
loc_features = loc_features / loc_features.norm(dim=1, keepdim=True) |
|
|
|
|
|
return loc_features |
|
|
|
|
|
def _preprocess_12band_array(self, img_array: np.ndarray) -> torch.Tensor: |
|
|
""" |
|
|
Preprocess a 12-band Sentinel-2 array to 13-band tensor for SatCLIP. |
|
|
|
|
|
Args: |
|
|
img_array: numpy array of shape (H, W, 12) with uint16 values (0-10000+) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor of shape (13, 224, 224) normalized |
|
|
""" |
|
|
|
|
|
image = img_array.astype(np.float32) / 10000.0 |
|
|
|
|
|
|
|
|
image = image.transpose(2, 0, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B10 = np.zeros((1, image.shape[1], image.shape[2]), dtype=image.dtype) |
|
|
image_13 = np.concatenate([image[:10], B10, image[10:]], axis=0) |
|
|
|
|
|
|
|
|
image_tensor = torch.from_numpy(image_13) |
|
|
|
|
|
|
|
|
transforms = T.Resize((self.input_size, self.input_size), |
|
|
interpolation=T.InterpolationMode.BICUBIC, |
|
|
antialias=True) |
|
|
image_tensor = transforms(image_tensor) |
|
|
|
|
|
return image_tensor |
|
|
|
|
|
def _preprocess_rgb_image(self, image: Image.Image) -> torch.Tensor: |
|
|
""" |
|
|
Preprocess RGB PIL Image to 13-band tensor for SatCLIP. |
|
|
Maps RGB to B04, B03, B02 and zeros for other bands. |
|
|
|
|
|
Args: |
|
|
image: PIL RGB Image |
|
|
|
|
|
Returns: |
|
|
torch.Tensor of shape (13, 224, 224) |
|
|
""" |
|
|
image = image.convert("RGB") |
|
|
image = image.resize((self.input_size, self.input_size)) |
|
|
img_np = np.array(image).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_tensor = np.zeros((13, self.input_size, self.input_size), dtype=np.float32) |
|
|
input_tensor[1] = img_np[:, :, 2] |
|
|
input_tensor[2] = img_np[:, :, 1] |
|
|
input_tensor[3] = img_np[:, :, 0] |
|
|
|
|
|
return torch.from_numpy(input_tensor) |
|
|
|
|
|
def encode_image(self, image, is_multispectral=False): |
|
|
""" |
|
|
Encode an image into a feature vector. |
|
|
|
|
|
Args: |
|
|
image: Can be one of: |
|
|
- PIL.Image (RGB) - will be converted to 13-band |
|
|
- np.ndarray of shape (H, W, 12) - 12-band Sentinel-2 data |
|
|
- torch.Tensor of shape (13, H, W) or (B, 13, H, W) - ready for model |
|
|
is_multispectral: Hint to indicate if numpy input is multi-spectral |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Normalized embedding vector |
|
|
""" |
|
|
if self.model is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
|
|
|
input_tensor = self._preprocess_rgb_image(image).unsqueeze(0) |
|
|
|
|
|
elif isinstance(image, np.ndarray): |
|
|
|
|
|
if image.ndim == 3 and image.shape[-1] == 12: |
|
|
input_tensor = self._preprocess_12band_array(image).unsqueeze(0) |
|
|
elif image.ndim == 3 and image.shape[-1] == 3: |
|
|
|
|
|
pil_img = Image.fromarray(image.astype(np.uint8)) |
|
|
input_tensor = self._preprocess_rgb_image(pil_img).unsqueeze(0) |
|
|
else: |
|
|
print(f"Unsupported numpy array shape: {image.shape}") |
|
|
return None |
|
|
|
|
|
elif isinstance(image, torch.Tensor): |
|
|
|
|
|
if image.dim() == 3: |
|
|
input_tensor = image.unsqueeze(0) |
|
|
else: |
|
|
input_tensor = image |
|
|
|
|
|
|
|
|
if input_tensor.shape[-1] != self.input_size or input_tensor.shape[-2] != self.input_size: |
|
|
transforms = T.Resize((self.input_size, self.input_size), |
|
|
interpolation=T.InterpolationMode.BICUBIC, |
|
|
antialias=True) |
|
|
input_tensor = transforms(input_tensor) |
|
|
else: |
|
|
print(f"Unsupported image type: {type(image)}") |
|
|
return None |
|
|
|
|
|
|
|
|
input_tensor = input_tensor.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
img_feature = self.model.encode_image(input_tensor) |
|
|
img_feature = img_feature / img_feature.norm(dim=1, keepdim=True) |
|
|
|
|
|
return img_feature |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error encoding image in SatCLIP-MS: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def encode_batch(self, batch_tensors: list) -> np.ndarray: |
|
|
""" |
|
|
Encode a batch of pre-processed tensors. |
|
|
|
|
|
Args: |
|
|
batch_tensors: List of torch.Tensor, each of shape (13, H, W) |
|
|
|
|
|
Returns: |
|
|
np.ndarray of shape (N, embedding_dim) |
|
|
""" |
|
|
if self.model is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
t_stack = torch.stack(batch_tensors).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
feats = self.model.encode_image(t_stack) |
|
|
feats = feats / feats.norm(dim=1, keepdim=True) |
|
|
|
|
|
return feats.cpu().numpy() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error encoding batch: {e}") |
|
|
return None |
|
|
|
|
|
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): |
|
|
if self.image_embeddings is None: |
|
|
return None, None, None |
|
|
|
|
|
query_features = query_features.float() |
|
|
|
|
|
|
|
|
probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten() |
|
|
|
|
|
if top_percent is not None: |
|
|
k = int(len(probs) * top_percent) |
|
|
if k < 1: |
|
|
k = 1 |
|
|
threshold = np.partition(probs, -k)[-k] |
|
|
|
|
|
|
|
|
mask = probs >= threshold |
|
|
filtered_indices = np.where(mask)[0] |
|
|
|
|
|
|
|
|
top_indices = np.argsort(probs)[-top_k:][::-1] |
|
|
|
|
|
return probs, filtered_indices, top_indices |
|
|
|