import torch from transformers import AutoImageProcessor, AutoModel import numpy as np import pandas as pd import pyarrow.parquet as pq import torch.nn.functional as F from PIL import Image import os class DINOv2Model: """ DINOv2 model wrapper for Sentinel-2 RGB data embedding and search. This class provides a unified interface for: - Loading DINOv2 models from local checkpoint or HuggingFace - Encoding images into embeddings - Loading pre-computed embeddings - Searching similar images using cosine similarity The model processes Sentinel-2 RGB data by normalizing it to true-color values and generating feature embeddings using the DINOv2 architecture. """ def __init__(self, ckpt_path="./checkpoints/DINOv2", model_name="facebook/dinov2-large", embedding_path="./embedding_datasets/10percent_dinov2_encoded/all_dinov2_embeddings.parquet", device=None): """ Initialize the DINOv2Model. Args: ckpt_path (str): Path to local checkpoint directory or 'hf' for HuggingFace model_name (str): HuggingFace model name (used when ckpt_path='hf') embedding_path (str): Path to pre-computed embeddings parquet file device (str): Device to use ('cuda', 'cpu', or None for auto-detection) """ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") self.model_name = model_name self.ckpt_path = ckpt_path self.embedding_path = embedding_path self.model = None self.processor = None self.df_embed = None self.image_embeddings = None # Define the RGB bands for Sentinel-2 (B04, B03, B02) self.bands = ['B04', 'B03', 'B02'] self.size = None self.load_model() if self.embedding_path is not None: self.load_embeddings() def load_model(self): """Load DINOv2 model and processor from local checkpoint or HuggingFace.""" print(f"Loading DINOv2 model from {self.ckpt_path}...") try: if self.ckpt_path == 'hf': # Load from HuggingFace print(f"Loading from HuggingFace: {self.model_name}") self.processor = AutoImageProcessor.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name) elif self.ckpt_path.startswith('ms'): # Load from ModelScope import modelscope self.processor = modelscope.AutoImageProcessor.from_pretrained(self.model_name) self.model = modelscope.AutoModel.from_pretrained(self.model_name) else: self.processor = AutoImageProcessor.from_pretrained(self.ckpt_path) self.model = AutoModel.from_pretrained(self.ckpt_path) self.model = self.model.to(self.device) self.model.eval() # Extract the input size from the processor settings if hasattr(self.processor, 'crop_size'): self.size = (self.processor.crop_size['height'], self.processor.crop_size['width']) elif hasattr(self.processor, 'size'): if isinstance(self.processor.size, dict): self.size = (self.processor.size.get('height', 224), self.processor.size.get('width', 224)) else: self.size = (self.processor.size, self.processor.size) else: self.size = (224, 224) print(f"DINOv2 model loaded on {self.device}, input size: {self.size}") except Exception as e: print(f"Error loading DINOv2 model: {e}") def load_embeddings(self): """Load pre-computed embeddings from parquet file.""" print(f"Loading DINOv2 embeddings from {self.embedding_path}...") try: if not os.path.exists(self.embedding_path): print(f"Warning: Embedding file not found at {self.embedding_path}") return self.df_embed = pq.read_table(self.embedding_path).to_pandas() # Pre-compute image embeddings tensor image_embeddings_np = np.stack(self.df_embed['embedding'].values) self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) print(f"DINOv2 Data loaded: {len(self.df_embed)} records") except Exception as e: print(f"Error loading DINOv2 embeddings: {e}") # def normalize_s2(self, input_data): # """ # Normalize Sentinel-2 RGB data to true-color values. # Converts raw Sentinel-2 reflectance values to normalized true-color values # suitable for the DINOv2 model. # Args: # input_data (torch.Tensor or np.ndarray): Raw Sentinel-2 image data # Returns: # torch.Tensor or np.ndarray: Normalized true-color image in range [0, 1] # """ # return (2.5 * (input_data / 1e4)).clip(0, 1) def encode_image(self, image, is_sentinel2=False): """ Encode an image into a feature embedding. Args: image (PIL.Image, torch.Tensor, or np.ndarray): Input image - PIL.Image: RGB image - torch.Tensor: Image tensor with shape [C, H, W] (Sentinel-2) or [H, W, C] - np.ndarray: Image array with shape [H, W, C] is_sentinel2 (bool): Whether to apply Sentinel-2 normalization Returns: torch.Tensor: Normalized embedding vector with shape [embedding_dim] """ if self.model is None or self.processor is None: print("Model not loaded!") return None try: # Convert to PIL Image if needed if isinstance(image, torch.Tensor): if is_sentinel2: # Sentinel-2 data: [C, H, W] -> normalize -> PIL image = self.normalize_s2(image) # Convert to [H, W, C] and then to numpy if image.shape[0] == 3: # [C, H, W] image = image.permute(1, 2, 0) image_np = (image.cpu().numpy() * 255).astype(np.uint8) image = Image.fromarray(image_np, mode='RGB') else: # Regular RGB tensor: [H, W, C] or [C, H, W] if image.shape[0] == 3: # [C, H, W] image = image.permute(1, 2, 0) image_np = (image.cpu().numpy() * 255).astype(np.uint8) image = Image.fromarray(image_np, mode='RGB') elif isinstance(image, np.ndarray): if is_sentinel2: image = self.normalize_s2(image) # Assume [H, W, C] format if image.max() <= 1.0: image = (image * 255).astype(np.uint8) else: image = image.astype(np.uint8) image = Image.fromarray(image, mode='RGB') elif isinstance(image, Image.Image): image = image.convert("RGB") else: raise ValueError(f"Unsupported image type: {type(image)}") # Process image inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs['pixel_values'].to(self.device) # Generate embeddings with torch.no_grad(): if self.device == "cuda": # with torch.amp.autocast('cuda'): # disable amp as the official embedding is float32 outputs = self.model(pixel_values) else: outputs = self.model(pixel_values) # Get embeddings: average across sequence dimension last_hidden_states = outputs.last_hidden_state image_features = last_hidden_states.mean(dim=1) # Normalize image_features = F.normalize(image_features, dim=-1) return image_features except Exception as e: print(f"Error encoding image: {e}") import traceback traceback.print_exc() return None def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): """ Search for similar images using cosine similarity. Args: query_features (torch.Tensor): Query embedding vector top_k (int): Number of top results to return top_percent (float): If set, use top percentage instead of top_k threshold (float): Minimum similarity threshold Returns: tuple: (similarities, filtered_indices, top_indices) - similarities: Similarity scores for all images - filtered_indices: Indices of images above threshold - top_indices: Indices of top-k results """ if self.image_embeddings is None: print("Embeddings not loaded!") return None, None, None try: # Ensure query_features is float32 and on correct device query_features = query_features.float().to(self.device) # Normalize query features query_features = F.normalize(query_features, dim=-1) # Cosine similarity similarity = (self.image_embeddings @ query_features.T).squeeze() similarities = similarity.detach().cpu().numpy() # Handle top_percent if top_percent is not None: k = int(len(similarities) * top_percent) if k < 1: k = 1 threshold = np.partition(similarities, -k)[-k] # Filter by threshold mask = similarities >= threshold filtered_indices = np.where(mask)[0] # Get top k top_indices = np.argsort(similarities)[-top_k:][::-1] return similarities, filtered_indices, top_indices except Exception as e: print(f"Error during search: {e}") return None, None, None # Legacy class for backward compatibility class DINOv2_S2RGB_Embedder(torch.nn.Module): """ Legacy embedding wrapper for DINOv2 and Sentinel-2 data. This class is kept for backward compatibility with existing code. For new projects, please use DINOv2Model instead. """ def __init__(self): """Initialize the legacy DINOv2_S2RGB_Embedder.""" super().__init__() # Load the DINOv2 processor and model from Hugging Face self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') self.model = AutoModel.from_pretrained('facebook/dinov2-base') # Define the RGB bands for Sentinel-2 (B04, B03, B02) self.bands = ['B04', 'B03', 'B02'] # Extract the input size from the processor settings self.size = self.processor.crop_size['height'], self.processor.crop_size['width'] def normalize(self, input): """ Normalize Sentinel-2 RGB data to true-color values. Args: input (torch.Tensor): Raw Sentinel-2 image tensor Returns: torch.Tensor: Normalized true-color image """ return (2.5 * (input / 1e4)).clip(0, 1) def forward(self, input): """ Forward pass through the model to generate embeddings. Args: input (torch.Tensor): Input Sentinel-2 image tensor with shape [C, H, W] Returns: torch.Tensor: Embedding vector with shape [embedding_dim] """ model_input = self.processor(self.normalize(input), return_tensors="pt") outputs = self.model(model_input['pixel_values'].to(self.model.device)) last_hidden_states = outputs.last_hidden_state return last_hidden_states.mean(dim=1).cpu()