import sys import os import torch import numpy as np import pandas as pd import pyarrow.parquet as pq import torch.nn.functional as F from PIL import Image from huggingface_hub import hf_hub_download import warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) from models.SatCLIP.satclip.load import get_satclip print("Successfully imported models.SatCLIP.satclip.load.get_satclip.") class SatCLIPModel: def __init__(self, ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt', embedding_path='./embedding_datasets/10percent_satclip_encoded/all_satclip_embeddings.parquet', # Path to pre-computed embeddings if available device=None): self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") if 'hf' in ckpt_path: ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt") self.ckpt_path = ckpt_path self.embedding_path = embedding_path self.model = None self.df_embed = None self.image_embeddings = None self.load_model() if self.embedding_path: self.load_embeddings() def load_model(self): if get_satclip is None: print("Error: SatCLIP functionality is not available.") return print(f"Loading SatCLIP model from {self.ckpt_path}...") try: if not os.path.exists(self.ckpt_path): print(f"Warning: Checkpoint not found at {self.ckpt_path}") return # Load model using get_satclip # return_all=True to get both visual and location encoders self.model = get_satclip(self.ckpt_path, self.device, return_all=True) self.model.eval() print(f"SatCLIP model loaded on {self.device}") except Exception as e: print(f"Error loading SatCLIP model: {e}") def load_embeddings(self): # Assuming embeddings are stored similarly to SigLIP print(f"Loading SatCLIP embeddings from {self.embedding_path}...") try: if not os.path.exists(self.embedding_path): print(f"Warning: Embedding file not found at {self.embedding_path}") return self.df_embed = pq.read_table(self.embedding_path).to_pandas() # Pre-compute image embeddings tensor image_embeddings_np = np.stack(self.df_embed['embedding'].values) self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) print(f"SatCLIP Data loaded: {len(self.df_embed)} records") except Exception as e: print(f"Error loading SatCLIP embeddings: {e}") def encode_location(self, lat, lon): """ Encode a (latitude, longitude) pair into a vector. """ if self.model is None: return None # SatCLIP expects input shape (N, 2) -> (lon, lat) # Note: SatCLIP usually uses (lon, lat) order. # Use double precision as per notebook reference coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device) with torch.no_grad(): # Use model.encode_location instead of model.location_encoder # And normalize as per notebook: x / x.norm() loc_features = self.model.encode_location(coords).float() loc_features = loc_features / loc_features.norm(dim=1, keepdim=True) return loc_features def encode_image(self, image): """ Encode an RGB image into a vector using SatCLIP visual encoder. Adapts RGB (3 channels) to SatCLIP input (13 channels). """ if self.model is None: return None try: # Handle PIL Image (RGB) if isinstance(image, Image.Image): image = image.convert("RGB") image = image.resize((224, 224)) img_np = np.array(image).astype(np.float32) / 255.0 # Construct 13 channels # S2 bands: B01, B02(B), B03(G), B04(R), B05, B06, B07, B08, B8A, B09, B10, B11, B12 # Indices: 0=B01, 1=B02, 2=B03, 3=B04 ... input_tensor = np.zeros((13, 224, 224), dtype=np.float32) input_tensor[1] = img_np[:, :, 2] # Blue input_tensor[2] = img_np[:, :, 1] # Green input_tensor[3] = img_np[:, :, 0] # Red input_tensor = torch.from_numpy(input_tensor).unsqueeze(0).to(self.device) with torch.no_grad(): img_feature = self.model.encode_image(input_tensor) img_feature = img_feature / img_feature.norm(dim=1, keepdim=True) return img_feature except Exception as e: print(f"Error encoding image in SatCLIP: {e}") return None return None def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): if self.image_embeddings is None: return None, None, None query_features = query_features.float() # Similarity calculation (Cosine similarity) # SatCLIP embeddings are normalized, so dot product is cosine similarity probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten() if top_percent is not None: k = int(len(probs) * top_percent) if k < 1: k = 1 threshold = np.partition(probs, -k)[-k] # Filter by threshold mask = probs >= threshold filtered_indices = np.where(mask)[0] # Get top k top_indices = np.argsort(probs)[-top_k:][::-1] return probs, filtered_indices, top_indices