ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
from open_clip import create_model_from_pretrained, get_tokenizer
import torch
class SigLIP_S2RGB_Embedder(torch.nn.Module):
"""
Embedding wrapper for SigLIP and Sentinel-2 data.
This model processes Sentinel-2 RGB data and embeds it into a feature space using the DINOv@ transformer model.
The preprocessing includes normalizing Sentinel-2 values to create a True-Colour image before passing it through
the model. The final output is a high-dimensional feature vector representing the input image.
Preprocessing:
- Sentinel-2 bands are divided by 10,000 to scale the reflectance values.
- Then, the values are multiplied by 2.5 to map them into the [0, 1] range for True-Colour images.
- The model input is further processed using the DINOv@ preprocessor.
Model:
- Takes an RGB input of shape 384x384 pixels and produces an embedding vector.
"""
def __init__(self):
super().__init__()
# load model
self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
# Sentinel-2 RGB bands (B04 - Red, B03 - Green, B02 - Blue)
self.bands = ['B04', 'B03', 'B02']
self.size = self.preprocess.transforms[0].size
def normalize(self, input):
"""
Normalizes Sentinel-2 image data to create a True-Colour image.
Sentinel-2 images are scaled to reflectance values in the range [0, 1]. This function:
- Divides the input by 10,000 to scale Sentinel-2 values.
- Multiplies the result by 2.5 to map the values into the True-Colour image range.
Args:
input (torch.Tensor or np.ndarray): Input image with Sentinel-2 reflectance values.
Returns:
torch.Tensor: Normalized True-Colour image, clipped to the range [0, 1].
"""
return (2.5 * (input / 1e4)).clip(0,1)
def forward(self, input):
"""
Forward pass through the SigLIP model.
This method normalizes the input Sentinel-2 image to a True-Colour representation and processes it through
the model to obtain an embedding.
Args:
input (torch.Tensor): A Sentinel-2 image, typically of shape (C, H, W), where C=3 (RGB),
H=384, and W=384.
Returns:
torch.Tensor: The image embedding produced by the model.
"""
preprocess_input = self.normalize(input)
# normalization only
model_input = self.preprocess.transforms[-1](preprocess_input)
return self.model.encode_image(model_input)