Spaces:
Running
Running
| import torch | |
| from transformers import AutoImageProcessor, AutoModel | |
| class DINOv2_S2RGB_Embedder(torch.nn.Module): | |
| """ | |
| Embedding wrapper for DINOv2 and Sentinel-2 data. | |
| This model uses the DINOv2 architecture to generate embeddings for Sentinel-2 RGB data. The input data (RGB bands) | |
| is preprocessed by normalizing and mapping it to true-color values. Then, it is passed through the DINOv2 model | |
| to obtain feature embeddings. | |
| Preprocessing: | |
| The input Sentinel-2 image is divided by 10,000 and multiplied by 2.5 to map it to a true-color image | |
| (normalized to the range [0, 1]), followed by processing using the DINOv2 image processor. | |
| Model: | |
| The DINOv2 model processes RGB input images of shape [224, 224] and produces embeddings, which are then | |
| averaged across the sequence dimension to obtain a fixed-size embedding vector. | |
| Model Components: | |
| - `AutoImageProcessor`: Preprocessing pipeline for handling Sentinel-2 data. | |
| - `AutoModel`: DINOv2 transformer model used for feature extraction. | |
| Attributes: | |
| processor (AutoImageProcessor): The DINOv2 image processor to handle preprocessing. | |
| model (AutoModel): The DINOv2 model used to generate embeddings from preprocessed images. | |
| bands (list): List of the Sentinel-2 bands used for RGB input (B04, B03, B02). | |
| size (tuple): The input size expected by the model (height, width) for the RGB image. | |
| """ | |
| def __init__(self): | |
| """ | |
| Initializes the DINOv2_S2RGB_Embedder by loading the pre-trained DINOv2 model and processor, | |
| and setting the expected input size for Sentinel-2 RGB data. | |
| This embedder uses the 'facebook/dinov2-base' model for feature extraction from Sentinel-2 | |
| true-color images (RGB). | |
| Attributes: | |
| processor (AutoImageProcessor): The DINOv2 image processor for preprocessing Sentinel-2 images. | |
| model (AutoModel): The pre-trained DINOv2 model for generating embeddings. | |
| bands (list): The Sentinel-2 bands used for RGB data (B04 - Red, B03 - Green, B02 - Blue). | |
| size (tuple): The expected input size of the image for the DINOv2 model (height, width). | |
| """ | |
| super().__init__() | |
| # Load the DINOv2 processor and model from Hugging Face | |
| self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') | |
| self.model = AutoModel.from_pretrained('facebook/dinov2-base') | |
| # Define the RGB bands for Sentinel-2 (B04, B03, B02) | |
| self.bands = ['B04', 'B03', 'B02'] | |
| # Extract the input size from the processor settings | |
| self.size = self.processor.crop_size['height'], self.processor.crop_size['width'] | |
| def normalize(self, input): | |
| """ | |
| Normalizes Sentinel-2 RGB data to true-color values. | |
| The input image (in raw Sentinel-2 reflectance values) is first divided by 10,000 to convert it | |
| to reflectance values in the range [0, 1]. Then, the result is multiplied by 2.5 to obtain true-color | |
| values that are suitable for input into the DINOv2 model. | |
| Args: | |
| input (torch.Tensor): The raw Sentinel-2 image tensor to be normalized. | |
| Returns: | |
| torch.Tensor: The normalized true-color image. | |
| """ | |
| return (2.5 * (input / 1e4)).clip(0,1) | |
| def forward(self, input): | |
| """ | |
| Forward pass through the model to generate embeddings for the input image. | |
| The input image is first normalized using the `normalize` method, then processed by the DINOv2 image processor | |
| and passed through the DINOv2 model to generate embeddings. The output from the model is averaged across | |
| the sequence dimension to obtain a fixed-size embedding. | |
| Args: | |
| input (torch.Tensor): The input Sentinel-2 image tensor with shape [C, H, W], where C=3 (RGB channels). | |
| Returns: | |
| torch.Tensor: The embedding vector, averaged over the sequence dimension, with shape [embedding_dim]. | |
| """ | |
| model_input = self.processor(self.normalize(input), return_tensors="pt") | |
| outputs = self.model(model_input['pixel_values'].to(self.model.device)) | |
| last_hidden_states = outputs.last_hidden_state | |
| return last_hidden_states.mean(dim=1).cpu() |