File size: 6,082 Bytes
eb1aec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import sys
import os
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import hf_hub_download
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
from models.SatCLIP.satclip.load import get_satclip
print("Successfully imported models.SatCLIP.satclip.load.get_satclip.")
class SatCLIPModel:
def __init__(self,
ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt',
embedding_path='./embedding_datasets/10percent_satclip_encoded/all_satclip_embeddings.parquet', # Path to pre-computed embeddings if available
device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
if 'hf' in ckpt_path:
ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt")
self.ckpt_path = ckpt_path
self.embedding_path = embedding_path
self.model = None
self.df_embed = None
self.image_embeddings = None
self.load_model()
if self.embedding_path:
self.load_embeddings()
def load_model(self):
if get_satclip is None:
print("Error: SatCLIP functionality is not available.")
return
print(f"Loading SatCLIP model from {self.ckpt_path}...")
try:
if not os.path.exists(self.ckpt_path):
print(f"Warning: Checkpoint not found at {self.ckpt_path}")
return
# Load model using get_satclip
# return_all=True to get both visual and location encoders
self.model = get_satclip(self.ckpt_path, self.device, return_all=True)
self.model.eval()
print(f"SatCLIP model loaded on {self.device}")
except Exception as e:
print(f"Error loading SatCLIP model: {e}")
def load_embeddings(self):
# Assuming embeddings are stored similarly to SigLIP
print(f"Loading SatCLIP embeddings from {self.embedding_path}...")
try:
if not os.path.exists(self.embedding_path):
print(f"Warning: Embedding file not found at {self.embedding_path}")
return
self.df_embed = pq.read_table(self.embedding_path).to_pandas()
# Pre-compute image embeddings tensor
image_embeddings_np = np.stack(self.df_embed['embedding'].values)
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
print(f"SatCLIP Data loaded: {len(self.df_embed)} records")
except Exception as e:
print(f"Error loading SatCLIP embeddings: {e}")
def encode_location(self, lat, lon):
"""
Encode a (latitude, longitude) pair into a vector.
"""
if self.model is None:
return None
# SatCLIP expects input shape (N, 2) -> (lon, lat)
# Note: SatCLIP usually uses (lon, lat) order.
# Use double precision as per notebook reference
coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device)
with torch.no_grad():
# Use model.encode_location instead of model.location_encoder
# And normalize as per notebook: x / x.norm()
loc_features = self.model.encode_location(coords).float()
loc_features = loc_features / loc_features.norm(dim=1, keepdim=True)
return loc_features
def encode_image(self, image):
"""
Encode an RGB image into a vector using SatCLIP visual encoder.
Adapts RGB (3 channels) to SatCLIP input (13 channels).
"""
if self.model is None:
return None
try:
# Handle PIL Image (RGB)
if isinstance(image, Image.Image):
image = image.convert("RGB")
image = image.resize((224, 224))
img_np = np.array(image).astype(np.float32) / 255.0
# Construct 13 channels
# S2 bands: B01, B02(B), B03(G), B04(R), B05, B06, B07, B08, B8A, B09, B10, B11, B12
# Indices: 0=B01, 1=B02, 2=B03, 3=B04 ...
input_tensor = np.zeros((13, 224, 224), dtype=np.float32)
input_tensor[1] = img_np[:, :, 2] # Blue
input_tensor[2] = img_np[:, :, 1] # Green
input_tensor[3] = img_np[:, :, 0] # Red
input_tensor = torch.from_numpy(input_tensor).unsqueeze(0).to(self.device)
with torch.no_grad():
img_feature = self.model.encode_image(input_tensor)
img_feature = img_feature / img_feature.norm(dim=1, keepdim=True)
return img_feature
except Exception as e:
print(f"Error encoding image in SatCLIP: {e}")
return None
return None
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
if self.image_embeddings is None:
return None, None, None
query_features = query_features.float()
# Similarity calculation (Cosine similarity)
# SatCLIP embeddings are normalized, so dot product is cosine similarity
probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten()
if top_percent is not None:
k = int(len(probs) * top_percent)
if k < 1: k = 1
threshold = np.partition(probs, -k)[-k]
# Filter by threshold
mask = probs >= threshold
filtered_indices = np.where(mask)[0]
# Get top k
top_indices = np.argsort(probs)[-top_k:][::-1]
return probs, filtered_indices, top_indices
|