EarthEmbeddingExplorer / models /dinov2_model.py
ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
import torch
from transformers import AutoImageProcessor, AutoModel
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
import os
class DINOv2Model:
"""
DINOv2 model wrapper for Sentinel-2 RGB data embedding and search.
This class provides a unified interface for:
- Loading DINOv2 models from local checkpoint or HuggingFace
- Encoding images into embeddings
- Loading pre-computed embeddings
- Searching similar images using cosine similarity
The model processes Sentinel-2 RGB data by normalizing it to true-color values
and generating feature embeddings using the DINOv2 architecture.
"""
def __init__(self,
ckpt_path="./checkpoints/DINOv2",
model_name="facebook/dinov2-large",
embedding_path="./embedding_datasets/10percent_dinov2_encoded/all_dinov2_embeddings.parquet",
device=None):
"""
Initialize the DINOv2Model.
Args:
ckpt_path (str): Path to local checkpoint directory or 'hf' for HuggingFace
model_name (str): HuggingFace model name (used when ckpt_path='hf')
embedding_path (str): Path to pre-computed embeddings parquet file
device (str): Device to use ('cuda', 'cpu', or None for auto-detection)
"""
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
self.ckpt_path = ckpt_path
self.embedding_path = embedding_path
self.model = None
self.processor = None
self.df_embed = None
self.image_embeddings = None
# Define the RGB bands for Sentinel-2 (B04, B03, B02)
self.bands = ['B04', 'B03', 'B02']
self.size = None
self.load_model()
if self.embedding_path is not None:
self.load_embeddings()
def load_model(self):
"""Load DINOv2 model and processor from local checkpoint or HuggingFace."""
print(f"Loading DINOv2 model from {self.ckpt_path}...")
try:
if self.ckpt_path == 'hf':
# Load from HuggingFace
print(f"Loading from HuggingFace: {self.model_name}")
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name)
elif self.ckpt_path.startswith('ms'):
# Load from ModelScope
import modelscope
self.processor = modelscope.AutoImageProcessor.from_pretrained(self.model_name)
self.model = modelscope.AutoModel.from_pretrained(self.model_name)
else:
self.processor = AutoImageProcessor.from_pretrained(self.ckpt_path)
self.model = AutoModel.from_pretrained(self.ckpt_path)
self.model = self.model.to(self.device)
self.model.eval()
# Extract the input size from the processor settings
if hasattr(self.processor, 'crop_size'):
self.size = (self.processor.crop_size['height'], self.processor.crop_size['width'])
elif hasattr(self.processor, 'size'):
if isinstance(self.processor.size, dict):
self.size = (self.processor.size.get('height', 224), self.processor.size.get('width', 224))
else:
self.size = (self.processor.size, self.processor.size)
else:
self.size = (224, 224)
print(f"DINOv2 model loaded on {self.device}, input size: {self.size}")
except Exception as e:
print(f"Error loading DINOv2 model: {e}")
def load_embeddings(self):
"""Load pre-computed embeddings from parquet file."""
print(f"Loading DINOv2 embeddings from {self.embedding_path}...")
try:
if not os.path.exists(self.embedding_path):
print(f"Warning: Embedding file not found at {self.embedding_path}")
return
self.df_embed = pq.read_table(self.embedding_path).to_pandas()
# Pre-compute image embeddings tensor
image_embeddings_np = np.stack(self.df_embed['embedding'].values)
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
print(f"DINOv2 Data loaded: {len(self.df_embed)} records")
except Exception as e:
print(f"Error loading DINOv2 embeddings: {e}")
# def normalize_s2(self, input_data):
# """
# Normalize Sentinel-2 RGB data to true-color values.
# Converts raw Sentinel-2 reflectance values to normalized true-color values
# suitable for the DINOv2 model.
# Args:
# input_data (torch.Tensor or np.ndarray): Raw Sentinel-2 image data
# Returns:
# torch.Tensor or np.ndarray: Normalized true-color image in range [0, 1]
# """
# return (2.5 * (input_data / 1e4)).clip(0, 1)
def encode_image(self, image, is_sentinel2=False):
"""
Encode an image into a feature embedding.
Args:
image (PIL.Image, torch.Tensor, or np.ndarray): Input image
- PIL.Image: RGB image
- torch.Tensor: Image tensor with shape [C, H, W] (Sentinel-2) or [H, W, C]
- np.ndarray: Image array with shape [H, W, C]
is_sentinel2 (bool): Whether to apply Sentinel-2 normalization
Returns:
torch.Tensor: Normalized embedding vector with shape [embedding_dim]
"""
if self.model is None or self.processor is None:
print("Model not loaded!")
return None
try:
# Convert to PIL Image if needed
if isinstance(image, torch.Tensor):
if is_sentinel2:
# Sentinel-2 data: [C, H, W] -> normalize -> PIL
image = self.normalize_s2(image)
# Convert to [H, W, C] and then to numpy
if image.shape[0] == 3: # [C, H, W]
image = image.permute(1, 2, 0)
image_np = (image.cpu().numpy() * 255).astype(np.uint8)
image = Image.fromarray(image_np, mode='RGB')
else:
# Regular RGB tensor: [H, W, C] or [C, H, W]
if image.shape[0] == 3: # [C, H, W]
image = image.permute(1, 2, 0)
image_np = (image.cpu().numpy() * 255).astype(np.uint8)
image = Image.fromarray(image_np, mode='RGB')
elif isinstance(image, np.ndarray):
if is_sentinel2:
image = self.normalize_s2(image)
# Assume [H, W, C] format
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
image = Image.fromarray(image, mode='RGB')
elif isinstance(image, Image.Image):
image = image.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Process image
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(self.device)
# Generate embeddings
with torch.no_grad():
if self.device == "cuda":
# with torch.amp.autocast('cuda'): # disable amp as the official embedding is float32
outputs = self.model(pixel_values)
else:
outputs = self.model(pixel_values)
# Get embeddings: average across sequence dimension
last_hidden_states = outputs.last_hidden_state
image_features = last_hidden_states.mean(dim=1)
# Normalize
image_features = F.normalize(image_features, dim=-1)
return image_features
except Exception as e:
print(f"Error encoding image: {e}")
import traceback
traceback.print_exc()
return None
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
"""
Search for similar images using cosine similarity.
Args:
query_features (torch.Tensor): Query embedding vector
top_k (int): Number of top results to return
top_percent (float): If set, use top percentage instead of top_k
threshold (float): Minimum similarity threshold
Returns:
tuple: (similarities, filtered_indices, top_indices)
- similarities: Similarity scores for all images
- filtered_indices: Indices of images above threshold
- top_indices: Indices of top-k results
"""
if self.image_embeddings is None:
print("Embeddings not loaded!")
return None, None, None
try:
# Ensure query_features is float32 and on correct device
query_features = query_features.float().to(self.device)
# Normalize query features
query_features = F.normalize(query_features, dim=-1)
# Cosine similarity
similarity = (self.image_embeddings @ query_features.T).squeeze()
similarities = similarity.detach().cpu().numpy()
# Handle top_percent
if top_percent is not None:
k = int(len(similarities) * top_percent)
if k < 1:
k = 1
threshold = np.partition(similarities, -k)[-k]
# Filter by threshold
mask = similarities >= threshold
filtered_indices = np.where(mask)[0]
# Get top k
top_indices = np.argsort(similarities)[-top_k:][::-1]
return similarities, filtered_indices, top_indices
except Exception as e:
print(f"Error during search: {e}")
return None, None, None
# Legacy class for backward compatibility
class DINOv2_S2RGB_Embedder(torch.nn.Module):
"""
Legacy embedding wrapper for DINOv2 and Sentinel-2 data.
This class is kept for backward compatibility with existing code.
For new projects, please use DINOv2Model instead.
"""
def __init__(self):
"""Initialize the legacy DINOv2_S2RGB_Embedder."""
super().__init__()
# Load the DINOv2 processor and model from Hugging Face
self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
self.model = AutoModel.from_pretrained('facebook/dinov2-base')
# Define the RGB bands for Sentinel-2 (B04, B03, B02)
self.bands = ['B04', 'B03', 'B02']
# Extract the input size from the processor settings
self.size = self.processor.crop_size['height'], self.processor.crop_size['width']
def normalize(self, input):
"""
Normalize Sentinel-2 RGB data to true-color values.
Args:
input (torch.Tensor): Raw Sentinel-2 image tensor
Returns:
torch.Tensor: Normalized true-color image
"""
return (2.5 * (input / 1e4)).clip(0, 1)
def forward(self, input):
"""
Forward pass through the model to generate embeddings.
Args:
input (torch.Tensor): Input Sentinel-2 image tensor with shape [C, H, W]
Returns:
torch.Tensor: Embedding vector with shape [embedding_dim]
"""
model_input = self.processor(self.normalize(input), return_tensors="pt")
outputs = self.model(model_input['pixel_values'].to(self.model.device))
last_hidden_states = outputs.last_hidden_state
return last_hidden_states.mean(dim=1).cpu()