File size: 6,082 Bytes
eb1aec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import sys
import os
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import hf_hub_download

import warnings


with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    from models.SatCLIP.satclip.load import get_satclip
    print("Successfully imported models.SatCLIP.satclip.load.get_satclip.")

class SatCLIPModel:
    def __init__(self, 
                 ckpt_path='./checkpoints/SatCLIP/satclip-vit16-l40.ckpt',
                 embedding_path='./embedding_datasets/10percent_satclip_encoded/all_satclip_embeddings.parquet', # Path to pre-computed embeddings if available
                 device=None):
        
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        if 'hf' in ckpt_path:
            ckpt_path = hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt")
        self.ckpt_path = ckpt_path
        self.embedding_path = embedding_path
        
        self.model = None
        self.df_embed = None
        self.image_embeddings = None
        
        self.load_model()
        if self.embedding_path:
            self.load_embeddings()

    def load_model(self):
        if get_satclip is None:
            print("Error: SatCLIP functionality is not available.")
            return

        print(f"Loading SatCLIP model from {self.ckpt_path}...")
        try:
            if not os.path.exists(self.ckpt_path):
                print(f"Warning: Checkpoint not found at {self.ckpt_path}")
                return

            # Load model using get_satclip
            # return_all=True to get both visual and location encoders
            self.model = get_satclip(self.ckpt_path, self.device, return_all=True)
            self.model.eval()
            print(f"SatCLIP model loaded on {self.device}")
        except Exception as e:
            print(f"Error loading SatCLIP model: {e}")

    def load_embeddings(self):
        # Assuming embeddings are stored similarly to SigLIP
        print(f"Loading SatCLIP embeddings from {self.embedding_path}...")
        try:
            if not os.path.exists(self.embedding_path):
                print(f"Warning: Embedding file not found at {self.embedding_path}")
                return

            self.df_embed = pq.read_table(self.embedding_path).to_pandas()
            
            # Pre-compute image embeddings tensor
            image_embeddings_np = np.stack(self.df_embed['embedding'].values)
            self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
            self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
            print(f"SatCLIP Data loaded: {len(self.df_embed)} records")
        except Exception as e:
            print(f"Error loading SatCLIP embeddings: {e}")

    def encode_location(self, lat, lon):
        """
        Encode a (latitude, longitude) pair into a vector.
        """
        if self.model is None:
            return None
        
        # SatCLIP expects input shape (N, 2) -> (lon, lat)
        # Note: SatCLIP usually uses (lon, lat) order.
        # Use double precision as per notebook reference
        coords = torch.tensor([[lon, lat]], dtype=torch.double).to(self.device)
        
        with torch.no_grad():
            # Use model.encode_location instead of model.location_encoder
            # And normalize as per notebook: x / x.norm()
            loc_features = self.model.encode_location(coords).float()
            loc_features = loc_features / loc_features.norm(dim=1, keepdim=True)
            
        return loc_features

    def encode_image(self, image):
        """
        Encode an RGB image into a vector using SatCLIP visual encoder.
        Adapts RGB (3 channels) to SatCLIP input (13 channels).
        """
        if self.model is None:
            return None
        
        try:
            # Handle PIL Image (RGB)
            if isinstance(image, Image.Image):
                image = image.convert("RGB")
                image = image.resize((224, 224))
                img_np = np.array(image).astype(np.float32) / 255.0
                
                # Construct 13 channels
                # S2 bands: B01, B02(B), B03(G), B04(R), B05, B06, B07, B08, B8A, B09, B10, B11, B12
                # Indices: 0=B01, 1=B02, 2=B03, 3=B04 ...
                input_tensor = np.zeros((13, 224, 224), dtype=np.float32)
                input_tensor[1] = img_np[:, :, 2] # Blue
                input_tensor[2] = img_np[:, :, 1] # Green
                input_tensor[3] = img_np[:, :, 0] # Red
                
                input_tensor = torch.from_numpy(input_tensor).unsqueeze(0).to(self.device)
                
                with torch.no_grad():
                    img_feature = self.model.encode_image(input_tensor)
                    img_feature = img_feature / img_feature.norm(dim=1, keepdim=True)
                    
                return img_feature
        except Exception as e:
            print(f"Error encoding image in SatCLIP: {e}")
            return None
        return None

    def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
        if self.image_embeddings is None:
            return None, None, None

        query_features = query_features.float()
        
        # Similarity calculation (Cosine similarity)
        # SatCLIP embeddings are normalized, so dot product is cosine similarity
        probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten()
        
        if top_percent is not None:
            k = int(len(probs) * top_percent)
            if k < 1: k = 1
            threshold = np.partition(probs, -k)[-k]

        # Filter by threshold
        mask = probs >= threshold
        filtered_indices = np.where(mask)[0]
        
        # Get top k
        top_indices = np.argsort(probs)[-top_k:][::-1]
        
        return probs, filtered_indices, top_indices