Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import pyarrow.parquet as pq | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import warnings | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| from models.SatCLIP.satclip.load import get_satclip | |
| print("Successfully imported models.SatCLIP.satclip.load.get_satclip.") | |
| class SatCLIPModel: | |
| def __init__(self, | |
| ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt', | |
| embedding_path='./embedding_datasets/10percent_satclip_encoded/all_satclip_embeddings.parquet', # Path to pre-computed embeddings if available | |
| device=None): | |
| self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| if 'hf' in ckpt_path: | |
| ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt") | |
| self.ckpt_path = ckpt_path | |
| self.embedding_path = embedding_path | |
| self.model = None | |
| self.df_embed = None | |
| self.image_embeddings = None | |
| self.load_model() | |
| if self.embedding_path: | |
| self.load_embeddings() | |
| def load_model(self): | |
| if get_satclip is None: | |
| print("Error: SatCLIP functionality is not available.") | |
| return | |
| print(f"Loading SatCLIP model from {self.ckpt_path}...") | |
| try: | |
| if not os.path.exists(self.ckpt_path): | |
| print(f"Warning: Checkpoint not found at {self.ckpt_path}") | |
| return | |
| # Load model using get_satclip | |
| # return_all=True to get both visual and location encoders | |
| self.model = get_satclip(self.ckpt_path, self.device, return_all=True) | |
| self.model.eval() | |
| print(f"SatCLIP model loaded on {self.device}") | |
| except Exception as e: | |
| print(f"Error loading SatCLIP model: {e}") | |
| def load_embeddings(self): | |
| # Assuming embeddings are stored similarly to SigLIP | |
| print(f"Loading SatCLIP embeddings from {self.embedding_path}...") | |
| try: | |
| if not os.path.exists(self.embedding_path): | |
| print(f"Warning: Embedding file not found at {self.embedding_path}") | |
| return | |
| self.df_embed = pq.read_table(self.embedding_path).to_pandas() | |
| # Pre-compute image embeddings tensor | |
| image_embeddings_np = np.stack(self.df_embed['embedding'].values) | |
| self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() | |
| self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) | |
| print(f"SatCLIP Data loaded: {len(self.df_embed)} records") | |
| except Exception as e: | |
| print(f"Error loading SatCLIP embeddings: {e}") | |
| def encode_location(self, lat, lon): | |
| """ | |
| Encode a (latitude, longitude) pair into a vector. | |
| """ | |
| if self.model is None: | |
| return None | |
| # SatCLIP expects input shape (N, 2) -> (lon, lat) | |
| # Note: SatCLIP usually uses (lon, lat) order. | |
| # Use double precision as per notebook reference | |
| coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device) | |
| with torch.no_grad(): | |
| # Use model.encode_location instead of model.location_encoder | |
| # And normalize as per notebook: x / x.norm() | |
| loc_features = self.model.encode_location(coords).float() | |
| loc_features = loc_features / loc_features.norm(dim=1, keepdim=True) | |
| return loc_features | |
| def encode_image(self, image): | |
| """ | |
| Encode an RGB image into a vector using SatCLIP visual encoder. | |
| Adapts RGB (3 channels) to SatCLIP input (13 channels). | |
| """ | |
| if self.model is None: | |
| return None | |
| try: | |
| # Handle PIL Image (RGB) | |
| if isinstance(image, Image.Image): | |
| image = image.convert("RGB") | |
| image = image.resize((224, 224)) | |
| img_np = np.array(image).astype(np.float32) / 255.0 | |
| # Construct 13 channels | |
| # S2 bands: B01, B02(B), B03(G), B04(R), B05, B06, B07, B08, B8A, B09, B10, B11, B12 | |
| # Indices: 0=B01, 1=B02, 2=B03, 3=B04 ... | |
| input_tensor = np.zeros((13, 224, 224), dtype=np.float32) | |
| input_tensor[1] = img_np[:, :, 2] # Blue | |
| input_tensor[2] = img_np[:, :, 1] # Green | |
| input_tensor[3] = img_np[:, :, 0] # Red | |
| input_tensor = torch.from_numpy(input_tensor).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| img_feature = self.model.encode_image(input_tensor) | |
| img_feature = img_feature / img_feature.norm(dim=1, keepdim=True) | |
| return img_feature | |
| except Exception as e: | |
| print(f"Error encoding image in SatCLIP: {e}") | |
| return None | |
| return None | |
| def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): | |
| if self.image_embeddings is None: | |
| return None, None, None | |
| query_features = query_features.float() | |
| # Similarity calculation (Cosine similarity) | |
| # SatCLIP embeddings are normalized, so dot product is cosine similarity | |
| probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten() | |
| if top_percent is not None: | |
| k = int(len(probs) * top_percent) | |
| if k < 1: k = 1 | |
| threshold = np.partition(probs, -k)[-k] | |
| # Filter by threshold | |
| mask = probs >= threshold | |
| filtered_indices = np.where(mask)[0] | |
| # Get top k | |
| top_indices = np.argsort(probs)[-top_k:][::-1] | |
| return probs, filtered_indices, top_indices | |