ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
import torch
from transformers import AutoImageProcessor, AutoModel
class DINOv2_S2RGB_Embedder(torch.nn.Module):
"""
Embedding wrapper for DINOv2 and Sentinel-2 data.
This model uses the DINOv2 architecture to generate embeddings for Sentinel-2 RGB data. The input data (RGB bands)
is preprocessed by normalizing and mapping it to true-color values. Then, it is passed through the DINOv2 model
to obtain feature embeddings.
Preprocessing:
The input Sentinel-2 image is divided by 10,000 and multiplied by 2.5 to map it to a true-color image
(normalized to the range [0, 1]), followed by processing using the DINOv2 image processor.
Model:
The DINOv2 model processes RGB input images of shape [224, 224] and produces embeddings, which are then
averaged across the sequence dimension to obtain a fixed-size embedding vector.
Model Components:
- `AutoImageProcessor`: Preprocessing pipeline for handling Sentinel-2 data.
- `AutoModel`: DINOv2 transformer model used for feature extraction.
Attributes:
processor (AutoImageProcessor): The DINOv2 image processor to handle preprocessing.
model (AutoModel): The DINOv2 model used to generate embeddings from preprocessed images.
bands (list): List of the Sentinel-2 bands used for RGB input (B04, B03, B02).
size (tuple): The input size expected by the model (height, width) for the RGB image.
"""
def __init__(self):
"""
Initializes the DINOv2_S2RGB_Embedder by loading the pre-trained DINOv2 model and processor,
and setting the expected input size for Sentinel-2 RGB data.
This embedder uses the 'facebook/dinov2-base' model for feature extraction from Sentinel-2
true-color images (RGB).
Attributes:
processor (AutoImageProcessor): The DINOv2 image processor for preprocessing Sentinel-2 images.
model (AutoModel): The pre-trained DINOv2 model for generating embeddings.
bands (list): The Sentinel-2 bands used for RGB data (B04 - Red, B03 - Green, B02 - Blue).
size (tuple): The expected input size of the image for the DINOv2 model (height, width).
"""
super().__init__()
# Load the DINOv2 processor and model from Hugging Face
self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
self.model = AutoModel.from_pretrained('facebook/dinov2-base')
# Define the RGB bands for Sentinel-2 (B04, B03, B02)
self.bands = ['B04', 'B03', 'B02']
# Extract the input size from the processor settings
self.size = self.processor.crop_size['height'], self.processor.crop_size['width']
def normalize(self, input):
"""
Normalizes Sentinel-2 RGB data to true-color values.
The input image (in raw Sentinel-2 reflectance values) is first divided by 10,000 to convert it
to reflectance values in the range [0, 1]. Then, the result is multiplied by 2.5 to obtain true-color
values that are suitable for input into the DINOv2 model.
Args:
input (torch.Tensor): The raw Sentinel-2 image tensor to be normalized.
Returns:
torch.Tensor: The normalized true-color image.
"""
return (2.5 * (input / 1e4)).clip(0,1)
def forward(self, input):
"""
Forward pass through the model to generate embeddings for the input image.
The input image is first normalized using the `normalize` method, then processed by the DINOv2 image processor
and passed through the DINOv2 model to generate embeddings. The output from the model is averaged across
the sequence dimension to obtain a fixed-size embedding.
Args:
input (torch.Tensor): The input Sentinel-2 image tensor with shape [C, H, W], where C=3 (RGB channels).
Returns:
torch.Tensor: The embedding vector, averaged over the sequence dimension, with shape [embedding_dim].
"""
model_input = self.processor(self.normalize(input), return_tensors="pt")
outputs = self.model(model_input['pixel_values'].to(self.model.device))
last_hidden_states = outputs.last_hidden_state
return last_hidden_states.mean(dim=1).cpu()