EarthEmbeddingExplorer / models /satclip_model.py
ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
import sys
import os
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import hf_hub_download
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
from models.SatCLIP.satclip.load import get_satclip
print("Successfully imported models.SatCLIP.satclip.load.get_satclip.")
class SatCLIPModel:
def __init__(self,
ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt',
embedding_path='./embedding_datasets/10percent_satclip_encoded/all_satclip_embeddings.parquet', # Path to pre-computed embeddings if available
device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
if 'hf' in ckpt_path:
ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt")
self.ckpt_path = ckpt_path
self.embedding_path = embedding_path
self.model = None
self.df_embed = None
self.image_embeddings = None
self.load_model()
if self.embedding_path:
self.load_embeddings()
def load_model(self):
if get_satclip is None:
print("Error: SatCLIP functionality is not available.")
return
print(f"Loading SatCLIP model from {self.ckpt_path}...")
try:
if not os.path.exists(self.ckpt_path):
print(f"Warning: Checkpoint not found at {self.ckpt_path}")
return
# Load model using get_satclip
# return_all=True to get both visual and location encoders
self.model = get_satclip(self.ckpt_path, self.device, return_all=True)
self.model.eval()
print(f"SatCLIP model loaded on {self.device}")
except Exception as e:
print(f"Error loading SatCLIP model: {e}")
def load_embeddings(self):
# Assuming embeddings are stored similarly to SigLIP
print(f"Loading SatCLIP embeddings from {self.embedding_path}...")
try:
if not os.path.exists(self.embedding_path):
print(f"Warning: Embedding file not found at {self.embedding_path}")
return
self.df_embed = pq.read_table(self.embedding_path).to_pandas()
# Pre-compute image embeddings tensor
image_embeddings_np = np.stack(self.df_embed['embedding'].values)
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
print(f"SatCLIP Data loaded: {len(self.df_embed)} records")
except Exception as e:
print(f"Error loading SatCLIP embeddings: {e}")
def encode_location(self, lat, lon):
"""
Encode a (latitude, longitude) pair into a vector.
"""
if self.model is None:
return None
# SatCLIP expects input shape (N, 2) -> (lon, lat)
# Note: SatCLIP usually uses (lon, lat) order.
# Use double precision as per notebook reference
coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device)
with torch.no_grad():
# Use model.encode_location instead of model.location_encoder
# And normalize as per notebook: x / x.norm()
loc_features = self.model.encode_location(coords).float()
loc_features = loc_features / loc_features.norm(dim=1, keepdim=True)
return loc_features
def encode_image(self, image):
"""
Encode an RGB image into a vector using SatCLIP visual encoder.
Adapts RGB (3 channels) to SatCLIP input (13 channels).
"""
if self.model is None:
return None
try:
# Handle PIL Image (RGB)
if isinstance(image, Image.Image):
image = image.convert("RGB")
image = image.resize((224, 224))
img_np = np.array(image).astype(np.float32) / 255.0
# Construct 13 channels
# S2 bands: B01, B02(B), B03(G), B04(R), B05, B06, B07, B08, B8A, B09, B10, B11, B12
# Indices: 0=B01, 1=B02, 2=B03, 3=B04 ...
input_tensor = np.zeros((13, 224, 224), dtype=np.float32)
input_tensor[1] = img_np[:, :, 2] # Blue
input_tensor[2] = img_np[:, :, 1] # Green
input_tensor[3] = img_np[:, :, 0] # Red
input_tensor = torch.from_numpy(input_tensor).unsqueeze(0).to(self.device)
with torch.no_grad():
img_feature = self.model.encode_image(input_tensor)
img_feature = img_feature / img_feature.norm(dim=1, keepdim=True)
return img_feature
except Exception as e:
print(f"Error encoding image in SatCLIP: {e}")
return None
return None
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
if self.image_embeddings is None:
return None, None, None
query_features = query_features.float()
# Similarity calculation (Cosine similarity)
# SatCLIP embeddings are normalized, so dot product is cosine similarity
probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten()
if top_percent is not None:
k = int(len(probs) * top_percent)
if k < 1: k = 1
threshold = np.partition(probs, -k)[-k]
# Filter by threshold
mask = probs >= threshold
filtered_indices = np.where(mask)[0]
# Get top k
top_indices = np.argsort(probs)[-top_k:][::-1]
return probs, filtered_indices, top_indices