| | import torch |
| | from transformers import AutoImageProcessor, AutoModel |
| |
|
| | class DINOv2_S2RGB_Embedder(torch.nn.Module): |
| | """ |
| | Embedding wrapper for DINOv2 and Sentinel-2 data. |
| | |
| | This model uses the DINOv2 architecture to generate embeddings for Sentinel-2 RGB data. The input data (RGB bands) |
| | is preprocessed by normalizing and mapping it to true-color values. Then, it is passed through the DINOv2 model |
| | to obtain feature embeddings. |
| | |
| | Preprocessing: |
| | The input Sentinel-2 image is divided by 10,000 and multiplied by 2.5 to map it to a true-color image |
| | (normalized to the range [0, 1]), followed by processing using the DINOv2 image processor. |
| | |
| | Model: |
| | The DINOv2 model processes RGB input images of shape [224, 224] and produces embeddings, which are then |
| | averaged across the sequence dimension to obtain a fixed-size embedding vector. |
| | |
| | Model Components: |
| | - `AutoImageProcessor`: Preprocessing pipeline for handling Sentinel-2 data. |
| | - `AutoModel`: DINOv2 transformer model used for feature extraction. |
| | |
| | Attributes: |
| | processor (AutoImageProcessor): The DINOv2 image processor to handle preprocessing. |
| | model (AutoModel): The DINOv2 model used to generate embeddings from preprocessed images. |
| | bands (list): List of the Sentinel-2 bands used for RGB input (B04, B03, B02). |
| | size (tuple): The input size expected by the model (height, width) for the RGB image. |
| | """ |
| |
|
| | def __init__(self): |
| | """ |
| | Initializes the DINOv2_S2RGB_Embedder by loading the pre-trained DINOv2 model and processor, |
| | and setting the expected input size for Sentinel-2 RGB data. |
| | |
| | This embedder uses the 'facebook/dinov2-base' model for feature extraction from Sentinel-2 |
| | true-color images (RGB). |
| | |
| | Attributes: |
| | processor (AutoImageProcessor): The DINOv2 image processor for preprocessing Sentinel-2 images. |
| | model (AutoModel): The pre-trained DINOv2 model for generating embeddings. |
| | bands (list): The Sentinel-2 bands used for RGB data (B04 - Red, B03 - Green, B02 - Blue). |
| | size (tuple): The expected input size of the image for the DINOv2 model (height, width). |
| | """ |
| | super().__init__() |
| |
|
| | |
| | self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') |
| | self.model = AutoModel.from_pretrained('facebook/dinov2-base') |
| |
|
| | |
| | self.bands = ['B04', 'B03', 'B02'] |
| |
|
| | |
| | self.size = self.processor.crop_size['height'], self.processor.crop_size['width'] |
| |
|
| |
|
| | def normalize(self, input): |
| | """ |
| | Normalizes Sentinel-2 RGB data to true-color values. |
| | |
| | The input image (in raw Sentinel-2 reflectance values) is first divided by 10,000 to convert it |
| | to reflectance values in the range [0, 1]. Then, the result is multiplied by 2.5 to obtain true-color |
| | values that are suitable for input into the DINOv2 model. |
| | |
| | Args: |
| | input (torch.Tensor): The raw Sentinel-2 image tensor to be normalized. |
| | |
| | Returns: |
| | torch.Tensor: The normalized true-color image. |
| | """ |
| | return (2.5 * (input / 1e4)).clip(0,1) |
| |
|
| | def forward(self, input): |
| | """ |
| | Forward pass through the model to generate embeddings for the input image. |
| | |
| | The input image is first normalized using the `normalize` method, then processed by the DINOv2 image processor |
| | and passed through the DINOv2 model to generate embeddings. The output from the model is averaged across |
| | the sequence dimension to obtain a fixed-size embedding. |
| | |
| | Args: |
| | input (torch.Tensor): The input Sentinel-2 image tensor with shape [C, H, W], where C=3 (RGB channels). |
| | |
| | Returns: |
| | torch.Tensor: The embedding vector, averaged over the sequence dimension, with shape [embedding_dim]. |
| | """ |
| | model_input = self.processor(self.normalize(input), return_tensors="pt") |
| | outputs = self.model(model_input['pixel_values'].to(self.model.device)) |
| | last_hidden_states = outputs.last_hidden_state |
| | return last_hidden_states.mean(dim=1).cpu() |