|
|
import sys |
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import pyarrow.parquet as pq |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
import warnings |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
from models.SatCLIP.satclip.load import get_satclip |
|
|
print("Successfully imported models.SatCLIP.satclip.load.get_satclip.") |
|
|
|
|
|
class SatCLIPModel: |
|
|
def __init__(self, |
|
|
ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt', |
|
|
embedding_path='./embedding_datasets/10percent_satclip_encoded/all_satclip_embeddings.parquet', |
|
|
device=None): |
|
|
|
|
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
if 'hf' in ckpt_path: |
|
|
ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt") |
|
|
self.ckpt_path = ckpt_path |
|
|
self.embedding_path = embedding_path |
|
|
|
|
|
self.model = None |
|
|
self.df_embed = None |
|
|
self.image_embeddings = None |
|
|
|
|
|
self.load_model() |
|
|
if self.embedding_path: |
|
|
self.load_embeddings() |
|
|
|
|
|
def load_model(self): |
|
|
if get_satclip is None: |
|
|
print("Error: SatCLIP functionality is not available.") |
|
|
return |
|
|
|
|
|
print(f"Loading SatCLIP model from {self.ckpt_path}...") |
|
|
try: |
|
|
if not os.path.exists(self.ckpt_path): |
|
|
print(f"Warning: Checkpoint not found at {self.ckpt_path}") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
self.model = get_satclip(self.ckpt_path, self.device, return_all=True) |
|
|
self.model.eval() |
|
|
print(f"SatCLIP model loaded on {self.device}") |
|
|
except Exception as e: |
|
|
print(f"Error loading SatCLIP model: {e}") |
|
|
|
|
|
def load_embeddings(self): |
|
|
|
|
|
print(f"Loading SatCLIP embeddings from {self.embedding_path}...") |
|
|
try: |
|
|
if not os.path.exists(self.embedding_path): |
|
|
print(f"Warning: Embedding file not found at {self.embedding_path}") |
|
|
return |
|
|
|
|
|
self.df_embed = pq.read_table(self.embedding_path).to_pandas() |
|
|
|
|
|
|
|
|
image_embeddings_np = np.stack(self.df_embed['embedding'].values) |
|
|
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() |
|
|
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) |
|
|
print(f"SatCLIP Data loaded: {len(self.df_embed)} records") |
|
|
except Exception as e: |
|
|
print(f"Error loading SatCLIP embeddings: {e}") |
|
|
|
|
|
def encode_location(self, lat, lon): |
|
|
""" |
|
|
Encode a (latitude, longitude) pair into a vector. |
|
|
""" |
|
|
if self.model is None: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
loc_features = self.model.encode_location(coords).float() |
|
|
loc_features = loc_features / loc_features.norm(dim=1, keepdim=True) |
|
|
|
|
|
return loc_features |
|
|
|
|
|
def encode_image(self, image): |
|
|
""" |
|
|
Encode an RGB image into a vector using SatCLIP visual encoder. |
|
|
Adapts RGB (3 channels) to SatCLIP input (13 channels). |
|
|
""" |
|
|
if self.model is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = image.convert("RGB") |
|
|
image = image.resize((224, 224)) |
|
|
img_np = np.array(image).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_tensor = np.zeros((13, 224, 224), dtype=np.float32) |
|
|
input_tensor[1] = img_np[:, :, 2] |
|
|
input_tensor[2] = img_np[:, :, 1] |
|
|
input_tensor[3] = img_np[:, :, 0] |
|
|
|
|
|
input_tensor = torch.from_numpy(input_tensor).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
img_feature = self.model.encode_image(input_tensor) |
|
|
img_feature = img_feature / img_feature.norm(dim=1, keepdim=True) |
|
|
|
|
|
return img_feature |
|
|
except Exception as e: |
|
|
print(f"Error encoding image in SatCLIP: {e}") |
|
|
return None |
|
|
return None |
|
|
|
|
|
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): |
|
|
if self.image_embeddings is None: |
|
|
return None, None, None |
|
|
|
|
|
query_features = query_features.float() |
|
|
|
|
|
|
|
|
|
|
|
probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten() |
|
|
|
|
|
if top_percent is not None: |
|
|
k = int(len(probs) * top_percent) |
|
|
if k < 1: k = 1 |
|
|
threshold = np.partition(probs, -k)[-k] |
|
|
|
|
|
|
|
|
mask = probs >= threshold |
|
|
filtered_indices = np.where(mask)[0] |
|
|
|
|
|
|
|
|
top_indices = np.argsort(probs)[-top_k:][::-1] |
|
|
|
|
|
return probs, filtered_indices, top_indices |
|
|
|