|
|
import torch |
|
|
from transformers import AutoImageProcessor, AutoModel |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import pyarrow.parquet as pq |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import os |
|
|
|
|
|
class DINOv2Model: |
|
|
""" |
|
|
DINOv2 model wrapper for Sentinel-2 RGB data embedding and search. |
|
|
|
|
|
This class provides a unified interface for: |
|
|
- Loading DINOv2 models from local checkpoint or HuggingFace |
|
|
- Encoding images into embeddings |
|
|
- Loading pre-computed embeddings |
|
|
- Searching similar images using cosine similarity |
|
|
|
|
|
The model processes Sentinel-2 RGB data by normalizing it to true-color values |
|
|
and generating feature embeddings using the DINOv2 architecture. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
ckpt_path="./checkpoints/DINOv2", |
|
|
model_name="facebook/dinov2-large", |
|
|
embedding_path="./embedding_datasets/10percent_dinov2_encoded/all_dinov2_embeddings.parquet", |
|
|
device=None): |
|
|
""" |
|
|
Initialize the DINOv2Model. |
|
|
|
|
|
Args: |
|
|
ckpt_path (str): Path to local checkpoint directory or 'hf' for HuggingFace |
|
|
model_name (str): HuggingFace model name (used when ckpt_path='hf') |
|
|
embedding_path (str): Path to pre-computed embeddings parquet file |
|
|
device (str): Device to use ('cuda', 'cpu', or None for auto-detection) |
|
|
""" |
|
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model_name = model_name |
|
|
self.ckpt_path = ckpt_path |
|
|
self.embedding_path = embedding_path |
|
|
|
|
|
self.model = None |
|
|
self.processor = None |
|
|
self.df_embed = None |
|
|
self.image_embeddings = None |
|
|
|
|
|
|
|
|
self.bands = ['B04', 'B03', 'B02'] |
|
|
self.size = None |
|
|
|
|
|
self.load_model() |
|
|
if self.embedding_path is not None: |
|
|
self.load_embeddings() |
|
|
|
|
|
def load_model(self): |
|
|
"""Load DINOv2 model and processor from local checkpoint or HuggingFace.""" |
|
|
print(f"Loading DINOv2 model from {self.ckpt_path}...") |
|
|
try: |
|
|
if self.ckpt_path == 'hf': |
|
|
|
|
|
print(f"Loading from HuggingFace: {self.model_name}") |
|
|
self.processor = AutoImageProcessor.from_pretrained(self.model_name) |
|
|
self.model = AutoModel.from_pretrained(self.model_name) |
|
|
elif self.ckpt_path.startswith('ms'): |
|
|
|
|
|
import modelscope |
|
|
self.processor = modelscope.AutoImageProcessor.from_pretrained(self.model_name) |
|
|
self.model = modelscope.AutoModel.from_pretrained(self.model_name) |
|
|
else: |
|
|
self.processor = AutoImageProcessor.from_pretrained(self.ckpt_path) |
|
|
self.model = AutoModel.from_pretrained(self.ckpt_path) |
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
if hasattr(self.processor, 'crop_size'): |
|
|
self.size = (self.processor.crop_size['height'], self.processor.crop_size['width']) |
|
|
elif hasattr(self.processor, 'size'): |
|
|
if isinstance(self.processor.size, dict): |
|
|
self.size = (self.processor.size.get('height', 224), self.processor.size.get('width', 224)) |
|
|
else: |
|
|
self.size = (self.processor.size, self.processor.size) |
|
|
else: |
|
|
self.size = (224, 224) |
|
|
|
|
|
print(f"DINOv2 model loaded on {self.device}, input size: {self.size}") |
|
|
except Exception as e: |
|
|
print(f"Error loading DINOv2 model: {e}") |
|
|
|
|
|
def load_embeddings(self): |
|
|
"""Load pre-computed embeddings from parquet file.""" |
|
|
print(f"Loading DINOv2 embeddings from {self.embedding_path}...") |
|
|
try: |
|
|
if not os.path.exists(self.embedding_path): |
|
|
print(f"Warning: Embedding file not found at {self.embedding_path}") |
|
|
return |
|
|
|
|
|
self.df_embed = pq.read_table(self.embedding_path).to_pandas() |
|
|
|
|
|
|
|
|
image_embeddings_np = np.stack(self.df_embed['embedding'].values) |
|
|
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() |
|
|
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) |
|
|
print(f"DINOv2 Data loaded: {len(self.df_embed)} records") |
|
|
except Exception as e: |
|
|
print(f"Error loading DINOv2 embeddings: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_image(self, image, is_sentinel2=False): |
|
|
""" |
|
|
Encode an image into a feature embedding. |
|
|
|
|
|
Args: |
|
|
image (PIL.Image, torch.Tensor, or np.ndarray): Input image |
|
|
- PIL.Image: RGB image |
|
|
- torch.Tensor: Image tensor with shape [C, H, W] (Sentinel-2) or [H, W, C] |
|
|
- np.ndarray: Image array with shape [H, W, C] |
|
|
is_sentinel2 (bool): Whether to apply Sentinel-2 normalization |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Normalized embedding vector with shape [embedding_dim] |
|
|
""" |
|
|
if self.model is None or self.processor is None: |
|
|
print("Model not loaded!") |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(image, torch.Tensor): |
|
|
if is_sentinel2: |
|
|
|
|
|
image = self.normalize_s2(image) |
|
|
|
|
|
if image.shape[0] == 3: |
|
|
image = image.permute(1, 2, 0) |
|
|
image_np = (image.cpu().numpy() * 255).astype(np.uint8) |
|
|
image = Image.fromarray(image_np, mode='RGB') |
|
|
else: |
|
|
|
|
|
if image.shape[0] == 3: |
|
|
image = image.permute(1, 2, 0) |
|
|
image_np = (image.cpu().numpy() * 255).astype(np.uint8) |
|
|
image = Image.fromarray(image_np, mode='RGB') |
|
|
elif isinstance(image, np.ndarray): |
|
|
if is_sentinel2: |
|
|
image = self.normalize_s2(image) |
|
|
|
|
|
if image.max() <= 1.0: |
|
|
image = (image * 255).astype(np.uint8) |
|
|
else: |
|
|
image = image.astype(np.uint8) |
|
|
image = Image.fromarray(image, mode='RGB') |
|
|
elif isinstance(image, Image.Image): |
|
|
image = image.convert("RGB") |
|
|
else: |
|
|
raise ValueError(f"Unsupported image type: {type(image)}") |
|
|
|
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt") |
|
|
pixel_values = inputs['pixel_values'].to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
if self.device == "cuda": |
|
|
|
|
|
outputs = self.model(pixel_values) |
|
|
else: |
|
|
outputs = self.model(pixel_values) |
|
|
|
|
|
|
|
|
last_hidden_states = outputs.last_hidden_state |
|
|
image_features = last_hidden_states.mean(dim=1) |
|
|
|
|
|
|
|
|
image_features = F.normalize(image_features, dim=-1) |
|
|
|
|
|
return image_features |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error encoding image: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): |
|
|
""" |
|
|
Search for similar images using cosine similarity. |
|
|
|
|
|
Args: |
|
|
query_features (torch.Tensor): Query embedding vector |
|
|
top_k (int): Number of top results to return |
|
|
top_percent (float): If set, use top percentage instead of top_k |
|
|
threshold (float): Minimum similarity threshold |
|
|
|
|
|
Returns: |
|
|
tuple: (similarities, filtered_indices, top_indices) |
|
|
- similarities: Similarity scores for all images |
|
|
- filtered_indices: Indices of images above threshold |
|
|
- top_indices: Indices of top-k results |
|
|
""" |
|
|
if self.image_embeddings is None: |
|
|
print("Embeddings not loaded!") |
|
|
return None, None, None |
|
|
|
|
|
try: |
|
|
|
|
|
query_features = query_features.float().to(self.device) |
|
|
|
|
|
|
|
|
query_features = F.normalize(query_features, dim=-1) |
|
|
|
|
|
|
|
|
similarity = (self.image_embeddings @ query_features.T).squeeze() |
|
|
similarities = similarity.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
if top_percent is not None: |
|
|
k = int(len(similarities) * top_percent) |
|
|
if k < 1: |
|
|
k = 1 |
|
|
threshold = np.partition(similarities, -k)[-k] |
|
|
|
|
|
|
|
|
mask = similarities >= threshold |
|
|
filtered_indices = np.where(mask)[0] |
|
|
|
|
|
|
|
|
top_indices = np.argsort(similarities)[-top_k:][::-1] |
|
|
|
|
|
return similarities, filtered_indices, top_indices |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during search: {e}") |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
|
|
|
class DINOv2_S2RGB_Embedder(torch.nn.Module): |
|
|
""" |
|
|
Legacy embedding wrapper for DINOv2 and Sentinel-2 data. |
|
|
|
|
|
This class is kept for backward compatibility with existing code. |
|
|
For new projects, please use DINOv2Model instead. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize the legacy DINOv2_S2RGB_Embedder.""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') |
|
|
self.model = AutoModel.from_pretrained('facebook/dinov2-base') |
|
|
|
|
|
|
|
|
self.bands = ['B04', 'B03', 'B02'] |
|
|
|
|
|
|
|
|
self.size = self.processor.crop_size['height'], self.processor.crop_size['width'] |
|
|
|
|
|
def normalize(self, input): |
|
|
""" |
|
|
Normalize Sentinel-2 RGB data to true-color values. |
|
|
|
|
|
Args: |
|
|
input (torch.Tensor): Raw Sentinel-2 image tensor |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Normalized true-color image |
|
|
""" |
|
|
return (2.5 * (input / 1e4)).clip(0, 1) |
|
|
|
|
|
def forward(self, input): |
|
|
""" |
|
|
Forward pass through the model to generate embeddings. |
|
|
|
|
|
Args: |
|
|
input (torch.Tensor): Input Sentinel-2 image tensor with shape [C, H, W] |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Embedding vector with shape [embedding_dim] |
|
|
""" |
|
|
model_input = self.processor(self.normalize(input), return_tensors="pt") |
|
|
outputs = self.model(model_input['pixel_values'].to(self.model.device)) |
|
|
last_hidden_states = outputs.last_hidden_state |
|
|
return last_hidden_states.mean(dim=1).cpu() |
|
|
|