Spaces:
Running
Running
Commit ·
29fab93
0
Parent(s):
sync from hf
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +54 -0
- .gitignore +10 -0
- MajorTOM/MajorTOMDataset.py +64 -0
- MajorTOM/__init__.py +5 -0
- MajorTOM/embedder/MajorTOM_Embedder.py +191 -0
- MajorTOM/embedder/__init__.py +2 -0
- MajorTOM/embedder/grid_cell_fragment.py +164 -0
- MajorTOM/embedder/models/DINOv2_S2RGB.py +91 -0
- MajorTOM/embedder/models/SSL4EO_S1RTC.py +125 -0
- MajorTOM/embedder/models/SSL4EO_S2L1C.py +97 -0
- MajorTOM/embedder/models/SigLIP_S2RGB.py +65 -0
- MajorTOM/embedder/models/__init__.py +4 -0
- MajorTOM/extras/coverage-example.png +3 -0
- MajorTOM/extras/coverage_vis.py +149 -0
- MajorTOM/extras/extract-sample-from-raw-S2.ipynb +0 -0
- MajorTOM/extras/thumbnail_dem.py +77 -0
- MajorTOM/extras/thumbnail_s1rtc.py +80 -0
- MajorTOM/extras/thumbnail_s2.py +68 -0
- MajorTOM/grid.py +284 -0
- MajorTOM/metadata_helpers.py +159 -0
- MajorTOM/sample_helpers.py +20 -0
- README.md +28 -0
- Tutorial.md +162 -0
- Tutorial_zh.md +157 -0
- app.py +792 -0
- configs/huggingface.yaml +12 -0
- countries.geo.json +0 -0
- data_utils.py +223 -0
- embedding_datasets/grid_sample_center_22k_FarSLIP_384x384.parquet +3 -0
- embedding_datasets/grid_sample_center_22k_SatCLIP_384x384.parquet +3 -0
- embedding_datasets/grid_sample_center_22k_SigLIP_384x384.parquet +3 -0
- embedding_datasets/grid_sample_metadata.parquet +3 -0
- embedding_datasets/zhejiang_sample_center_2k_FarSLIP_384x384.parquet +3 -0
- embedding_datasets/zhejiang_sample_center_2k_SatCLIP_384x384.parquet +3 -0
- embedding_datasets/zhejiang_sample_center_2k_SigLIP_384x384.parquet +3 -0
- embedding_datasets/zhejiang_sample_metadata.parquet +3 -0
- examples/example1.png +3 -0
- examples/example2.png +3 -0
- examples/example3.png +3 -0
- images/CLIP.png +3 -0
- images/Image_Search_Amazon.jpg +3 -0
- images/Image_Search_Middle_East.jpg +3 -0
- images/Location_Search_Amazon.jpg +3 -0
- images/Location_Search_Hangzhou.jpg +3 -0
- images/Text_Search.jpg +3 -0
- images/embedding.png +3 -0
- images/framework_en.png +3 -0
- images/framework_zh.png +3 -0
- images/samples.png +3 -0
- models/FarSLIP/.gitignore +160 -0
.gitattributes
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.gguf* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.ggml filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.llamafile* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
center_bbx_22k.parquet filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
embedding_datasets/center_bbx_22k_SigLIP_384x384.parquet filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
embedding_datasets/center_bbx_22k_FarSLIP_384x384.parquet filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
.gradio/
|
| 5 |
+
.vscode/
|
| 6 |
+
.DS_Store
|
| 7 |
+
checkpoints/
|
| 8 |
+
models/FarSLIP/assets
|
| 9 |
+
models/SatCLIP/figures
|
| 10 |
+
configs/local.yaml
|
MajorTOM/MajorTOMDataset.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import rasterio as rio
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torchvision.transforms as transforms
|
| 9 |
+
|
| 10 |
+
class MajorTOM(Dataset):
|
| 11 |
+
"""MajorTOM Dataset (https://huggingface.co/Major-TOM)
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
df ((geo)pandas.DataFrame): Metadata dataframe
|
| 15 |
+
local_dir (string): Root directory of the local dataset version
|
| 16 |
+
tif_bands (list): A list of tif file names to be read
|
| 17 |
+
png_bands (list): A list of png file names to be read
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self,
|
| 22 |
+
df,
|
| 23 |
+
local_dir = None,
|
| 24 |
+
tif_bands=['B04','B03','B02'],
|
| 25 |
+
png_bands=['thumbnail'],
|
| 26 |
+
tif_transforms=[transforms.ToTensor()],
|
| 27 |
+
png_transforms=[transforms.ToTensor()]
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.df = df
|
| 31 |
+
self.local_dir = Path(local_dir) if isinstance(local_dir,str) else local_dir
|
| 32 |
+
self.tif_bands = tif_bands if not isinstance(tif_bands,str) else [tif_bands]
|
| 33 |
+
self.png_bands = png_bands if not isinstance(png_bands,str) else [png_bands]
|
| 34 |
+
self.tif_transforms = transforms.Compose(tif_transforms) if tif_transforms is not None else None
|
| 35 |
+
self.png_transforms = transforms.Compose(png_transforms) if png_transforms is not None else None
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.df)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
meta = self.df.iloc[idx]
|
| 42 |
+
|
| 43 |
+
product_id = meta.product_id
|
| 44 |
+
grid_cell = meta.grid_cell
|
| 45 |
+
row = grid_cell.split('_')[0]
|
| 46 |
+
|
| 47 |
+
path = self.local_dir / Path("{}/{}/{}".format(row, grid_cell, product_id))
|
| 48 |
+
out_dict = {'meta' : meta}
|
| 49 |
+
|
| 50 |
+
for band in self.tif_bands:
|
| 51 |
+
with rio.open(path / '{}.tif'.format(band)) as f:
|
| 52 |
+
out = f.read()
|
| 53 |
+
if self.tif_transforms is not None:
|
| 54 |
+
out = self.tif_transforms(out)
|
| 55 |
+
out_dict[band] = out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
for band in self.png_bands:
|
| 59 |
+
out = Image.open(path / '{}.png'.format(band))
|
| 60 |
+
if self.png_transforms is not None:
|
| 61 |
+
out = self.png_transforms(out)
|
| 62 |
+
out_dict[band] = out
|
| 63 |
+
|
| 64 |
+
return out_dict
|
MajorTOM/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .sample_helpers import *
|
| 2 |
+
from .metadata_helpers import *
|
| 3 |
+
from .MajorTOMDataset import *
|
| 4 |
+
from .grid import *
|
| 5 |
+
from .embedder import *
|
MajorTOM/embedder/MajorTOM_Embedder.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import geopandas as gpd
|
| 3 |
+
import hashlib
|
| 4 |
+
from rasterio.io import MemoryFile
|
| 5 |
+
|
| 6 |
+
from .grid_cell_fragment import *
|
| 7 |
+
from .models import *
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
class MajorTOM_Embedder(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
MajorTOM Embedder class that applies a model to geospatial image fragments,
|
| 13 |
+
computes embeddings, and returns metadata for each fragment.
|
| 14 |
+
|
| 15 |
+
This class is designed to work with raster data, where the image is fragmented
|
| 16 |
+
into smaller tiles, and embeddings are computed for each tile using the provided
|
| 17 |
+
embedder model. The output is a GeoDataFrame containing spatial metadata and
|
| 18 |
+
the corresponding embeddings for each tile.
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
embedder: A model that generates embeddings for image fragments.
|
| 22 |
+
frag_params: Dictionary containing fragmentation parameters such as the
|
| 23 |
+
target overlap and border shift.
|
| 24 |
+
column_types: Dictionary specifying data types for the output GeoDataFrame columns.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, embedder, target_overlap=0.1, border_shift=True):
|
| 28 |
+
"""
|
| 29 |
+
Initializes the MajorTOM Embedder with the given parameters.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
embedder (torch.nn.Module): A model that generates embeddings for image fragments.
|
| 33 |
+
target_overlap (float): The target overlap between image fragments. Default is 0.1.
|
| 34 |
+
border_shift (bool): Whether to shift the borders of fragments to avoid edge artifacts. Default is True.
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
# Model
|
| 39 |
+
self.embedder = embedder
|
| 40 |
+
|
| 41 |
+
# Fragmentation Settings
|
| 42 |
+
self.frag_params = params = {
|
| 43 |
+
'fragment_size' : self.embedder.size[0],
|
| 44 |
+
'target_overlap' : target_overlap,
|
| 45 |
+
'border_shift' : border_shift
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Data types for the output dataframe (commented columns need no conversion)
|
| 49 |
+
self.column_types = {
|
| 50 |
+
#'unique_id' :,
|
| 51 |
+
#'embedding' : ,
|
| 52 |
+
#'timestamp' : ,
|
| 53 |
+
#'product_id' : ,
|
| 54 |
+
#'grid_cell' : ,
|
| 55 |
+
'grid_row_u' : 'int16',
|
| 56 |
+
'grid_col_r' : 'int16',
|
| 57 |
+
'centre_lat' : 'float32',
|
| 58 |
+
'centre_lon' : 'float32',
|
| 59 |
+
#'utm_footprint' : ,
|
| 60 |
+
#'utm_crs' : ,
|
| 61 |
+
#'pixel_bbox' : ,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def bands(self):
|
| 65 |
+
"""
|
| 66 |
+
Returns the set of input bands in the correct order.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
list: List of input bands used by the embedder.
|
| 70 |
+
"""
|
| 71 |
+
return self.embedder.bands
|
| 72 |
+
|
| 73 |
+
def size(self):
|
| 74 |
+
"""
|
| 75 |
+
Returns the input image size.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
tuple: Tuple representing the image size (height, width).
|
| 79 |
+
"""
|
| 80 |
+
return self.embedder.size
|
| 81 |
+
|
| 82 |
+
def calculate_checksum(self, geometry, timestamp, product_id, embedding):
|
| 83 |
+
"""
|
| 84 |
+
Calculates a checksum for the given geometry, timestamp, product ID, and embedding.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
geometry (shapely.geometry): The geometry object representing the fragment's footprint.
|
| 88 |
+
timestamp (str): Timestamp of the data.
|
| 89 |
+
product_id (str): Product identifier.
|
| 90 |
+
embedding (np.ndarray): The embedding of the image fragment.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
str: A SHA256 checksum of the concatenated input parameters.
|
| 94 |
+
"""
|
| 95 |
+
combined = f"{geometry}_{timestamp}_{product_id}_{embedding}"
|
| 96 |
+
checksum = hashlib.sha256(combined.encode()).hexdigest()
|
| 97 |
+
return checksum
|
| 98 |
+
|
| 99 |
+
def _read_image(self, row):
|
| 100 |
+
"""
|
| 101 |
+
Reads and processes the image bands for a given row, performs optional upsampling
|
| 102 |
+
if the resolution is mismatched, and returns the image data, footprint, and CRS.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
row (pandas.Series): The input row containing the image bands.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
torch.Tensor: A tensor containing the stacked image bands.
|
| 109 |
+
shapely.geometry: The footprint of the image.
|
| 110 |
+
rasterio.crs.CRS: The CRS of the image.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
# Read the file
|
| 114 |
+
img = []
|
| 115 |
+
for band in self.embedder.bands:
|
| 116 |
+
with MemoryFile(row[band][0].as_py()) as mem_f:
|
| 117 |
+
with mem_f.open(driver='GTiff') as f:
|
| 118 |
+
crs = f.crs
|
| 119 |
+
footprint = box(*f.bounds)
|
| 120 |
+
img.append(f.read()[0])
|
| 121 |
+
|
| 122 |
+
# optional upsampling
|
| 123 |
+
shapes = [layer.shape for layer in img]
|
| 124 |
+
if any([el!=shapes[0] for el in shapes]): # if any resolution mismatch
|
| 125 |
+
h, w = max([el[0] for el in shapes]), max([el[1] for el in shapes]) # maximum size
|
| 126 |
+
for layer_idx, layer in enumerate(img):
|
| 127 |
+
if layer.shape != (h,w):
|
| 128 |
+
img[layer_idx] = cv2.resize(layer, (h,w), interpolation=cv2.INTER_NEAREST)
|
| 129 |
+
img = torch.from_numpy(np.stack(img,-1).astype(np.float32))
|
| 130 |
+
|
| 131 |
+
return img, footprint, crs
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def forward(self, row, row_meta, device='cuda'):
|
| 135 |
+
"""
|
| 136 |
+
Forward pass of the model: Reads the image, fragments it, computes embeddings
|
| 137 |
+
for each fragment, and returns a GeoDataFrame with the spatial metadata and
|
| 138 |
+
embeddings.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
row (pandas.Series): The input row containing the image data.
|
| 142 |
+
row_meta (pandas.Series): Metadata associated with the row (e.g., timestamp, product_id).
|
| 143 |
+
device (str): The device to run the model on ('cpu' or 'cuda'). Default is 'cuda'.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
geopandas.GeoDataFrame: A GeoDataFrame containing metadata and embeddings for each fragment.
|
| 147 |
+
"""
|
| 148 |
+
# Read file
|
| 149 |
+
img, footprint, crs = self._read_image(row)
|
| 150 |
+
|
| 151 |
+
# Fragment the sample
|
| 152 |
+
fragments, xys = fragment_fn(img, **self.frag_params, return_indices=True, verbose=False)
|
| 153 |
+
|
| 154 |
+
nrows, ncols, c, h, w = fragments.shape
|
| 155 |
+
# Apply the model
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
embeddings = self.embedder(fragments.reshape(-1,c,h,w).to(device)).view(nrows, ncols, -1)
|
| 158 |
+
|
| 159 |
+
df_rows = []
|
| 160 |
+
|
| 161 |
+
# Pack rows for geoparquet
|
| 162 |
+
for r_idx in range(nrows):
|
| 163 |
+
for c_idx in range(ncols):
|
| 164 |
+
embedding = embeddings[r_idx, c_idx].cpu().numpy()
|
| 165 |
+
# spatial features per fragment
|
| 166 |
+
x_offset,y_offset=xys[r_idx,c_idx].int().tolist()
|
| 167 |
+
pixel_bbox = [x_offset, y_offset, x_offset + h,y_offset + w] # in pixels
|
| 168 |
+
utm_footprint = crop_footprint(footprint, *img.shape[:2], pixel_bbox)
|
| 169 |
+
# main footprint is in WGS84 (needs to be consistent across parquet)
|
| 170 |
+
transformer = Transformer.from_crs(crs, CRS.from_epsg(4326), always_xy=True)
|
| 171 |
+
geometry = transform(transformer.transform, utm_footprint) # WGS84
|
| 172 |
+
centre_lon, centre_lat = geometry.centroid.coords[0]
|
| 173 |
+
|
| 174 |
+
row_dict = {
|
| 175 |
+
'unique_id' : self.calculate_checksum(geometry, row_meta.timestamp.item(), row_meta.product_id.item(), embedding),
|
| 176 |
+
'embedding' : embedding,
|
| 177 |
+
'timestamp' : row_meta.timestamp.item(),
|
| 178 |
+
'product_id' : row_meta.product_id.item(),
|
| 179 |
+
'grid_cell' : row_meta.grid_cell.item(),
|
| 180 |
+
'grid_row_u' : row_meta.grid_row_u.item(),
|
| 181 |
+
'grid_col_r' : row_meta.grid_col_r.item(),
|
| 182 |
+
'geometry' : geometry,
|
| 183 |
+
'centre_lat' : centre_lat,
|
| 184 |
+
'centre_lon' : centre_lon,
|
| 185 |
+
'utm_footprint' : utm_footprint.wkt,
|
| 186 |
+
'utm_crs' : crs.to_string(),
|
| 187 |
+
'pixel_bbox' : pixel_bbox,
|
| 188 |
+
}
|
| 189 |
+
df_rows.append(row_dict)
|
| 190 |
+
|
| 191 |
+
return gpd.GeoDataFrame(df_rows).astype(self.column_types)
|
MajorTOM/embedder/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .MajorTOM_Embedder import *
|
| 2 |
+
from .grid_cell_fragment import *
|
MajorTOM/embedder/grid_cell_fragment.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from shapely.ops import transform
|
| 5 |
+
from pyproj import CRS, Transformer
|
| 6 |
+
import geopandas as gpd
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
from shapely.geometry import Polygon, box
|
| 10 |
+
from rasterio.transform import from_bounds, xy
|
| 11 |
+
#from rasterio.windows import Window, from_bounds
|
| 12 |
+
import rasterio as rio
|
| 13 |
+
|
| 14 |
+
def crop_footprint(footprint, height, width, crop_bbox):
|
| 15 |
+
"""
|
| 16 |
+
Crops the given footprint to the specified bounding box.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
footprint (shapely.geometry.Polygon): The original footprint of the image or area.
|
| 20 |
+
height (int): Height of the image (in pixels).
|
| 21 |
+
width (int): Width of the image (in pixels).
|
| 22 |
+
crop_bbox (list): The bounding box to crop the footprint. The format is
|
| 23 |
+
[col_start, row_start, col_end, row_end], where:
|
| 24 |
+
- col_start, row_start: top-left corner
|
| 25 |
+
- col_end, row_end: bottom-right corner
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
shapely.geometry.Polygon: The cropped bounding box in the same coordinate reference system (CRS) as the original footprint.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
transform = from_bounds(*footprint.bounds, width, height)
|
| 32 |
+
|
| 33 |
+
# Convert pixel coordinates (col, row) to spatial coordinates (e.g., UTM)
|
| 34 |
+
# Using the raster's affine transform
|
| 35 |
+
min_x, min_y = transform * (crop_bbox[0], crop_bbox[1]) # (col_start, row_start)
|
| 36 |
+
max_x, max_y = transform * (crop_bbox[2], crop_bbox[3]) # (col_end, row_end)
|
| 37 |
+
|
| 38 |
+
# Create a Shapely polygon for the crop's bounding box in UTM
|
| 39 |
+
return box(min_x, min_y, max_x, max_y)
|
| 40 |
+
|
| 41 |
+
def fragment_unfold(image,fragment_size,overlap):
|
| 42 |
+
"""
|
| 43 |
+
Unfold operation for a fragment with overlap. This function extracts image patches (fragments) with a specified
|
| 44 |
+
size and overlap between them.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
image (torch.Tensor or np.ndarray): The input image to be fragmented (height, width, channels).
|
| 48 |
+
fragment_size (int or list): The size of each fragment. Can be a single integer for square fragments or
|
| 49 |
+
a list of two integers for non-square fragments.
|
| 50 |
+
overlap (int or list): The overlap between adjacent fragments. Can be a single integer or a list of two integers.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
torch.Tensor: The unfolded fragments of the image, each with the specified size and overlap.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# Convert image to a tensor and reorder dimensions if necessary
|
| 57 |
+
if not torch.is_tensor(image):
|
| 58 |
+
image = torch.from_numpy(image).permute(2, 0, 1) # Rearrange to (channels, height, width)
|
| 59 |
+
if len(image.shape) < 4:
|
| 60 |
+
image = image.unsqueeze(0) # Add batch dimension
|
| 61 |
+
|
| 62 |
+
b, c, h, w = image.shape
|
| 63 |
+
|
| 64 |
+
# Ensure fragment size is a list
|
| 65 |
+
if isinstance(fragment_size, int):
|
| 66 |
+
fragment_size = [fragment_size, fragment_size]
|
| 67 |
+
if isinstance(overlap, int):
|
| 68 |
+
overlap = [overlap, overlap]
|
| 69 |
+
|
| 70 |
+
# Calculate stride based on fragment size and overlap
|
| 71 |
+
stride = [f - o for f, o in zip(fragment_size, overlap)]
|
| 72 |
+
|
| 73 |
+
# Perform the unfolding operation
|
| 74 |
+
uf = torch.nn.functional.unfold(image, fragment_size, dilation=1, padding=0, stride=stride)
|
| 75 |
+
|
| 76 |
+
# Reshape and permute to return the unfolded image fragments
|
| 77 |
+
return uf.view(b, c, *fragment_size, -1).permute(0, 4, 1, 2, 3)[0]
|
| 78 |
+
|
| 79 |
+
def fragment_fn(img,
|
| 80 |
+
fragment_size,
|
| 81 |
+
target_overlap,
|
| 82 |
+
border_shift=True, # determines whether the outer border is shifted to ensure full coverage
|
| 83 |
+
return_indices=False,
|
| 84 |
+
verbose=False
|
| 85 |
+
):
|
| 86 |
+
"""
|
| 87 |
+
Fragment an image into smaller patches with a specified fragment size and overlap.
|
| 88 |
+
|
| 89 |
+
This function handles different scenarios based on image size, fragment size, and overlap,
|
| 90 |
+
and creates fragments from the input image accordingly. It also supports shifting the outer
|
| 91 |
+
border of fragments to ensure full coverage of the image.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
img (np.ndarray or torch.Tensor): The input image to be fragmented (height, width, channels).
|
| 95 |
+
fragment_size (int or list): The size of the fragments. Can be a single integer (square) or a list of two integers (non-square).
|
| 96 |
+
target_overlap (float): The target overlap between adjacent fragments, in pixels.
|
| 97 |
+
border_shift (bool): Whether to shift the border of fragments to ensure full coverage of the image. Default is True.
|
| 98 |
+
return_indices (bool): If True, the function will also return the indices (offsets) for each fragment. Default is False.
|
| 99 |
+
verbose (bool): If True, the function will print additional details about the overlap. Default is False.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
torch.Tensor or tuple:
|
| 103 |
+
- If `return_indices` is False, a tensor containing the image fragments.
|
| 104 |
+
- If `return_indices` is True, a tuple of the image fragments and their offsets.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
h,w,c=img.shape
|
| 108 |
+
|
| 109 |
+
assert h==w # SQUARE IMAGES SUPPORT ONLY
|
| 110 |
+
|
| 111 |
+
hf, wf = fragment_size, fragment_size
|
| 112 |
+
ho, wo = target_overlap*hf, target_overlap*wf
|
| 113 |
+
|
| 114 |
+
assert h >= hf and w >= wf # reject Scenario 1
|
| 115 |
+
|
| 116 |
+
# Scenario 2
|
| 117 |
+
if h == hf or w == wf:
|
| 118 |
+
if not torch.is_tensor(img):
|
| 119 |
+
img=torch.from_numpy(img).permute(2,0,1)
|
| 120 |
+
return img.view(1,1,c,h,w)
|
| 121 |
+
|
| 122 |
+
# Scenario 3 & 4
|
| 123 |
+
|
| 124 |
+
# determine number of segments between the centers of outermost fragments
|
| 125 |
+
h_n = max(1, int(np.round((h-hf)/(hf-ho))))
|
| 126 |
+
w_n = max(1, int(np.round((w-wf)/(wf-wo))))
|
| 127 |
+
|
| 128 |
+
# adjust practical overlap (divide the distance between the centers of outermost fragments by the true number of segments)
|
| 129 |
+
aho = int(np.ceil(hf-(h-hf)/(h_n)))
|
| 130 |
+
awo = int(np.ceil(wf-(w-wf)/(w_n)))
|
| 131 |
+
|
| 132 |
+
# compute fragments (might not exactly fill the outermost border)
|
| 133 |
+
topleft = fragment_unfold(img.permute(2,0,1),fragment_size=(hf,wf), overlap=(aho,awo)).view(1+h_n, 1+w_n, c, hf, wf)
|
| 134 |
+
|
| 135 |
+
full = topleft
|
| 136 |
+
|
| 137 |
+
if border_shift:
|
| 138 |
+
|
| 139 |
+
if h > hf+h_n*(hf-aho) or w > wf+w_n*(wf-awo):
|
| 140 |
+
#print('Outers...')
|
| 141 |
+
bottomleft = fragment_unfold(img[-hf:,:,:],fragment_size=(hf,wf), overlap=(aho,awo)).view(1,1+w_n,c,hf,wf)
|
| 142 |
+
topright = fragment_unfold(img[:,-wf:,:],fragment_size=(hf,wf), overlap=(aho,awo)).view(1+h_n,1,c,hf,wf)
|
| 143 |
+
|
| 144 |
+
# Shift last row and col to the border of the original
|
| 145 |
+
full[:,-1,None] = topright
|
| 146 |
+
full[-1] = bottomleft
|
| 147 |
+
|
| 148 |
+
if verbose:
|
| 149 |
+
print('Target Overlap: {} pixels. Feasible Overlap: {} pixels.'.format(ho,aho))
|
| 150 |
+
|
| 151 |
+
if not return_indices:
|
| 152 |
+
return full
|
| 153 |
+
else:
|
| 154 |
+
offset=-1*torch.ones(*full.shape[:2],2)
|
| 155 |
+
for ridx in range(full.shape[0]):
|
| 156 |
+
for cidx in range(full.shape[1]):
|
| 157 |
+
offset[ridx,cidx,1] = cidx * (hf-aho)
|
| 158 |
+
offset[ridx,cidx,0] = ridx * (wf-awo)
|
| 159 |
+
|
| 160 |
+
if border_shift:
|
| 161 |
+
offset[ridx,-1,1] = h-hf
|
| 162 |
+
offset[-1,cidx,0] = w-wf
|
| 163 |
+
|
| 164 |
+
return full,offset
|
MajorTOM/embedder/models/DINOv2_S2RGB.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 3 |
+
|
| 4 |
+
class DINOv2_S2RGB_Embedder(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Embedding wrapper for DINOv2 and Sentinel-2 data.
|
| 7 |
+
|
| 8 |
+
This model uses the DINOv2 architecture to generate embeddings for Sentinel-2 RGB data. The input data (RGB bands)
|
| 9 |
+
is preprocessed by normalizing and mapping it to true-color values. Then, it is passed through the DINOv2 model
|
| 10 |
+
to obtain feature embeddings.
|
| 11 |
+
|
| 12 |
+
Preprocessing:
|
| 13 |
+
The input Sentinel-2 image is divided by 10,000 and multiplied by 2.5 to map it to a true-color image
|
| 14 |
+
(normalized to the range [0, 1]), followed by processing using the DINOv2 image processor.
|
| 15 |
+
|
| 16 |
+
Model:
|
| 17 |
+
The DINOv2 model processes RGB input images of shape [224, 224] and produces embeddings, which are then
|
| 18 |
+
averaged across the sequence dimension to obtain a fixed-size embedding vector.
|
| 19 |
+
|
| 20 |
+
Model Components:
|
| 21 |
+
- `AutoImageProcessor`: Preprocessing pipeline for handling Sentinel-2 data.
|
| 22 |
+
- `AutoModel`: DINOv2 transformer model used for feature extraction.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
processor (AutoImageProcessor): The DINOv2 image processor to handle preprocessing.
|
| 26 |
+
model (AutoModel): The DINOv2 model used to generate embeddings from preprocessed images.
|
| 27 |
+
bands (list): List of the Sentinel-2 bands used for RGB input (B04, B03, B02).
|
| 28 |
+
size (tuple): The input size expected by the model (height, width) for the RGB image.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""
|
| 33 |
+
Initializes the DINOv2_S2RGB_Embedder by loading the pre-trained DINOv2 model and processor,
|
| 34 |
+
and setting the expected input size for Sentinel-2 RGB data.
|
| 35 |
+
|
| 36 |
+
This embedder uses the 'facebook/dinov2-base' model for feature extraction from Sentinel-2
|
| 37 |
+
true-color images (RGB).
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
processor (AutoImageProcessor): The DINOv2 image processor for preprocessing Sentinel-2 images.
|
| 41 |
+
model (AutoModel): The pre-trained DINOv2 model for generating embeddings.
|
| 42 |
+
bands (list): The Sentinel-2 bands used for RGB data (B04 - Red, B03 - Green, B02 - Blue).
|
| 43 |
+
size (tuple): The expected input size of the image for the DINOv2 model (height, width).
|
| 44 |
+
"""
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
# Load the DINOv2 processor and model from Hugging Face
|
| 48 |
+
self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
| 49 |
+
self.model = AutoModel.from_pretrained('facebook/dinov2-base')
|
| 50 |
+
|
| 51 |
+
# Define the RGB bands for Sentinel-2 (B04, B03, B02)
|
| 52 |
+
self.bands = ['B04', 'B03', 'B02']
|
| 53 |
+
|
| 54 |
+
# Extract the input size from the processor settings
|
| 55 |
+
self.size = self.processor.crop_size['height'], self.processor.crop_size['width']
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def normalize(self, input):
|
| 59 |
+
"""
|
| 60 |
+
Normalizes Sentinel-2 RGB data to true-color values.
|
| 61 |
+
|
| 62 |
+
The input image (in raw Sentinel-2 reflectance values) is first divided by 10,000 to convert it
|
| 63 |
+
to reflectance values in the range [0, 1]. Then, the result is multiplied by 2.5 to obtain true-color
|
| 64 |
+
values that are suitable for input into the DINOv2 model.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
input (torch.Tensor): The raw Sentinel-2 image tensor to be normalized.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
torch.Tensor: The normalized true-color image.
|
| 71 |
+
"""
|
| 72 |
+
return (2.5 * (input / 1e4)).clip(0,1)
|
| 73 |
+
|
| 74 |
+
def forward(self, input):
|
| 75 |
+
"""
|
| 76 |
+
Forward pass through the model to generate embeddings for the input image.
|
| 77 |
+
|
| 78 |
+
The input image is first normalized using the `normalize` method, then processed by the DINOv2 image processor
|
| 79 |
+
and passed through the DINOv2 model to generate embeddings. The output from the model is averaged across
|
| 80 |
+
the sequence dimension to obtain a fixed-size embedding.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
input (torch.Tensor): The input Sentinel-2 image tensor with shape [C, H, W], where C=3 (RGB channels).
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
torch.Tensor: The embedding vector, averaged over the sequence dimension, with shape [embedding_dim].
|
| 87 |
+
"""
|
| 88 |
+
model_input = self.processor(self.normalize(input), return_tensors="pt")
|
| 89 |
+
outputs = self.model(model_input['pixel_values'].to(self.model.device))
|
| 90 |
+
last_hidden_states = outputs.last_hidden_state
|
| 91 |
+
return last_hidden_states.mean(dim=1).cpu()
|
MajorTOM/embedder/models/SSL4EO_S1RTC.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchgeo.models import ResNet50_Weights
|
| 3 |
+
import timm
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class SSL4EO_S1RTC_Embedder(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
SSL4EO Embedder for Sentinel-1 data using a pre-trained model.
|
| 9 |
+
|
| 10 |
+
This model is based on the SSL4EO (Self-Supervised Learning for Earth Observation) approach,
|
| 11 |
+
using a pre-trained ResNet50 model for Sentinel-1 radar data (SAR). The model is fine-tuned
|
| 12 |
+
to work with Sentinel-1 data and can be used directly for feature extraction.
|
| 13 |
+
|
| 14 |
+
Project Code:
|
| 15 |
+
https://github.com/zhu-xlab/SSL4EO-S12
|
| 16 |
+
|
| 17 |
+
Publication:
|
| 18 |
+
https://arxiv.org/abs/2211.07044
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, s1_mean=[-12.54847273, -20.19237134], s1_std=[5.25697717,5.91150917]):
|
| 22 |
+
"""
|
| 23 |
+
Initializes the SSL4EO_S1RTC_Embedder by setting up the mean and standard deviation for Sentinel-1 data normalization,
|
| 24 |
+
and loading the pre-trained model.
|
| 25 |
+
|
| 26 |
+
The model uses a pre-trained ResNet50 architecture adapted for Sentinel-1 radar (SAR) data, with weights provided
|
| 27 |
+
by the `torchgeo` library. The `s1_mean` and `s1_std` are used for normalizing the input data to the model.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
s1_mean (list, optional): Mean values for Sentinel-1 radar (SAR) data. Default is set to SSL4EO's values.
|
| 31 |
+
s1_std (list, optional): Standard deviation values for Sentinel-1 radar (SAR) data. Default is set to SSL4EO's values.
|
| 32 |
+
|
| 33 |
+
Attributes:
|
| 34 |
+
s1_mean (torch.FloatTensor): Mean values for normalization.
|
| 35 |
+
s1_std (torch.FloatTensor): Standard deviation values for normalization.
|
| 36 |
+
model (torch.nn.Module): The ResNet50 model initialized with pre-trained weights.
|
| 37 |
+
bands (list): List of Sentinel-1 bands used for input data (VV, VH).
|
| 38 |
+
size (tuple): The input size expected by the model (224x224 pixels).
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.s1_mean = torch.FloatTensor(s1_mean)
|
| 43 |
+
self.s1_std = torch.FloatTensor(s1_std)
|
| 44 |
+
|
| 45 |
+
# load model
|
| 46 |
+
self.model = self.init_model()
|
| 47 |
+
self.bands = ['vv','vh']
|
| 48 |
+
self.size = 224,224
|
| 49 |
+
|
| 50 |
+
def init_model(self):
|
| 51 |
+
"""
|
| 52 |
+
Initializes the ResNet50 model with pre-trained weights for Sentinel-1 data.
|
| 53 |
+
|
| 54 |
+
This method loads the pre-trained model weights for Sentinel-1 data from `ResNet50_Weights.SENTINEL1_ALL_MOCO`
|
| 55 |
+
and sets the fully connected layer (`fc`) to an identity function to output embeddings directly from the last
|
| 56 |
+
convolutional layer.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
torch.nn.Module: The initialized ResNet50 model.
|
| 60 |
+
"""
|
| 61 |
+
weights = ResNet50_Weights.SENTINEL1_ALL_MOCO
|
| 62 |
+
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
|
| 63 |
+
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
| 64 |
+
model.fc=torch.nn.Identity()
|
| 65 |
+
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
def normalize(self, img,scale=1.0):
|
| 69 |
+
"""
|
| 70 |
+
Normalizes the Sentinel-1 SAR (Synthetic Aperture Radar) data.
|
| 71 |
+
|
| 72 |
+
This method normalizes the Sentinel-1 radar signals using the mean (`s1_mean`)
|
| 73 |
+
and standard deviation (`s1_std`) values. The radar data is normalized to a
|
| 74 |
+
standard range, and the pixel values are scaled using a factor (`scale`).
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
img (torch.Tensor): The input Sentinel-1 image to be normalized.
|
| 78 |
+
scale (float, optional): The scaling factor for the normalized image. Default is 1.0.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
torch.Tensor: The normalized and scaled image.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
min_value = (self.s1_mean - 2 * self.s1_std).to(img.device)
|
| 86 |
+
max_value = (self.s1_mean + 2 * self.s1_std).to(img.device)
|
| 87 |
+
img = (img - min_value[:,None,None]) / (max_value - min_value)[:,None,None] * scale
|
| 88 |
+
img = img.clip(0,scale).float()
|
| 89 |
+
|
| 90 |
+
return img
|
| 91 |
+
|
| 92 |
+
def preprocess(self, input):
|
| 93 |
+
"""
|
| 94 |
+
Preprocesses the Sentinel-1 SAR (Synthetic Aperture Radar) data before feeding it into the model.
|
| 95 |
+
|
| 96 |
+
This method applies a logarithmic transformation to the input image to convert
|
| 97 |
+
it from linear scale to decibel (dB) scale. The image is clipped to avoid
|
| 98 |
+
logarithm of zero and then normalized using the `normalize` method.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
input (torch.Tensor): The input Sentinel-1 image (e.g., VV or VH polarization).
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
torch.Tensor: The preprocessed and normalized image in dB scale.
|
| 105 |
+
"""
|
| 106 |
+
# Convert the input from linear scale to decibel (dB) scale
|
| 107 |
+
dB_input = 10 * input.log10(input.clip(min=1e-10)) # Clip to prevent log(0)
|
| 108 |
+
|
| 109 |
+
# Normalize the dB-scaled image
|
| 110 |
+
return self.normalize(dB_input)
|
| 111 |
+
|
| 112 |
+
def forward(self, input):
|
| 113 |
+
"""
|
| 114 |
+
Forward pass through the model.
|
| 115 |
+
|
| 116 |
+
The input image is preprocessed using the `preprocess` method and then passed
|
| 117 |
+
through the ResNet50 model to obtain an embedding.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
input (torch.Tensor): Preprocessed Sentinel-1 image (e.g., shape: [C, H, W]).
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
torch.Tensor: The output embedding from the model.
|
| 124 |
+
"""
|
| 125 |
+
return self.model(self.preprocess(input))
|
MajorTOM/embedder/models/SSL4EO_S2L1C.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchgeo.models import ResNet50_Weights
|
| 3 |
+
import timm
|
| 4 |
+
|
| 5 |
+
class SSL4EO_S2L1C_Embedder(torch.nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
SSL4EO Embedder for Sentinel-2 data using a pre-trained model.
|
| 8 |
+
|
| 9 |
+
This model is based on the SSL4EO (Self-Supervised Learning for Earth Observation) approach,
|
| 10 |
+
using a pre-trained ResNet50 model for Sentinel-2 data. The model is fine-tuned for Sentinel-2
|
| 11 |
+
images and can be used directly for feature extraction.
|
| 12 |
+
|
| 13 |
+
Project Code:
|
| 14 |
+
https://github.com/zhu-xlab/SSL4EO-S12
|
| 15 |
+
|
| 16 |
+
Publication:
|
| 17 |
+
https://arxiv.org/abs/2211.07044
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
"""
|
| 24 |
+
Initializes the SSL4EO_S2L1C_Embedder by loading the pre-trained SSL4EO model.
|
| 25 |
+
|
| 26 |
+
The model uses ResNet50 architecture, adapted for Sentinel-2 data with a specific
|
| 27 |
+
weight configuration (`ResNet50_Weights.SENTINEL2_ALL_DINO`) provided by `torchgeo`.
|
| 28 |
+
It also defines the bands used for Sentinel-2 data and sets the input image size to
|
| 29 |
+
224x224 pixels (the model input size).
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
model (torch.nn.Module): The ResNet50 model with pre-trained weights for Sentinel-2 data.
|
| 33 |
+
bands (list): List of Sentinel-2 bands used for input data.
|
| 34 |
+
size (tuple): The input image size expected by the model, set to 224x224 pixels.
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
# Load the pre-trained SSL4EO ResNet50 model
|
| 39 |
+
self.model = self.init_model()
|
| 40 |
+
|
| 41 |
+
# Define the Sentinel-2 L1C bands (e.g., B01, B02, B03, etc.)
|
| 42 |
+
self.bands = [
|
| 43 |
+
'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07',
|
| 44 |
+
'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Define the expected input size of the model
|
| 48 |
+
self.size = 224, 224
|
| 49 |
+
|
| 50 |
+
def init_model(self):
|
| 51 |
+
"""
|
| 52 |
+
Initializes the ResNet50 model with pre-trained weights for Sentinel-2 data.
|
| 53 |
+
|
| 54 |
+
The model is loaded using the `timm` library, with Sentinel-2 specific weights
|
| 55 |
+
(`ResNet50_Weights.SENTINEL2_ALL_DINO`). The fully connected layer (`fc`) is replaced
|
| 56 |
+
with an identity function to obtain embeddings directly from the last convolutional
|
| 57 |
+
layer.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
torch.nn.Module: The initialized ResNet50 model.
|
| 61 |
+
"""
|
| 62 |
+
weights = ResNet50_Weights.SENTINEL2_ALL_DINO
|
| 63 |
+
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
|
| 64 |
+
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
| 65 |
+
model.fc=torch.nn.Identity()
|
| 66 |
+
|
| 67 |
+
return model
|
| 68 |
+
|
| 69 |
+
def preprocess(self, input):
|
| 70 |
+
"""
|
| 71 |
+
Preprocesses the Sentinel-2 input data for the model.
|
| 72 |
+
|
| 73 |
+
This function normalizes the input image by dividing the pixel values by 10,000.
|
| 74 |
+
This scaling step ensures that the reflectance values are mapped into an appropriate
|
| 75 |
+
range for the model.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
input (torch.Tensor): Input image with Sentinel-2 reflectance values (e.g., shape: [C, H, W]).
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
torch.Tensor: Preprocessed input, scaled by a factor of 10,000.
|
| 82 |
+
"""
|
| 83 |
+
return input / 1e4
|
| 84 |
+
|
| 85 |
+
def forward(self, input):
|
| 86 |
+
"""
|
| 87 |
+
Forward pass through the model.
|
| 88 |
+
|
| 89 |
+
The input image is preprocessed and then passed through the ResNet50 model to obtain the embedding.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
input (torch.Tensor): Preprocessed Sentinel-2 image (e.g., shape: [C, H, W]).
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
torch.Tensor: The output embedding from the model.
|
| 96 |
+
"""
|
| 97 |
+
return self.model(self.preprocess(input))
|
MajorTOM/embedder/models/SigLIP_S2RGB.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from open_clip import create_model_from_pretrained, get_tokenizer
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class SigLIP_S2RGB_Embedder(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Embedding wrapper for SigLIP and Sentinel-2 data.
|
| 7 |
+
|
| 8 |
+
This model processes Sentinel-2 RGB data and embeds it into a feature space using the DINOv@ transformer model.
|
| 9 |
+
The preprocessing includes normalizing Sentinel-2 values to create a True-Colour image before passing it through
|
| 10 |
+
the model. The final output is a high-dimensional feature vector representing the input image.
|
| 11 |
+
|
| 12 |
+
Preprocessing:
|
| 13 |
+
- Sentinel-2 bands are divided by 10,000 to scale the reflectance values.
|
| 14 |
+
- Then, the values are multiplied by 2.5 to map them into the [0, 1] range for True-Colour images.
|
| 15 |
+
- The model input is further processed using the DINOv@ preprocessor.
|
| 16 |
+
|
| 17 |
+
Model:
|
| 18 |
+
- Takes an RGB input of shape 384x384 pixels and produces an embedding vector.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
# load model
|
| 25 |
+
self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
|
| 26 |
+
# Sentinel-2 RGB bands (B04 - Red, B03 - Green, B02 - Blue)
|
| 27 |
+
self.bands = ['B04', 'B03', 'B02']
|
| 28 |
+
self.size = self.preprocess.transforms[0].size
|
| 29 |
+
|
| 30 |
+
def normalize(self, input):
|
| 31 |
+
"""
|
| 32 |
+
Normalizes Sentinel-2 image data to create a True-Colour image.
|
| 33 |
+
|
| 34 |
+
Sentinel-2 images are scaled to reflectance values in the range [0, 1]. This function:
|
| 35 |
+
- Divides the input by 10,000 to scale Sentinel-2 values.
|
| 36 |
+
- Multiplies the result by 2.5 to map the values into the True-Colour image range.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
input (torch.Tensor or np.ndarray): Input image with Sentinel-2 reflectance values.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
torch.Tensor: Normalized True-Colour image, clipped to the range [0, 1].
|
| 43 |
+
"""
|
| 44 |
+
return (2.5 * (input / 1e4)).clip(0,1)
|
| 45 |
+
|
| 46 |
+
def forward(self, input):
|
| 47 |
+
"""
|
| 48 |
+
Forward pass through the SigLIP model.
|
| 49 |
+
|
| 50 |
+
This method normalizes the input Sentinel-2 image to a True-Colour representation and processes it through
|
| 51 |
+
the model to obtain an embedding.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
input (torch.Tensor): A Sentinel-2 image, typically of shape (C, H, W), where C=3 (RGB),
|
| 55 |
+
H=384, and W=384.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: The image embedding produced by the model.
|
| 59 |
+
"""
|
| 60 |
+
preprocess_input = self.normalize(input)
|
| 61 |
+
|
| 62 |
+
# normalization only
|
| 63 |
+
model_input = self.preprocess.transforms[-1](preprocess_input)
|
| 64 |
+
|
| 65 |
+
return self.model.encode_image(model_input)
|
MajorTOM/embedder/models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .SigLIP_S2RGB import *
|
| 2 |
+
from .DINOv2_S2RGB import *
|
| 3 |
+
from .SSL4EO_S2L1C import *
|
| 4 |
+
from .SSL4EO_S1RTC import *
|
MajorTOM/extras/coverage-example.png
ADDED
|
Git LFS Details
|
MajorTOM/extras/coverage_vis.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from mpl_toolkits.basemap import Basemap
|
| 5 |
+
import PIL
|
| 6 |
+
|
| 7 |
+
def get_mask(df):
|
| 8 |
+
"""
|
| 9 |
+
Take a Major TOM dataframe and create a mask corresponding to available cells
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
mask = np.zeros((2004,4008), dtype=np.uint8)
|
| 13 |
+
row_offset = -1002
|
| 14 |
+
col_offset = -2004
|
| 15 |
+
|
| 16 |
+
nodata = df['nodata'].values > 0.5
|
| 17 |
+
|
| 18 |
+
yy = mask.shape[0] - (np.array(df['grid_row_u']) - row_offset) - 1
|
| 19 |
+
xx = np.array(df['grid_col_r']) - col_offset
|
| 20 |
+
|
| 21 |
+
yy = yy[~nodata]
|
| 22 |
+
xx = xx[~nodata]
|
| 23 |
+
|
| 24 |
+
mask[yy, xx] = 255
|
| 25 |
+
|
| 26 |
+
return PIL.Image.fromarray(mask)
|
| 27 |
+
|
| 28 |
+
def fig2img(fig):
|
| 29 |
+
"""Convert a Matplotlib figure to a PIL Image and return it"""
|
| 30 |
+
import io
|
| 31 |
+
buf = io.BytesIO()
|
| 32 |
+
fig.savefig(buf)
|
| 33 |
+
buf.seek(0)
|
| 34 |
+
img = PIL.Image.open(buf)
|
| 35 |
+
return img
|
| 36 |
+
|
| 37 |
+
def light_basemap():
|
| 38 |
+
"""
|
| 39 |
+
Bright coloured contours
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
with plt.ioff():
|
| 43 |
+
fig, ax = plt.subplots(figsize=(48,24), dpi=167)
|
| 44 |
+
|
| 45 |
+
m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
|
| 46 |
+
m.fillcontinents(color="#9eba9b", lake_color='#CCDDFF')
|
| 47 |
+
m.drawmapboundary(fill_color="#CCDDFF")
|
| 48 |
+
m.drawcountries(color="#666666", linewidth=1)
|
| 49 |
+
m.drawcoastlines(color="#666666", linewidth=1)
|
| 50 |
+
|
| 51 |
+
plt.gca().set_axis_off()
|
| 52 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
|
| 53 |
+
hspace = 0, wspace = 0)
|
| 54 |
+
plt.margins(0,0)
|
| 55 |
+
|
| 56 |
+
return fig2img(fig)
|
| 57 |
+
|
| 58 |
+
def dark_basemap():
|
| 59 |
+
"""
|
| 60 |
+
Dark contours
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
with plt.ioff():
|
| 64 |
+
fig, ax = plt.subplots(figsize=(48,24), dpi=167)
|
| 65 |
+
|
| 66 |
+
m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
|
| 67 |
+
m.fillcontinents(color="#242424", lake_color='#242424')
|
| 68 |
+
m.drawmapboundary(fill_color="#242424")
|
| 69 |
+
m.drawcountries(color="#000000", linewidth=1)
|
| 70 |
+
m.drawcoastlines(color="#000000", linewidth=1)
|
| 71 |
+
|
| 72 |
+
plt.gca().set_axis_off()
|
| 73 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
|
| 74 |
+
hspace = 0, wspace = 0)
|
| 75 |
+
plt.margins(0,0)
|
| 76 |
+
|
| 77 |
+
return fig2img(fig)
|
| 78 |
+
|
| 79 |
+
def get_coveragemap(input, input2=None):
|
| 80 |
+
"""
|
| 81 |
+
Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
|
| 82 |
+
|
| 83 |
+
Optionally, input2 can be provided and then, the map plots a map with extra colours indicating cells available only in input (green) or only input2 (blue)
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
if input2 is None:
|
| 87 |
+
return single_coveragemap(input)
|
| 88 |
+
else:
|
| 89 |
+
cmap1 = single_coveragemap(input)
|
| 90 |
+
cmap2 = single_coveragemap(input2)
|
| 91 |
+
|
| 92 |
+
# arrays for mixing
|
| 93 |
+
inp1_arr = np.array(cmap1)[...,:3]
|
| 94 |
+
inp2_arr = np.array(cmap2)[...,:3]
|
| 95 |
+
|
| 96 |
+
common_arr = inp1_arr*(inp1_arr.sum(-1) == inp2_arr.sum(-1))[:,:,None]
|
| 97 |
+
common_arr[:,:,(1,2)] = 0
|
| 98 |
+
inp1_arr[:,:,(0,2)] = 0 # Green - indicates presence of S2 only
|
| 99 |
+
inp2_arr[:,:,(0,1)] = 0 # Blue - indicates presense of DEM only
|
| 100 |
+
|
| 101 |
+
return PIL.Image.fromarray(((common_arr + inp1_arr + inp2_arr)).astype(np.uint8))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def single_coveragemap(input):
|
| 105 |
+
"""
|
| 106 |
+
Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# compute mask if df is provided
|
| 110 |
+
if isinstance(input, pd.DataFrame):
|
| 111 |
+
mask = get_mask(input)
|
| 112 |
+
else:
|
| 113 |
+
mask = input
|
| 114 |
+
|
| 115 |
+
basemap = light_basemap()
|
| 116 |
+
basemap_d = dark_basemap()
|
| 117 |
+
|
| 118 |
+
outside_earth = np.array(basemap.convert('RGBA'))[:, :, 0] == 255
|
| 119 |
+
outside_earth = PIL.Image.fromarray(outside_earth)
|
| 120 |
+
|
| 121 |
+
mask = mask.resize(basemap.size, PIL.Image.NEAREST)
|
| 122 |
+
|
| 123 |
+
basemap.putalpha(mask)
|
| 124 |
+
|
| 125 |
+
# Mask outside of earth
|
| 126 |
+
basemap.paste(outside_earth, (0,0), outside_earth)
|
| 127 |
+
|
| 128 |
+
basemap_d.paste(basemap, (0,0), basemap)
|
| 129 |
+
|
| 130 |
+
return basemap_d
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
DATASET_NAME = 'Major-TOM/Core-S2L2A'
|
| 134 |
+
meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
|
| 135 |
+
df = pd.read_parquet(meta_path)
|
| 136 |
+
|
| 137 |
+
# This is how you make a coverage figure!
|
| 138 |
+
coverage_img = get_coveragemap(df)
|
| 139 |
+
|
| 140 |
+
coverage_img.save('coverage-example.png', format='PNG')
|
| 141 |
+
|
| 142 |
+
# and this is how you can create an overap for 2 datasets!
|
| 143 |
+
DATASET_NAME = 'Major-TOM/Core-DEM'
|
| 144 |
+
meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
|
| 145 |
+
dem_df = pd.read_parquet(meta_path)
|
| 146 |
+
|
| 147 |
+
coverage_img = get_coveragemap(df,dem_df)
|
| 148 |
+
|
| 149 |
+
coverage_img.save('overlap-coverage-example.png', format='PNG')
|
MajorTOM/extras/extract-sample-from-raw-S2.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
MajorTOM/extras/thumbnail_dem.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
|
| 3 |
+
|
| 4 |
+
Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from rasterio.io import MemoryFile
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import rasterio as rio
|
| 13 |
+
from matplotlib.colors import LightSource
|
| 14 |
+
|
| 15 |
+
def get_grayscale(x):
|
| 16 |
+
"""
|
| 17 |
+
Normalized grayscale visualisation
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# normalize
|
| 21 |
+
x_n = x-x.min()
|
| 22 |
+
x_n = x_n/x_n.max()
|
| 23 |
+
|
| 24 |
+
return np.uint8(x_n*255)
|
| 25 |
+
|
| 26 |
+
def get_hillshade(x, azdeg=315, altdeg=45,ve=1):
|
| 27 |
+
"""
|
| 28 |
+
Hillshade visualisation for DEM
|
| 29 |
+
"""
|
| 30 |
+
ls = LightSource(azdeg=azdeg, altdeg=altdeg)
|
| 31 |
+
|
| 32 |
+
return np.uint8(255*ls.hillshade(x, vert_exag=ve))
|
| 33 |
+
|
| 34 |
+
def dem_thumbnail(dem, dem_NODATA = -32768.0, hillshade=True):
|
| 35 |
+
"""
|
| 36 |
+
Takes vv and vh numpy arrays along with the corresponding NODATA values (default is -32768.0)
|
| 37 |
+
|
| 38 |
+
Returns a numpy array with the thumbnail
|
| 39 |
+
"""
|
| 40 |
+
if hillshade:
|
| 41 |
+
return get_hillshade(dem)
|
| 42 |
+
else:
|
| 43 |
+
return get_grayscale(dem)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def dem_thumbnail_from_datarow(datarow):
|
| 47 |
+
"""
|
| 48 |
+
Takes a datarow directly from one of the data parquet files
|
| 49 |
+
|
| 50 |
+
Returns a PIL Image
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
with MemoryFile(datarow['DEM'][0].as_py()) as mem_f:
|
| 54 |
+
with mem_f.open(driver='GTiff') as f:
|
| 55 |
+
dem=f.read().squeeze()
|
| 56 |
+
dem_NODATA = f.nodata
|
| 57 |
+
|
| 58 |
+
img = dem_thumbnail(dem, dem_NODATA)
|
| 59 |
+
|
| 60 |
+
return Image.fromarray(img,'L')
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
from fsspec.parquet import open_parquet_file
|
| 64 |
+
import pyarrow.parquet as pq
|
| 65 |
+
|
| 66 |
+
print('[example run] reading file from HuggingFace...')
|
| 67 |
+
url = "https://huggingface.co/datasets/Major-TOM/Core-DEM/resolve/main/images/part_01001.parquet"
|
| 68 |
+
with open_parquet_file(url) as f:
|
| 69 |
+
with pq.ParquetFile(f) as pf:
|
| 70 |
+
first_row_group = pf.read_row_group(1)
|
| 71 |
+
|
| 72 |
+
print('[example run] computing the thumbnail...')
|
| 73 |
+
thumbnail = dem_thumbnail_from_datarow(first_row_group)
|
| 74 |
+
|
| 75 |
+
thumbnail_fname = 'example_thumbnail.png'
|
| 76 |
+
thumbnail.save(thumbnail_fname, format = 'PNG')
|
| 77 |
+
print('[example run] saved as "{}"'.format(thumbnail_fname))
|
MajorTOM/extras/thumbnail_s1rtc.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
|
| 3 |
+
|
| 4 |
+
Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from rasterio.io import MemoryFile
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
def s1rtc_thumbnail(vv, vh, vv_NODATA = -32768.0, vh_NODATA = -32768.0):
|
| 12 |
+
"""
|
| 13 |
+
Takes vv and vh numpy arrays along with the corresponding NODATA values (default is -32768.0)
|
| 14 |
+
|
| 15 |
+
Returns a numpy array with the thumbnail
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# valid data masks
|
| 19 |
+
vv_mask = vv != vv_NODATA
|
| 20 |
+
vh_mask = vh != vh_NODATA
|
| 21 |
+
|
| 22 |
+
# remove invalid values before log op
|
| 23 |
+
vv[vv<0] = vv[vv>=0].min()
|
| 24 |
+
vh[vh<0] = vh[vh>=0].min()
|
| 25 |
+
|
| 26 |
+
# apply log op
|
| 27 |
+
vv_dB = 10*np.log10(vv)
|
| 28 |
+
vh_dB = 10*np.log10(vh)
|
| 29 |
+
|
| 30 |
+
# scale to 0-255
|
| 31 |
+
vv_dB = (vv_dB - vv_dB[vv_mask].min()) / (vv_dB[vv_mask].max() - vv_dB[vv_mask].min()) * 255
|
| 32 |
+
vh_dB = (vh_dB - vh_dB[vh_mask].min()) / (vh_dB[vh_mask].max() - vh_dB[vh_mask].min()) * 255
|
| 33 |
+
|
| 34 |
+
# represent nodata as 0
|
| 35 |
+
vv_dB[vv_mask==0] = 0
|
| 36 |
+
vh_dB[vh_mask==0] = 0
|
| 37 |
+
|
| 38 |
+
# false colour composite
|
| 39 |
+
return np.stack([vv_dB,
|
| 40 |
+
255*(vv_dB+vh_dB)/np.max(vv_dB+vh_dB),
|
| 41 |
+
vh_dB
|
| 42 |
+
],-1).astype(np.uint8)
|
| 43 |
+
|
| 44 |
+
def s1rtc_thumbnail_from_datarow(datarow):
|
| 45 |
+
"""
|
| 46 |
+
Takes a datarow directly from one of the data parquet files
|
| 47 |
+
|
| 48 |
+
Returns a PIL Image
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
with MemoryFile(datarow['vv'][0].as_py()) as mem_f:
|
| 52 |
+
with mem_f.open(driver='GTiff') as f:
|
| 53 |
+
vv=f.read().squeeze()
|
| 54 |
+
vv_NODATA = f.nodata
|
| 55 |
+
|
| 56 |
+
with MemoryFile(datarow['vh'][0].as_py()) as mem_f:
|
| 57 |
+
with mem_f.open(driver='GTiff') as f:
|
| 58 |
+
vh=f.read().squeeze()
|
| 59 |
+
vh_NODATA = f.nodata
|
| 60 |
+
|
| 61 |
+
img = s1rtc_thumbnail(vv, vh, vv_NODATA=vv_NODATA, vh_NODATA=vh_NODATA)
|
| 62 |
+
|
| 63 |
+
return Image.fromarray(img)
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
from fsspec.parquet import open_parquet_file
|
| 67 |
+
import pyarrow.parquet as pq
|
| 68 |
+
|
| 69 |
+
print('[example run] reading file from HuggingFace...')
|
| 70 |
+
url = "https://huggingface.co/datasets/Major-TOM/Core-S1RTC/resolve/main/images/part_00001.parquet"
|
| 71 |
+
with open_parquet_file(url) as f:
|
| 72 |
+
with pq.ParquetFile(f) as pf:
|
| 73 |
+
first_row_group = pf.read_row_group(1)
|
| 74 |
+
|
| 75 |
+
print('[example run] computing the thumbnail...')
|
| 76 |
+
thumbnail = s1rtc_thumbnail_from_datarow(first_row_group)
|
| 77 |
+
|
| 78 |
+
thumbnail_fname = 'example_thumbnail.png'
|
| 79 |
+
thumbnail.save(thumbnail_fname, format = 'PNG')
|
| 80 |
+
print('[example run] saved as "{}"'.format(thumbnail_fname))
|
MajorTOM/extras/thumbnail_s2.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
|
| 3 |
+
|
| 4 |
+
Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from rasterio.io import MemoryFile
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
def s2l2a_thumbnail(B04, B03, B02, gain=1.3, gamma=0.6):
|
| 12 |
+
"""
|
| 13 |
+
Takes B04, B03, B02 numpy arrays along with the corresponding NODATA values (default is -32768.0)
|
| 14 |
+
|
| 15 |
+
Returns a numpy array with the thumbnail
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# concatenate
|
| 19 |
+
thumb = np.stack([B04, B03, B02], -1)
|
| 20 |
+
|
| 21 |
+
# apply gain & gamma
|
| 22 |
+
thumb = gain*((thumb/10_000)**gamma)
|
| 23 |
+
|
| 24 |
+
return (thumb.clip(0,1)*255).astype(np.uint8)
|
| 25 |
+
|
| 26 |
+
def s2l2a_thumbnail_from_datarow(datarow):
|
| 27 |
+
"""
|
| 28 |
+
Takes a datarow directly from one of the data parquet files
|
| 29 |
+
|
| 30 |
+
Returns a PIL Image
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# red
|
| 34 |
+
with MemoryFile(datarow['B04'][0].as_py()) as mem_f:
|
| 35 |
+
with mem_f.open(driver='GTiff') as f:
|
| 36 |
+
B04=f.read().squeeze()
|
| 37 |
+
B04_NODATA = f.nodata
|
| 38 |
+
|
| 39 |
+
# green
|
| 40 |
+
with MemoryFile(datarow['B03'][0].as_py()) as mem_f:
|
| 41 |
+
with mem_f.open(driver='GTiff') as f:
|
| 42 |
+
B03=f.read().squeeze()
|
| 43 |
+
B03_NODATA = f.nodata
|
| 44 |
+
|
| 45 |
+
# blue
|
| 46 |
+
with MemoryFile(datarow['B02'][0].as_py()) as mem_f:
|
| 47 |
+
with mem_f.open(driver='GTiff') as f:
|
| 48 |
+
B02=f.read().squeeze()
|
| 49 |
+
B02_NODATA = f.nodata
|
| 50 |
+
|
| 51 |
+
img = s2l2a_thumbnail(B04,B03,B02)
|
| 52 |
+
|
| 53 |
+
return Image.fromarray(img)
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
from fsspec.parquet import open_parquet_file
|
| 57 |
+
import pyarrow.parquet as pq
|
| 58 |
+
|
| 59 |
+
print('[example run] reading file from HuggingFace...')
|
| 60 |
+
url = "https://huggingface.co/datasets/Major-TOM/Core-S2L2A/resolve/main/images/part_01000.parquet"
|
| 61 |
+
with open_parquet_file(url, columns = ["B04", "B03", "B02"]) as f:
|
| 62 |
+
with pq.ParquetFile(f) as pf:
|
| 63 |
+
first_row_group = pf.read_row_group(1, columns = ["B04", "B03", "B02"])
|
| 64 |
+
|
| 65 |
+
print('[example run] computing the thumbnail...')
|
| 66 |
+
thumbnail = s2l2a_thumbnail_from_datarow(first_row_group)
|
| 67 |
+
|
| 68 |
+
thumbnail.save('example_thumbnail.png', format = 'PNG')
|
MajorTOM/grid.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import geopandas as gpd
|
| 5 |
+
from shapely.geometry import LineString, Polygon
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Grid():
|
| 12 |
+
|
| 13 |
+
RADIUS_EQUATOR = 6378.137 # km
|
| 14 |
+
|
| 15 |
+
def __init__(self,dist,latitude_range=(-85,85),longitude_range=(-180,180),utm_definition='bottomleft'):
|
| 16 |
+
self.dist = dist
|
| 17 |
+
self.latitude_range = latitude_range
|
| 18 |
+
self.longitude_range = longitude_range
|
| 19 |
+
self.utm_definition = utm_definition
|
| 20 |
+
self.rows,self.lats = self.get_rows()
|
| 21 |
+
self.points, self.points_by_row = self.get_points()
|
| 22 |
+
|
| 23 |
+
def get_rows(self):
|
| 24 |
+
|
| 25 |
+
# Define set of latitudes to use, based on the grid distance
|
| 26 |
+
arc_pole_to_pole = math.pi * self.RADIUS_EQUATOR
|
| 27 |
+
num_divisions_in_hemisphere = math.ceil(arc_pole_to_pole / self.dist)
|
| 28 |
+
|
| 29 |
+
latitudes = np.linspace(-90, 90, num_divisions_in_hemisphere+1)[:-1]
|
| 30 |
+
latitudes = np.mod(latitudes, 180) - 90
|
| 31 |
+
|
| 32 |
+
# order should be from south to north
|
| 33 |
+
latitudes = np.sort(latitudes)
|
| 34 |
+
|
| 35 |
+
zeroth_row = np.searchsorted(latitudes,0)
|
| 36 |
+
|
| 37 |
+
# From 0U-NU and 1D-ND
|
| 38 |
+
rows = [None] * len(latitudes)
|
| 39 |
+
rows[zeroth_row:] = [f'{i}U' for i in range(len(latitudes)-zeroth_row)]
|
| 40 |
+
rows[:zeroth_row] = [f'{abs(i-zeroth_row)}D' for i in range(zeroth_row)]
|
| 41 |
+
|
| 42 |
+
# bound to range
|
| 43 |
+
idxs = (latitudes>=self.latitude_range[0]) * (latitudes<=self.latitude_range[1])
|
| 44 |
+
rows,latitudes = np.array(rows), np.array(latitudes)
|
| 45 |
+
rows,latitudes = rows[idxs],latitudes[idxs]
|
| 46 |
+
|
| 47 |
+
return rows,latitudes
|
| 48 |
+
|
| 49 |
+
def get_circumference_at_latitude(self,lat):
|
| 50 |
+
|
| 51 |
+
# Circumference of the cross-section of a sphere at a given latitude
|
| 52 |
+
|
| 53 |
+
radius_at_lat = self.RADIUS_EQUATOR * math.cos(lat * math.pi / 180)
|
| 54 |
+
circumference = 2 * math.pi * radius_at_lat
|
| 55 |
+
|
| 56 |
+
return circumference
|
| 57 |
+
|
| 58 |
+
def subdivide_circumference(self,lat,return_cols=False):
|
| 59 |
+
# Provide a list of longitudes that subdivide the circumference of the earth at a given latitude
|
| 60 |
+
# into equal parts as close as possible to dist
|
| 61 |
+
|
| 62 |
+
circumference = self.get_circumference_at_latitude(lat)
|
| 63 |
+
num_divisions = math.ceil(circumference / self.dist)
|
| 64 |
+
longitudes = np.linspace(-180,180, num_divisions+1)[:-1]
|
| 65 |
+
longitudes = np.mod(longitudes, 360) - 180
|
| 66 |
+
longitudes = np.sort(longitudes)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if return_cols:
|
| 70 |
+
cols = [None] * len(longitudes)
|
| 71 |
+
zeroth_idx = np.where(longitudes==0)[0][0]
|
| 72 |
+
cols[zeroth_idx:] = [f'{i}R' for i in range(len(longitudes)-zeroth_idx)]
|
| 73 |
+
cols[:zeroth_idx] = [f'{abs(i-zeroth_idx)}L' for i in range(zeroth_idx)]
|
| 74 |
+
return np.array(cols),np.array(longitudes)
|
| 75 |
+
|
| 76 |
+
return np.array(longitudes)
|
| 77 |
+
|
| 78 |
+
def get_points(self):
|
| 79 |
+
|
| 80 |
+
r_idx = 0
|
| 81 |
+
points_by_row = [None]*len(self.rows)
|
| 82 |
+
for r,lat in zip(self.rows,self.lats):
|
| 83 |
+
point_names,grid_row_names,grid_col_names,grid_row_idx,grid_col_idx,grid_lats,grid_lons,utm_zones,epsgs = [],[],[],[],[],[],[],[],[]
|
| 84 |
+
cols,lons = self.subdivide_circumference(lat,return_cols=True)
|
| 85 |
+
|
| 86 |
+
cols,lons = self.filter_longitude(cols,lons)
|
| 87 |
+
c_idx = 0
|
| 88 |
+
for c,lon in zip(cols,lons):
|
| 89 |
+
point_names.append(f'{r}_{c}')
|
| 90 |
+
grid_row_names.append(r)
|
| 91 |
+
grid_col_names.append(c)
|
| 92 |
+
grid_row_idx.append(r_idx)
|
| 93 |
+
grid_col_idx.append(c_idx)
|
| 94 |
+
grid_lats.append(lat)
|
| 95 |
+
grid_lons.append(lon)
|
| 96 |
+
if self.utm_definition == 'bottomleft':
|
| 97 |
+
utm_zones.append(get_utm_zone_from_latlng([lat,lon]))
|
| 98 |
+
elif self.utm_definition == 'center':
|
| 99 |
+
center_lat = lat + (1000*self.dist/2)/111_120
|
| 100 |
+
center_lon = lon + (1000*self.dist/2)/(111_120*math.cos(center_lat*math.pi/180))
|
| 101 |
+
utm_zones.append(get_utm_zone_from_latlng([center_lat,center_lon]))
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f'Invalid utm_definition {self.utm_definition}')
|
| 104 |
+
epsgs.append(f'EPSG:{utm_zones[-1]}')
|
| 105 |
+
|
| 106 |
+
c_idx += 1
|
| 107 |
+
points_by_row[r_idx] = gpd.GeoDataFrame({
|
| 108 |
+
'name':point_names,
|
| 109 |
+
'row':grid_row_names,
|
| 110 |
+
'col':grid_col_names,
|
| 111 |
+
'row_idx':grid_row_idx,
|
| 112 |
+
'col_idx':grid_col_idx,
|
| 113 |
+
'utm_zone':utm_zones,
|
| 114 |
+
'epsg':epsgs
|
| 115 |
+
},geometry=gpd.points_from_xy(grid_lons,grid_lats))
|
| 116 |
+
r_idx += 1
|
| 117 |
+
points = gpd.GeoDataFrame(pd.concat(points_by_row))
|
| 118 |
+
# points.reset_index(inplace=True,drop=True)
|
| 119 |
+
return points, points_by_row
|
| 120 |
+
|
| 121 |
+
def group_points_by_row(self):
|
| 122 |
+
# Make list of different gdfs for each row
|
| 123 |
+
points_by_row = [None]*len(self.rows)
|
| 124 |
+
for i,row in enumerate(self.rows):
|
| 125 |
+
points_by_row[i] = self.points[self.points.row==row]
|
| 126 |
+
return points_by_row
|
| 127 |
+
|
| 128 |
+
def filter_longitude(self,cols,lons):
|
| 129 |
+
idxs = (lons>=self.longitude_range[0]) * (lons<=self.longitude_range[1])
|
| 130 |
+
cols,lons = cols[idxs],lons[idxs]
|
| 131 |
+
return cols,lons
|
| 132 |
+
|
| 133 |
+
def latlon2rowcol(self,lats,lons,return_idx=False,integer=False):
|
| 134 |
+
"""
|
| 135 |
+
Convert latitude and longitude to row and column number from the grid
|
| 136 |
+
"""
|
| 137 |
+
# Always take bottom left corner of grid cell
|
| 138 |
+
rows = np.searchsorted(self.lats,lats)-1
|
| 139 |
+
|
| 140 |
+
# Get the possible points of the grid cells at the given latitude
|
| 141 |
+
possible_points = [self.points_by_row[row] for row in rows]
|
| 142 |
+
|
| 143 |
+
# For each point, find the rightmost point that is still to the left of the given longitude
|
| 144 |
+
cols = [poss_points.iloc[np.searchsorted(poss_points.geometry.x,lon)-1].col for poss_points,lon in zip(possible_points,lons)]
|
| 145 |
+
rows = self.rows[rows].tolist()
|
| 146 |
+
|
| 147 |
+
outputs = [rows, cols]
|
| 148 |
+
if return_idx:
|
| 149 |
+
# Get the table index for self.points with each row,col pair in rows, cols
|
| 150 |
+
idx = [self.points[(self.points.row==row) & (self.points.col==col)].index.values[0] for row,col in zip(rows,cols)]
|
| 151 |
+
outputs.append(idx)
|
| 152 |
+
|
| 153 |
+
# return raw numbers
|
| 154 |
+
if integer:
|
| 155 |
+
outputs[0] = [int(el[:-1]) if el[-1] == 'U' else -int(el[:-1]) for el in outputs[0]]
|
| 156 |
+
outputs[1] = [int(el[:-1]) if el[-1] == 'R' else -int(el[:-1]) for el in outputs[1]]
|
| 157 |
+
|
| 158 |
+
return outputs
|
| 159 |
+
|
| 160 |
+
def rowcol2latlon(self,rows,cols):
|
| 161 |
+
point_geoms = [self.points.loc[(self.points.row==row) & (self.points.col==col),'geometry'].values[0] for row,col in zip(rows,cols)]
|
| 162 |
+
lats = [point.y for point in point_geoms]
|
| 163 |
+
lons = [point.x for point in point_geoms]
|
| 164 |
+
return lats,lons
|
| 165 |
+
|
| 166 |
+
def get_bounded_footprint(self,point,buffer_ratio=0):
|
| 167 |
+
# Gets the polygon footprint of the grid cell for a given point, bounded by the other grid points' cells.
|
| 168 |
+
# Grid point defined as bottom-left corner of polygon. Buffer ratio is the ratio of the grid cell's width/height to buffer by.
|
| 169 |
+
|
| 170 |
+
bottom,left = point.geometry.y,point.geometry.x
|
| 171 |
+
row_idx = point.row_idx
|
| 172 |
+
col_idx = point.col_idx
|
| 173 |
+
next_row_idx = row_idx+1
|
| 174 |
+
next_col_idx = col_idx+1
|
| 175 |
+
|
| 176 |
+
if next_row_idx >= len(self.lats): # If at top row, use difference between top and second-to-top row for height
|
| 177 |
+
height = (self.lats[row_idx] - self.lats[row_idx-1])
|
| 178 |
+
top = self.lats[row_idx] + height
|
| 179 |
+
else:
|
| 180 |
+
top = self.lats[next_row_idx]
|
| 181 |
+
|
| 182 |
+
max_col = len(self.points_by_row[row_idx].col_idx)-1
|
| 183 |
+
if next_col_idx > max_col: # If at rightmost column, use difference between rightmost and second-to-rightmost column for width
|
| 184 |
+
width = (self.points_by_row[row_idx].iloc[col_idx].geometry.x - self.points_by_row[row_idx].iloc[col_idx-1].geometry.x)
|
| 185 |
+
right = self.points_by_row[row_idx].iloc[col_idx].geometry.x + width
|
| 186 |
+
else:
|
| 187 |
+
right = self.points_by_row[row_idx].iloc[next_col_idx].geometry.x
|
| 188 |
+
|
| 189 |
+
# Buffer the polygon by the ratio of the grid cell's width/height
|
| 190 |
+
width = right - left
|
| 191 |
+
height = top - bottom
|
| 192 |
+
|
| 193 |
+
buffer_horizontal = width * buffer_ratio
|
| 194 |
+
buffer_vertical = height * buffer_ratio
|
| 195 |
+
|
| 196 |
+
new_left = left - buffer_horizontal
|
| 197 |
+
new_right = right + buffer_horizontal
|
| 198 |
+
|
| 199 |
+
new_bottom = bottom - buffer_vertical
|
| 200 |
+
new_top = top + buffer_vertical
|
| 201 |
+
|
| 202 |
+
bbox = Polygon([(new_left,new_bottom),(new_left,new_top),(new_right,new_top),(new_right,new_bottom)])
|
| 203 |
+
|
| 204 |
+
return bbox
|
| 205 |
+
|
| 206 |
+
def get_utm_zone_from_latlng(latlng):
|
| 207 |
+
"""
|
| 208 |
+
Get the UTM zone from a latlng list and return the corresponding EPSG code.
|
| 209 |
+
|
| 210 |
+
Parameters
|
| 211 |
+
----------
|
| 212 |
+
latlng : List[Union[int, float]]
|
| 213 |
+
The latlng list to get the UTM zone from.
|
| 214 |
+
|
| 215 |
+
Returns
|
| 216 |
+
-------
|
| 217 |
+
str
|
| 218 |
+
The EPSG code for the UTM zone.
|
| 219 |
+
"""
|
| 220 |
+
assert isinstance(latlng, (list, tuple)), "latlng must be in the form of a list or tuple."
|
| 221 |
+
|
| 222 |
+
longitude = latlng[1]
|
| 223 |
+
latitude = latlng[0]
|
| 224 |
+
|
| 225 |
+
zone_number = (math.floor((longitude + 180) / 6)) % 60 + 1
|
| 226 |
+
|
| 227 |
+
# Special zones for Svalbard and Norway
|
| 228 |
+
if latitude >= 56.0 and latitude < 64.0 and longitude >= 3.0 and longitude < 12.0:
|
| 229 |
+
zone_number = 32
|
| 230 |
+
elif latitude >= 72.0 and latitude < 84.0:
|
| 231 |
+
if longitude >= 0.0 and longitude < 9.0:
|
| 232 |
+
zone_number = 31
|
| 233 |
+
elif longitude >= 9.0 and longitude < 21.0:
|
| 234 |
+
zone_number = 33
|
| 235 |
+
elif longitude >= 21.0 and longitude < 33.0:
|
| 236 |
+
zone_number = 35
|
| 237 |
+
elif longitude >= 33.0 and longitude < 42.0:
|
| 238 |
+
zone_number = 37
|
| 239 |
+
|
| 240 |
+
# Determine the hemisphere and construct the EPSG code
|
| 241 |
+
if latitude < 0:
|
| 242 |
+
epsg_code = f"327{zone_number:02d}"
|
| 243 |
+
else:
|
| 244 |
+
epsg_code = f"326{zone_number:02d}"
|
| 245 |
+
if not re.match(r"32[6-7](0[1-9]|[1-5][0-9]|60)",epsg_code):
|
| 246 |
+
print(f"latlng: {latlng}, epsg_code: {epsg_code}")
|
| 247 |
+
raise ValueError(f"out of bound latlng resulted in incorrect EPSG code for the point")
|
| 248 |
+
|
| 249 |
+
return epsg_code
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if __name__ == '__main__':
|
| 253 |
+
|
| 254 |
+
assert get_utm_zone_from_latlng([-1,-174.34]) == "32701"
|
| 255 |
+
assert get_utm_zone_from_latlng([48,-4]) == "32630"
|
| 256 |
+
assert get_utm_zone_from_latlng([78,13]) == "32633"
|
| 257 |
+
assert get_utm_zone_from_latlng([-34,19.7]) == "32734"
|
| 258 |
+
assert get_utm_zone_from_latlng([-36,175.7]) == "32760"
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
dist = 100
|
| 262 |
+
grid = Grid(dist)
|
| 263 |
+
|
| 264 |
+
np.random.seed(0)
|
| 265 |
+
test_lons = np.random.uniform(-20,20,size=(1000)) % 180 # Checks edge-case of crossing 180th meridian
|
| 266 |
+
test_lats = np.random.uniform(-20,68,size=(1000))
|
| 267 |
+
|
| 268 |
+
test_rows,test_cols = grid.latlon2rowcol(test_lats,test_lons)
|
| 269 |
+
test_lats2,test_lons2 = grid.rowcol2latlon(test_rows,test_cols)
|
| 270 |
+
|
| 271 |
+
print(test_lons[:10])
|
| 272 |
+
print(test_lats[:10])
|
| 273 |
+
print(test_rows[:10])
|
| 274 |
+
print(test_cols[:10])
|
| 275 |
+
|
| 276 |
+
# Make line segments from the points to their corresponding grid points
|
| 277 |
+
lines = []
|
| 278 |
+
for i in range(len(test_lats)):
|
| 279 |
+
lines.append([(test_lons[i],test_lats[i]),(test_lons2[i],test_lats2[i])])
|
| 280 |
+
|
| 281 |
+
lines = gpd.GeoDataFrame(geometry=gpd.GeoSeries([LineString(line) for line in lines]))
|
| 282 |
+
|
| 283 |
+
lines.to_file(f'testlines_{dist}km.geojson',driver='GeoJSON')
|
| 284 |
+
grid.points.to_file(f'testgrid_{dist}km.geojson',driver='GeoJSON')
|
MajorTOM/metadata_helpers.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pyarrow.parquet as pq
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import geopandas as gpd
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import urllib.request
|
| 6 |
+
import fsspec
|
| 7 |
+
from fsspec.parquet import open_parquet_file
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from rasterio.io import MemoryFile
|
| 11 |
+
from tqdm.notebook import tqdm
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
from .sample_helpers import *
|
| 15 |
+
|
| 16 |
+
def metadata_from_url(access_url, local_url):
|
| 17 |
+
local_url, response = urllib.request.urlretrieve(access_url, local_url)
|
| 18 |
+
df = pq.read_table(local_url).to_pandas()
|
| 19 |
+
df['timestamp'] = pd.to_datetime(df.timestamp)
|
| 20 |
+
gdf = gpd.GeoDataFrame(
|
| 21 |
+
df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
|
| 22 |
+
)
|
| 23 |
+
return gdf
|
| 24 |
+
|
| 25 |
+
def filter_metadata(df,
|
| 26 |
+
region=None,
|
| 27 |
+
daterange=None,
|
| 28 |
+
cloud_cover=(0,100),
|
| 29 |
+
nodata=(0, 1.0)
|
| 30 |
+
):
|
| 31 |
+
"""Filters the Major-TOM dataframe based on several parameters
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
df (geopandas dataframe): Parent dataframe
|
| 35 |
+
region (shapely geometry object) : Region of interest
|
| 36 |
+
daterange (tuple) : Inclusive range of dates (example format: '2020-01-01')
|
| 37 |
+
cloud_cover (tuple) : Inclusive percentage range (0-100) of cloud cover
|
| 38 |
+
nodata (tuple) : Inclusive fraction (0.0-1.0) of no data allowed in a sample
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
df: a filtered dataframe
|
| 42 |
+
"""
|
| 43 |
+
# temporal filtering
|
| 44 |
+
if daterange is not None:
|
| 45 |
+
assert (isinstance(daterange, list) or isinstance(daterange, tuple)) and len(daterange)==2
|
| 46 |
+
df = df[df.timestamp >= daterange[0]]
|
| 47 |
+
df = df[df.timestamp <= daterange[1]]
|
| 48 |
+
|
| 49 |
+
# spatial filtering
|
| 50 |
+
if region is not None:
|
| 51 |
+
idxs = df.sindex.query(region)
|
| 52 |
+
df = df.take(idxs)
|
| 53 |
+
# cloud filtering
|
| 54 |
+
if cloud_cover is not None:
|
| 55 |
+
df = df[df.cloud_cover >= cloud_cover[0]]
|
| 56 |
+
df = df[df.cloud_cover <= cloud_cover[1]]
|
| 57 |
+
|
| 58 |
+
# spatial filtering
|
| 59 |
+
if nodata is not None:
|
| 60 |
+
df = df[df.nodata >= nodata[0]]
|
| 61 |
+
df = df[df.nodata <= nodata[1]]
|
| 62 |
+
|
| 63 |
+
return df
|
| 64 |
+
|
| 65 |
+
def read_row(row, columns=["thumbnail"]):
|
| 66 |
+
"""Reads a row from a Major-TOM dataframe
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
row (row from geopandas dataframe): The row of metadata
|
| 70 |
+
columns (list): columns to be read from the file
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
data (dict): dictionary with returned data from requested columns
|
| 74 |
+
"""
|
| 75 |
+
with open_parquet_file(row.parquet_url, columns=columns, footer_sample_size=2000000) as f:
|
| 76 |
+
with pq.ParquetFile(f) as pf:
|
| 77 |
+
row_group = pf.read_row_group(row.parquet_row, columns=columns)
|
| 78 |
+
|
| 79 |
+
if columns == ["thumbnail"]:
|
| 80 |
+
stream = BytesIO(row_group['thumbnail'][0].as_py())
|
| 81 |
+
return Image.open(stream)
|
| 82 |
+
else:
|
| 83 |
+
row_output = {}
|
| 84 |
+
for col in columns:
|
| 85 |
+
bytes = row_group[col][0].as_py()
|
| 86 |
+
|
| 87 |
+
if col != 'thumbnail':
|
| 88 |
+
row_output[col] = read_tif_bytes(bytes)
|
| 89 |
+
else:
|
| 90 |
+
stream = BytesIO(bytes)
|
| 91 |
+
row_output[col] = Image.open(stream)
|
| 92 |
+
|
| 93 |
+
return row_output
|
| 94 |
+
|
| 95 |
+
def filter_download(df, local_dir, source_name, by_row = False, verbose = False, tif_columns=None):
|
| 96 |
+
"""Downloads and unpacks the data of Major-TOM based on a metadata dataframe
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
df (geopandas dataframe): Metadata dataframe
|
| 100 |
+
local_dir (str or Path) : Path to the where the data is to be stored locally
|
| 101 |
+
source_name (str) : Name alias of the resulting dataset
|
| 102 |
+
by_row (bool): If True, it will access individual rows of parquet via http - otherwise entire parquets are downloaded temporarily
|
| 103 |
+
verbose (bool) : option for potential internal state printing
|
| 104 |
+
tif_columns (list of str) : Optionally specified columns to be downloaded as .tifs, e.g. ['B04', 'B03', 'B02']
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
None
|
| 108 |
+
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
if isinstance(local_dir, str):
|
| 112 |
+
local_dir = Path(local_dir)
|
| 113 |
+
|
| 114 |
+
temp_file = local_dir / 'temp.parquet'
|
| 115 |
+
|
| 116 |
+
# identify all parquets that need to be downloaded (group them)
|
| 117 |
+
urls = df.parquet_url.unique()
|
| 118 |
+
print('Starting download of {} parquet files.'.format(len(urls))) if verbose else None
|
| 119 |
+
|
| 120 |
+
for url in tqdm(urls, desc='Downloading and unpacking...', disable=not verbose):
|
| 121 |
+
# identify all relevant rows
|
| 122 |
+
rows = df[df.parquet_url == url].parquet_row.unique()
|
| 123 |
+
|
| 124 |
+
if not by_row: # (downloads entire parquet)
|
| 125 |
+
# download a temporary file
|
| 126 |
+
temp_path, http_resp = urllib.request.urlretrieve(url, temp_file)
|
| 127 |
+
else:
|
| 128 |
+
f=fsspec.open(url)
|
| 129 |
+
temp_path = f.open()
|
| 130 |
+
|
| 131 |
+
# populate the bands
|
| 132 |
+
with pq.ParquetFile(temp_path) as pf:
|
| 133 |
+
for row_idx in rows:
|
| 134 |
+
table = pf.read_row_group(row_idx)
|
| 135 |
+
|
| 136 |
+
product_id = table['product_id'][0].as_py()
|
| 137 |
+
grid_cell = table['grid_cell'][0].as_py()
|
| 138 |
+
row = grid_cell.split('_')[0]
|
| 139 |
+
|
| 140 |
+
dest = local_dir / Path("{}/{}/{}/{}".format(source_name, row, grid_cell, product_id))
|
| 141 |
+
dest.mkdir(exist_ok=True, parents=True)
|
| 142 |
+
|
| 143 |
+
columns = [col for col in table.column_names if col[0] == 'B'] + ['cloud_mask'] if tif_columns is None else tif_columns
|
| 144 |
+
# tifs
|
| 145 |
+
for col in columns:
|
| 146 |
+
with open(dest / "{}.tif".format(col), "wb") as f:
|
| 147 |
+
# Write bytes to file
|
| 148 |
+
f.write(table[col][0].as_py())
|
| 149 |
+
|
| 150 |
+
# thumbnail (png)
|
| 151 |
+
col = 'thumbnail'
|
| 152 |
+
with open(dest / "{}.png".format(col), "wb") as f:
|
| 153 |
+
# Write bytes to file
|
| 154 |
+
f.write(table[col][0].as_py())
|
| 155 |
+
if not by_row:
|
| 156 |
+
# remove downloaded file
|
| 157 |
+
os.remove(temp_path)
|
| 158 |
+
else:
|
| 159 |
+
f.close()
|
MajorTOM/sample_helpers.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rasterio.io import MemoryFile
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
|
| 7 |
+
def plot(sample, bands = ['B04', 'B03', 'B02'], scaling=2e3):
|
| 8 |
+
img = []
|
| 9 |
+
for b in bands:
|
| 10 |
+
img.append(read_tif_bytes(sample[b]))
|
| 11 |
+
plt.imshow(np.stack(img, -1)/2e3)
|
| 12 |
+
|
| 13 |
+
def read_tif_bytes(tif_bytes):
|
| 14 |
+
with MemoryFile(tif_bytes) as mem_f:
|
| 15 |
+
with mem_f.open(driver='GTiff') as f:
|
| 16 |
+
return f.read().squeeze()
|
| 17 |
+
|
| 18 |
+
def read_png_bytes(png_bytes):
|
| 19 |
+
stream = BytesIO(png_bytes)
|
| 20 |
+
return Image.open(stream)
|
README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: EarthExplorer
|
| 3 |
+
emoji: 🌍
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.9.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# EarthExplorer
|
| 14 |
+
|
| 15 |
+
A tool for searching satellite images of Earth using natural language descriptions, images, geolocations, or a simple click on the map.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- Text-based satellite image search
|
| 20 |
+
- Image-based similarity search
|
| 21 |
+
- Location-based search
|
| 22 |
+
- Interactive map interface
|
| 23 |
+
|
| 24 |
+
## Clone
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
git clone https://huggingface.co/spaces/ML4Sustain/EarthExplorer
|
| 28 |
+
```
|
Tutorial.md
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tutorial: EarthEmbeddingExplorer
|
| 2 |
+
|
| 3 |
+
## Background
|
| 4 |
+
|
| 5 |
+
### What is this project about?
|
| 6 |
+
EarthEmbeddingExplorer is a tool that lets you search satellite imagery using **natural language**, **images**, or **geographic locations**. In simple terms, you can enter prompts like “a satellite image of a glacier” or “a satellite image of a city with a coastline”, and the system will find places on Earth that match your description and visualize them on a map.
|
| 7 |
+
|
| 8 |
+
EarthEmbeddingExplorer enables users to explore the Earth in multiple ways without leaving their desk, and it can be useful for many geoscience tasks. For example, geologists can quickly locate glacier regions; biologists can rapidly map forest cover; and architects can study urban patterns across different parts of the world.
|
| 9 |
+
|
| 10 |
+
## How does it work? (Core ideas)
|
| 11 |
+
|
| 12 |
+
### Satellite imagery dataset
|
| 13 |
+
We use **MajorTOM** (Major TOM: Expandable Datasets for Earth Observation) released by the European Space Agency (ESA) [1]. Specifically, we use the [Core-S2L2A](https://modelscope.cn/datasets/Major-TOM/Core-S2L2A) subset.
|
| 14 |
+
|
| 15 |
+
| Dataset | Imagery source | Number of samples | Sensor type |
|
| 16 |
+
| :--- | :--- | :--- | :--- |
|
| 17 |
+
| MajorTOM-Core-S2L2A | Sentinel-2 Level 2A | 2,245,886 | Multispectral |
|
| 18 |
+
|
| 19 |
+
MajorTOM Core-S2L2A provides global Sentinel-2 multispectral imagery (10 m resolution). We convert the RGB bands into embeddings using CLIP-like models (e.g., SigLIP), which saves substantial time because we do not need to preprocess raw imagery ourselves. In addition, embeddings (vectors) are much smaller than raw imagery, and they are significantly faster to search.
|
| 20 |
+
|
| 21 |
+
To keep EarthEmbeddingExplorer responsive, we build a smaller but representative version of the dataset.
|
| 22 |
+
|
| 23 |
+
The original tiles in Core-S2L2A are large (1068×1068 pixels), but most AI models expect smaller inputs (384×384 or 224×224 pixels).
|
| 24 |
+
1. **Cropping**: for simplicity, from each original tile we only take the **center** 384×384 (or 224×224) crop to generate an embedding.
|
| 25 |
+
2. **Uniform sampling**: using MajorTOM’s grid coding system, we sample **1%** of the data (about 22,000 images). This preserves global coverage while keeping search fast.
|
| 26 |
+
|
| 27 |
+
<div align="center">
|
| 28 |
+
<img src="images/samples.png" width="50%" />
|
| 29 |
+
<br>
|
| 30 |
+
<em>Figure 1: Geographic distribution of our sampled satellite image embeddings.</em>
|
| 31 |
+
</div>
|
| 32 |
+
|
| 33 |
+
### Retrieval models
|
| 34 |
+
The core of image retrieval is a family of models known as **CLIP (Contrastive Language-Image Pre-training)** [2]. We use its improved variants such as **SigLIP (Sigmoid Language-Image Pre-training)** [3], **FarSLIP (Fine-grained Aligned Remote Sensing Language Image Pretraining)** [4], and **SatCLIP (Satellite Location-Image Pretraining)** [5].
|
| 35 |
+
|
| 36 |
+
An analogy: when teaching a child, you show a picture of a glacier and say “glacier”. After seeing many examples, the child learns to associate the visual concept with the word.
|
| 37 |
+
|
| 38 |
+
CLIP-like models learn in a similar way, but at much larger scale.
|
| 39 |
+
- An image encoder turns an **image** into an **embedding** (a vector of numbers).
|
| 40 |
+
- A text (or location) encoder turns **text** (or **latitude/longitude**) into an embedding.
|
| 41 |
+
|
| 42 |
+
The key property is: if an image matches a text description (or location), their embeddings will be close; otherwise they will be far apart.
|
| 43 |
+
|
| 44 |
+
<div align="center">
|
| 45 |
+
<img src="images/CLIP.png" width="40%" />
|
| 46 |
+
<br>
|
| 47 |
+
<em>Figure 2: How CLIP-like models connect images and text.</em>
|
| 48 |
+
</div>
|
| 49 |
+
|
| 50 |
+
The three models we use differ in their encoders and training data:
|
| 51 |
+
|
| 52 |
+
| Model | Encoder type | Training data |
|
| 53 |
+
| :--- | :--- | :--- |
|
| 54 |
+
| SigLIP | image encoder + text encoder | natural image–text pairs from the web |
|
| 55 |
+
| FarSLIP | image encoder + text encoder | satellite image–text pairs |
|
| 56 |
+
| SatCLIP | image encoder + location encoder | satellite image–location pairs |
|
| 57 |
+
|
| 58 |
+
<div align="center">
|
| 59 |
+
<img src="images/embedding.png" width="30%" />
|
| 60 |
+
<br>
|
| 61 |
+
<em>Figure 3: Converting satellite images into embedding vectors.</em>
|
| 62 |
+
</div>
|
| 63 |
+
|
| 64 |
+
In EarthEmbeddingExplorer:
|
| 65 |
+
1. We precompute embeddings for ~22k globally distributed satellite images using SigLIP, FarSLIP, and SatCLIP.
|
| 66 |
+
2. When you provide a query (text like “a satellite image of glacier”, an image, or a location such as (-89, 120)), we encode the query into an embedding using the corresponding encoder.
|
| 67 |
+
3. We compare the query embedding with all image embeddings, visualize similarities on a map, and show the top-5 most similar images.
|
| 68 |
+
|
| 69 |
+
## System architecture
|
| 70 |
+
|
| 71 |
+
<div align="center">
|
| 72 |
+
<img src="images/framework_en.png" width="70%" />
|
| 73 |
+
<br>
|
| 74 |
+
<em>Figure 4: EarthEmbeddingExplorer system architecture on ModelScope.</em>
|
| 75 |
+
</div>
|
| 76 |
+
|
| 77 |
+
We deploy EarthEmbeddingExplorer on ModelScope: the models, embedding datasets, and raw imagery datasets are all hosted on the platform. The app runs on [xGPU](https://www.modelscope.cn/brand/view/xGPU), allowing flexible access to GPU resources and faster retrieval.
|
| 78 |
+
|
| 79 |
+
### How is the raw imagery stored?
|
| 80 |
+
|
| 81 |
+
MajorTOM Core-S2L2A is large (about 23 TB), so we do not download the full dataset. Instead, the raw imagery is stored as **Parquet shards**:
|
| 82 |
+
|
| 83 |
+
- **Shard storage**: the dataset is split into many remote Parquet files (shards), each containing a subset of the samples.
|
| 84 |
+
- **Columnar storage**: different fields/bands (e.g., B04/B03/B02, thumbnail) are stored as separate columns; we only read what we need.
|
| 85 |
+
- **Metadata index**: we maintain a small index table mapping `product_id → (parquet_url, parquet_row)` so the system can locate “which shard and which position” contains a given image.
|
| 86 |
+
|
| 87 |
+
With this design, when a user only needs a small number of images from the retrieval results, the system can use **HTTP Range requests** to download only a small byte range from a Parquet file (the target row/row group and the requested columns), rather than downloading the full 23 TB dataset—enabling near real-time retrieval of raw images.
|
| 88 |
+
|
| 89 |
+
### What happens when you use the app?
|
| 90 |
+
|
| 91 |
+
1. **Enter a query**: you can enter text, upload an image, or input a latitude/longitude. You can also click on the map to use the clicked location as a query.
|
| 92 |
+
2. **Compute similarity**: the app encodes your query into an embedding vector and computes similarity scores against all satellite image embeddings.
|
| 93 |
+
3. **Show results**: the system filters out low-similarity results and shows the highest-scoring locations (and scores) on the map. You can adjust the threshold using a slider.
|
| 94 |
+
4. **Download raw images on demand**: for the top-5 most similar images, the system looks up their `parquet_url` and row position via the metadata index, then uses HTTP Range to fetch only the required data (RGB bands) and displays the images quickly in the UI.
|
| 95 |
+
|
| 96 |
+
## Examples
|
| 97 |
+
<div align="center">
|
| 98 |
+
<img src="images/Text_Search.jpg" width="99%" />
|
| 99 |
+
<br>
|
| 100 |
+
<em>Figure 5: Search by text.</em>
|
| 101 |
+
</div>
|
| 102 |
+
<br>
|
| 103 |
+
|
| 104 |
+
<div align="center">
|
| 105 |
+
<img src="images/Image_Search_Amazon.jpg" width="99%" />
|
| 106 |
+
<br>
|
| 107 |
+
<em>Figure 6: Search by image.</em>
|
| 108 |
+
</div>
|
| 109 |
+
<br>
|
| 110 |
+
|
| 111 |
+
<div align="center">
|
| 112 |
+
<img src="images/Location_Search_Amazon.jpg" width="99%" />
|
| 113 |
+
<br>
|
| 114 |
+
<em>Figure 7: Search by location.</em>
|
| 115 |
+
</div>
|
| 116 |
+
|
| 117 |
+
## Limitations
|
| 118 |
+
|
| 119 |
+
While EarthEmbeddingExplorer has strong potential, it also has limitations. SigLIP is primarily trained on “natural images” from the internet (people, pets, cars, everyday objects) rather than satellite imagery. This domain gap can make it harder for the model to understand certain scientific terms or distinctive geographic patterns that are uncommon in typical web photos.
|
| 120 |
+
|
| 121 |
+
FarSLIP may perform poorly on non-remote-sensing concepts described in text, such as queries like “an image of face”.
|
| 122 |
+
|
| 123 |
+
## Acknowledgements
|
| 124 |
+
|
| 125 |
+
We thank the following open-source projects and datasets that made EarthEmbeddingExplorer possible:
|
| 126 |
+
|
| 127 |
+
**Models:**
|
| 128 |
+
- [SigLIP](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384) - Vision Transformer model for image-text alignment
|
| 129 |
+
- [FarSLIP](https://github.com/NJU-LHRS/FarSLIP) - Fine-grained satellite image-text pretraining model
|
| 130 |
+
- [SatCLIP](https://github.com/microsoft/satclip) - Satellite location-image pretraining model
|
| 131 |
+
|
| 132 |
+
**Datasets:**
|
| 133 |
+
- [MajorTOM](https://github.com/ESA-PhiLab/MajorTOM) - Expandable datasets for Earth observation by ESA
|
| 134 |
+
|
| 135 |
+
We are grateful to the research communities and organizations that developed and shared these resources.
|
| 136 |
+
|
| 137 |
+
## Contributors
|
| 138 |
+
- [Yijie Zheng](https://voyagerxvoyagerx.github.io/)
|
| 139 |
+
- [Weijie Wu](https://github.com/go-bananas-wwj)
|
| 140 |
+
- [Bingyue Wu](https://brynn-wu.github.io/Brynn-Wu)
|
| 141 |
+
|
| 142 |
+
## Roadmap
|
| 143 |
+
- [ ] Increase the geographical coverage (sample rate) to 1.2% of of the Earth's land surface. (coming by 16 Jan!)
|
| 144 |
+
- [ ] Support DINOv2 Embedding model and embedding datasets.
|
| 145 |
+
- [ ] Support FAISS for faster similarity search.
|
| 146 |
+
- [ ] What features do you want? Leave an issue [here](https://huggingface.co/spaces/ML4Sustain/EarthExplorer/discussions)!
|
| 147 |
+
|
| 148 |
+
We warmly welcome new contributors!
|
| 149 |
+
|
| 150 |
+
## References
|
| 151 |
+
|
| 152 |
+
[1] Francis, A., & Czerkawski, M. (2024). Major TOM: Expandable Datasets for Earth Observation. IGARSS 2024.
|
| 153 |
+
|
| 154 |
+
[2] Radford, A., et al. (2021). Learning Transferable Visual Models From Natural Language Supervision. ICML 2021.
|
| 155 |
+
|
| 156 |
+
[3] Zhai, X., et al. (2023). Sigmoid Loss for Language-Image Pre-Training. ICCV 2023.
|
| 157 |
+
|
| 158 |
+
[4] Li, Z., et al. (2025). FarSLIP: Discovering Effective CLIP Adaptation for Fine-Grained Remote Sensing Understanding. arXiv 2025.
|
| 159 |
+
|
| 160 |
+
[5] Klemmer, K. et al. (2025). SatCLIP: Global, General-Purpose Location Embeddings with Satellite Imagery. AAAI 2025.
|
| 161 |
+
|
| 162 |
+
[6] Czerkawski, M., Kluczek, M., & Bojanowski, J. S. (2024). Global and Dense Embeddings of Earth: Major TOM Floating in the Latent Space. arXiv preprint arXiv:2412.05600.
|
Tutorial_zh.md
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 教程:EarthExplorer 地球探索者
|
| 2 |
+
|
| 3 |
+
## 背景介绍
|
| 4 |
+
|
| 5 |
+
### 这个项目是做什么的?
|
| 6 |
+
EarthExplorer 是一个可以通过**自然语言**,**图像**,或**地理位置**搜索卫星图像的工具。简单来说,你可以输入像“a satellite image of glacier”或“a satellite image of city with a coastline”这样的描述,系统就会在地球上找到符合你描述的地点,并将它们在地图上展示出来。EarthExplorer 可以让用户足不出户地,以多种方式探索地球上的每一个角落,在地理科学领域有广泛的应用价值。例如,地质学家们可以用这个工具来快速寻找冰川的分布;生物学家可以快速进行森林覆盖的制图,建筑学家们可以研究世界不同地区的城市发展结构。
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
## 它是如何工作的?(核心原理)
|
| 10 |
+
|
| 11 |
+
### 卫星影像数据集
|
| 12 |
+
我们使用了欧空局(ESA)发布的 **MajorTOM** (Major TOM: Expandable Datasets for Earth Observation) 数据集 [1]。具体来说,我们使用的是 [Core-S2L2A](https://modelscope.cn/datasets/Major-TOM/Core-S2L2A) 这个子集。
|
| 13 |
+
|
| 14 |
+
| 数据集 | 影像来源 | 嵌入数量 | 传感器类型 |
|
| 15 |
+
| :--- | :--- | :--- | :--- |
|
| 16 |
+
| MajorTOM-Core-S2L2A | Sentinel-2 Level 2A | 2,245,886 | 多光谱 |
|
| 17 |
+
|
| 18 |
+
MajorTOM Core-S2L2A 包含了全球覆盖的 Sentinel-2 多光谱影像(10m 分辨率);我们将这个数据集则利用 SigLIP 模型将 RGB 波段处理成了嵌入。这为我们节省了大量时间,因为我们不需要自己去处理这些原始图像!此外,图像嵌入(一串数字)的存储空间远小于原始图像,计算效率也更高!
|
| 19 |
+
|
| 20 |
+
为了让 EarthExplorer 响应迅速,我们创建了一个更小、更有代表性的数据集版本。
|
| 21 |
+
|
| 22 |
+
Core-S2L2A 中的原始卫星图像尺寸很大(1068x1068 像素),但 AI 模型需要较小的输入尺寸(384x384 或 224x224 像素)。
|
| 23 |
+
1. **裁剪**:为了简化,对每个原尺寸图像,我们仅选取大图正中心的 384x384 或 224x224 像素区域所生成的嵌入。
|
| 24 |
+
2. **随机采样**:我们根据 MajorTOM 的网格编码系统,均匀采样了 **1%** 的数据(约 22000 张图像)。这样既能保证全球覆盖,又可以在很短的时间内检索出结果。
|
| 25 |
+
|
| 26 |
+
<div align="center">
|
| 27 |
+
<img src="images/samples.png" width="50%" />
|
| 28 |
+
<br>
|
| 29 |
+
<em>图 1:我们采样的卫星图像嵌入的地理分布。</em>
|
| 30 |
+
</div>
|
| 31 |
+
|
| 32 |
+
### 检索模型
|
| 33 |
+
图像检索核心技术是一种叫做 **CLIP (Contrastive Language-Image Pre-training)** [2] 的人工智能模型,我们使用的是它的改进版本 **SigLIP (Sigmoid Language-Image Pre-training)** [3], **FarSLIP (Fine-grained Aligned Remote Sensing Language Image Pretraining)** [4], 和 **SatCLIP (Satellite Location-Image Pretraining)** [5]。
|
| 34 |
+
|
| 35 |
+
想象一下教小孩子识物。你给他们看一张冰川的照片,并说“冰川”。在看了很多冰川的照片并听到这个词后,孩子就学会了将冰川的样子和“冰川”这个词联系起来。
|
| 36 |
+
|
| 37 |
+
SigLIP/FarSLIP/SatCLIP 的工作原理类似,但规模要大得多。它在学习了数百万个图片-文字对或图片-地理位置对,从而理解了图像和文本/地理位置之间的关系。
|
| 38 |
+
- 它使用图片编码器将**图像**转换成一种数学表示(一串数字),我们称之为**嵌入 (Embedding)**。
|
| 39 |
+
- 它也使用文本/地理位置编码器将**文本**或**地理位置(经纬度坐标)**转换成类似的数学表示(嵌入)。
|
| 40 |
+
|
| 41 |
+
神奇之处在于,如果一张图片和一段文字描述或经纬度是匹配的,它们转换后的数学表示就会非常接近。如果不匹配,它们就会相距很远。
|
| 42 |
+
|
| 43 |
+
<div align="center">
|
| 44 |
+
<img src="images/CLIP.png" width="40%" />
|
| 45 |
+
<br>
|
| 46 |
+
<em>图 2:CLIP 类模型如何连接图像和文本/位置。</em>
|
| 47 |
+
</div>
|
| 48 |
+
|
| 49 |
+
我们用到的三个模型的模型结构和训练数据是:
|
| 50 |
+
| 模型 | 编码器类型 | 训练数据来源 |
|
| 51 |
+
| :--- | :--- | :--- |
|
| 52 |
+
| SigLIP | 图像编码器+文本编码器 | 互联网上的自然图像-文本对 |
|
| 53 |
+
| FarSLIP | 图像编码器+文本编码器 | 卫星图像-文本对 |
|
| 54 |
+
| SatCLIP | 图像编码器+位置编码器 | 卫星图像-地理位置对 |
|
| 55 |
+
|
| 56 |
+
<div align="center">
|
| 57 |
+
<img src="images/embedding.png" width="30%" />
|
| 58 |
+
<br>
|
| 59 |
+
<em>图 3:将卫星图像转换成嵌入向量。</em>
|
| 60 |
+
</div>
|
| 61 |
+
|
| 62 |
+
在 EarthExplorer 中:
|
| 63 |
+
1. 我们将全球均匀采样的两万多张卫星图像,分别使用 SigLIP, FarSLIP, 和 SatCLIP 的图像编码器,将卫星图像已经转换成这种数学“嵌入”。
|
| 64 |
+
2. 当你输入一个查询,这个查询可以是文本(例如“a satellite image of glacier”),图像(一张冰川的图像),或地理位置(-89, 120),我们将你的查询也使用对应的编码器转换成嵌入。
|
| 65 |
+
3. 然后,我们将你的查询嵌入与所有卫星图像的嵌入进行比较,将相似度在地图上可视化,并展示最相似的5张图像。
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
## 系统架构
|
| 69 |
+
|
| 70 |
+
<div align="center">
|
| 71 |
+
<img src="images/framework_zh.png" width="70%" />
|
| 72 |
+
<br>
|
| 73 |
+
<em>图 4:基于魔搭创空间的 EarthExplorer 系统架构。</em>
|
| 74 |
+
</div>
|
| 75 |
+
|
| 76 |
+
我们基于魔搭平台进行部署:模型、嵌入数据集、以及原始影像数据集都托管在魔搭上。我们将 APP 部署在 [xGPU](https://www.modelscope.cn/brand/view/xGPU) 环境下,使得用户可以获得灵活调度的免费 GPU 资源,加快检索速度。
|
| 77 |
+
|
| 78 |
+
### 原始影像是如何存的?
|
| 79 |
+
|
| 80 |
+
MajorTOM Core-S2L2A 的原始影像体量很大(约 23TB),以 **Parquet 分片(shard)** 的方式存储:
|
| 81 |
+
|
| 82 |
+
- **分片存储**:数据被拆成很多个远端 Parquet 文件(分片),每个分片只包含一部分影像样本。
|
| 83 |
+
- **列式存储**:每个影像的不同字段/波段(例如 B04/B03/B02、thumbnail)存成不同的列,需要什么就读什么。
|
| 84 |
+
- **元数据索引**:我们额外维护一份很小的索引表,把 `product_id → (parquet_url, parquet_row)` 对应起来,告诉系统“这个 id 的影像在哪个分片、在分片里的哪个位置”。
|
| 85 |
+
|
| 86 |
+
这样,当用户只需要查看检索结果的少量影像时,系统可以通过 **HTTP Range 请求**只下载 Parquet 文件中“那一小段字节”(对应目标行/行组 + 指定列的数据),而不是下载整个 23TB 数据集,从而实现秒级取图。
|
| 87 |
+
|
| 88 |
+
### 当你使用这个 App 时
|
| 89 |
+
|
| 90 |
+
1. **输入查询**:你可以输入文字、上传图片、输入经纬度;也可以在地图上点击一个位置,直接把该点经纬度作为查询。
|
| 91 |
+
2. **计算相似度**:App 将你的查询编码成一个“嵌入向量”,并与嵌入数据集中每一张卫星图像的嵌入计算相似度分数。
|
| 92 |
+
3. **展示检索结果**:系统过滤掉相似度较低的结果,把相似度最高的地点(以及分数)显示在地图上;你可以用滑动条调整阈值。
|
| 93 |
+
4. **按需下载原图**:对最相似的前 5 张影像,系统用 `product_id` 查询元数据索引定位到远端 `parquet_url` 和行位置,然后通过 HTTP Range 只拉取对应缩略图数据,在前端快速展示原始影像。
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
## 示例
|
| 97 |
+
<div align="center">
|
| 98 |
+
<img src="images/Text_Search.jpg" width="99%" />
|
| 99 |
+
<br>
|
| 100 |
+
<em>图 5:以文搜图示例。</em>
|
| 101 |
+
</div>
|
| 102 |
+
<br>
|
| 103 |
+
|
| 104 |
+
<div align="center">
|
| 105 |
+
<img src="images/Image_Search_Amazon.jpg" width="99%" />
|
| 106 |
+
<br>
|
| 107 |
+
<em>图 6:以图搜图示例。</em>
|
| 108 |
+
</div>
|
| 109 |
+
|
| 110 |
+
<br>
|
| 111 |
+
<div align="center">
|
| 112 |
+
<img src="images/Location_Search_Amazon.jpg" width="99%" />
|
| 113 |
+
<br>
|
| 114 |
+
<em>图 7:以点搜图示例。</em>
|
| 115 |
+
</div>
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
## 局限性
|
| 119 |
+
|
| 120 |
+
虽然 EarthExplorer 有很大的应用潜力,但它也有一些局限性。SigLIP 模型主要是通过互联网上的“自然图像”(如人物、猫狗、汽车、日常用品的照片)训练的,而不是专门针对卫星图像训练的。这种训练数据和应用时数据的偏差,使得模型可能难以理解特定的科学术语或在普通网络照片中不常见的独特地理特征。而 FarSLIP 模型对非典型遥感地物的语言描述,例如 'an image of face' 的检索效果不佳。
|
| 121 |
+
|
| 122 |
+
未来的工作可以使用其他专门针对地球观测数据训练的 AI 模型来提高检索的准确性。
|
| 123 |
+
|
| 124 |
+
## 未来工作
|
| 125 |
+
- 结合时间序列影像,实现全球变化监测
|
| 126 |
+
- 添加不同地球基础模型,对比不同模型的检索性能
|
| 127 |
+
|
| 128 |
+
## 致谢
|
| 129 |
+
我们感谢以下开源项目和数据集,它们使 EarthExplorer 得以实现:
|
| 130 |
+
|
| 131 |
+
**模型:**
|
| 132 |
+
- [SigLIP](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384) - 用于图像-文本对齐的视觉Transformer模型
|
| 133 |
+
- [FarSLIP](https://github.com/NJU-LHRS/FarSLIP) - 细粒度卫星图像-文本预训练模型
|
| 134 |
+
- [SatCLIP](https://github.com/microsoft/satclip) - 卫星位置-图像预训练模型
|
| 135 |
+
|
| 136 |
+
**数据集:**
|
| 137 |
+
- [MajorTOM](https://github.com/ESA-PhiLab/MajorTOM) - 欧洲航天局(ESA)的可扩展地球观测数据集
|
| 138 |
+
|
| 139 |
+
我们感谢开发和分享这些资源的研究社区和组织。
|
| 140 |
+
|
| 141 |
+
## 贡献者
|
| 142 |
+
- [郑祎杰](https://voyagerxvoyagerx.github.io/)
|
| 143 |
+
- [伍炜杰](https://github.com/go-bananas-wwj)
|
| 144 |
+
- [吴冰玥](https://brynn-wu.github.io/Brynn-Wu)
|
| 145 |
+
|
| 146 |
+
## 引用
|
| 147 |
+
[1] Francis, A., & Czerkawski, M. (2024). Major TOM: Expandable Datasets for Earth Observation. IGARSS 2024.
|
| 148 |
+
|
| 149 |
+
[2] Radford, A., et al. (2021). Learning Transferable Visual Models From Natural Language Supervision. ICML 2021.
|
| 150 |
+
|
| 151 |
+
[3] Zhai, X., et al. (2023). Sigmoid Loss for Language-Image Pre-Training. ICCV 2023.
|
| 152 |
+
|
| 153 |
+
[4] Li, Z., et al. (2025). FarSLIP: Discovering Effective CLIP Adaptation for Fine-Grained Remote Sensing Understanding. arXiv 2025.
|
| 154 |
+
|
| 155 |
+
[5] Klemmer, K. et al. (2025). SatCLIP: Global, General-Purpose Location Embeddings with Satellite Imagery. AAAI 2025.
|
| 156 |
+
|
| 157 |
+
[6] Czerkawski, M., Kluczek, M., & Bojanowski, J. S. (2024). Global and Dense Embeddings of Earth: Major TOM Floating in the Latent Space. arXiv preprint arXiv:2412.05600.
|
app.py
ADDED
|
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import zipfile
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 10 |
+
|
| 11 |
+
# Import custom modules
|
| 12 |
+
from models.siglip_model import SigLIPModel
|
| 13 |
+
from models.satclip_model import SatCLIPModel
|
| 14 |
+
from models.farslip_model import FarSLIPModel
|
| 15 |
+
from models.load_config import load_and_process_config
|
| 16 |
+
from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution
|
| 17 |
+
from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image
|
| 18 |
+
from PIL import Image as PILImage
|
| 19 |
+
from PIL import ImageDraw, ImageFont
|
| 20 |
+
|
| 21 |
+
# Configuration
|
| 22 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
print(f"Running on device: {device}")
|
| 24 |
+
|
| 25 |
+
# Load and process configuration
|
| 26 |
+
config = load_and_process_config()
|
| 27 |
+
|
| 28 |
+
# Initialize Models
|
| 29 |
+
print("Initializing models...")
|
| 30 |
+
models = {}
|
| 31 |
+
|
| 32 |
+
# SigLIP
|
| 33 |
+
try:
|
| 34 |
+
if config and 'siglip' in config:
|
| 35 |
+
models['SigLIP'] = SigLIPModel(
|
| 36 |
+
ckpt_path=config['siglip'].get('ckpt_path'),
|
| 37 |
+
tokenizer_path=config['siglip'].get('tokenizer_path'),
|
| 38 |
+
embedding_path=config['siglip'].get('embedding_path'),
|
| 39 |
+
device=device
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
models['SigLIP'] = SigLIPModel(device=device)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Failed to load SigLIP: {e}")
|
| 45 |
+
|
| 46 |
+
# SatCLIP
|
| 47 |
+
try:
|
| 48 |
+
if config and 'satclip' in config:
|
| 49 |
+
models['SatCLIP'] = SatCLIPModel(
|
| 50 |
+
ckpt_path=config['satclip'].get('ckpt_path'),
|
| 51 |
+
embedding_path=config['satclip'].get('embedding_path'),
|
| 52 |
+
device=device
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
models['SatCLIP'] = SatCLIPModel(device=device)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Failed to load SatCLIP: {e}")
|
| 58 |
+
|
| 59 |
+
# FarSLIP
|
| 60 |
+
try:
|
| 61 |
+
if config and 'farslip' in config:
|
| 62 |
+
models['FarSLIP'] = FarSLIPModel(
|
| 63 |
+
ckpt_path=config['farslip'].get('ckpt_path'),
|
| 64 |
+
model_name=config['farslip'].get('model_name'),
|
| 65 |
+
embedding_path=config['farslip'].get('embedding_path'),
|
| 66 |
+
device=device
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
models['FarSLIP'] = FarSLIPModel(device=device)
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Failed to load FarSLIP: {e}")
|
| 72 |
+
|
| 73 |
+
def get_active_model(model_name):
|
| 74 |
+
if model_name not in models:
|
| 75 |
+
return None, f"Model {model_name} not loaded."
|
| 76 |
+
return models[model_name], None
|
| 77 |
+
|
| 78 |
+
def combine_images(img1, img2):
|
| 79 |
+
if img1 is None: return img2
|
| 80 |
+
if img2 is None: return img1
|
| 81 |
+
|
| 82 |
+
# Resize to match width
|
| 83 |
+
w1, h1 = img1.size
|
| 84 |
+
w2, h2 = img2.size
|
| 85 |
+
|
| 86 |
+
new_w = max(w1, w2)
|
| 87 |
+
new_h1 = int(h1 * new_w / w1)
|
| 88 |
+
new_h2 = int(h2 * new_w / w2)
|
| 89 |
+
|
| 90 |
+
img1 = img1.resize((new_w, new_h1))
|
| 91 |
+
img2 = img2.resize((new_w, new_h2))
|
| 92 |
+
|
| 93 |
+
dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255))
|
| 94 |
+
dst.paste(img1, (0, 0))
|
| 95 |
+
dst.paste(img2, (0, new_h1))
|
| 96 |
+
return dst
|
| 97 |
+
|
| 98 |
+
def create_text_image(text, size=(384, 384)):
|
| 99 |
+
img = PILImage.new('RGB', size, color=(240, 240, 240))
|
| 100 |
+
d = ImageDraw.Draw(img)
|
| 101 |
+
|
| 102 |
+
# Try to load a font, fallback to default
|
| 103 |
+
try:
|
| 104 |
+
# Try to find a font that supports larger size
|
| 105 |
+
font = ImageFont.truetype("DejaVuSans.ttf", 40)
|
| 106 |
+
except:
|
| 107 |
+
font = ImageFont.load_default()
|
| 108 |
+
|
| 109 |
+
# Wrap text simply
|
| 110 |
+
margin = 20
|
| 111 |
+
offset = 100
|
| 112 |
+
for line in text.split(','):
|
| 113 |
+
d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0))
|
| 114 |
+
offset += 50
|
| 115 |
+
|
| 116 |
+
d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255))
|
| 117 |
+
return img
|
| 118 |
+
|
| 119 |
+
def fetch_top_k_images(top_indices, probs, df_embed, query_text=None):
|
| 120 |
+
"""
|
| 121 |
+
Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image.
|
| 122 |
+
"""
|
| 123 |
+
results = []
|
| 124 |
+
|
| 125 |
+
# We can run this in parallel
|
| 126 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
| 127 |
+
future_to_idx = {}
|
| 128 |
+
for i, idx in enumerate(top_indices):
|
| 129 |
+
row = df_embed.iloc[idx]
|
| 130 |
+
pid = row['product_id']
|
| 131 |
+
|
| 132 |
+
# Use download_and_process_image to get real data
|
| 133 |
+
future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False)
|
| 134 |
+
future_to_idx[future] = idx
|
| 135 |
+
|
| 136 |
+
for future in as_completed(future_to_idx):
|
| 137 |
+
idx = future_to_idx[future]
|
| 138 |
+
try:
|
| 139 |
+
img_384, img_full = future.result()
|
| 140 |
+
|
| 141 |
+
if img_384 is None:
|
| 142 |
+
# Fallback to Esri if download fails
|
| 143 |
+
print(f"Download failed for idx {idx}, falling back to Esri...")
|
| 144 |
+
row = df_embed.iloc[idx]
|
| 145 |
+
img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text)
|
| 146 |
+
img_full = img_384
|
| 147 |
+
|
| 148 |
+
row = df_embed.iloc[idx]
|
| 149 |
+
results.append({
|
| 150 |
+
'image_384': img_384,
|
| 151 |
+
'image_full': img_full,
|
| 152 |
+
'score': probs[idx],
|
| 153 |
+
'lat': row['centre_lat'],
|
| 154 |
+
'lon': row['centre_lon'],
|
| 155 |
+
'id': row['product_id']
|
| 156 |
+
})
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error fetching image for idx {idx}: {e}")
|
| 159 |
+
|
| 160 |
+
# Sort results by score descending (since futures complete in random order)
|
| 161 |
+
results.sort(key=lambda x: x['score'], reverse=True)
|
| 162 |
+
return results
|
| 163 |
+
|
| 164 |
+
def get_all_results_metadata(model, filtered_indices, probs):
|
| 165 |
+
if len(filtered_indices) == 0:
|
| 166 |
+
return []
|
| 167 |
+
|
| 168 |
+
# Sort by score descending
|
| 169 |
+
filtered_scores = probs[filtered_indices]
|
| 170 |
+
sorted_order = np.argsort(filtered_scores)[::-1]
|
| 171 |
+
sorted_indices = filtered_indices[sorted_order]
|
| 172 |
+
|
| 173 |
+
# Extract from DataFrame
|
| 174 |
+
df_results = model.df_embed.iloc[sorted_indices].copy()
|
| 175 |
+
df_results['score'] = probs[sorted_indices]
|
| 176 |
+
|
| 177 |
+
# Rename columns
|
| 178 |
+
df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'})
|
| 179 |
+
|
| 180 |
+
# Convert to list of dicts
|
| 181 |
+
return df_results[['id', 'lat', 'lon', 'score']].to_dict('records')
|
| 182 |
+
|
| 183 |
+
def search_text(query, threshold, model_name):
|
| 184 |
+
model, error = get_active_model(model_name)
|
| 185 |
+
if error:
|
| 186 |
+
yield None, None, error, None, None, None, None
|
| 187 |
+
return
|
| 188 |
+
|
| 189 |
+
if not query:
|
| 190 |
+
yield None, None, "Please enter a query.", None, None, None, None
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
timings = {}
|
| 195 |
+
|
| 196 |
+
# 1. Encode Text
|
| 197 |
+
yield None, None, "Encoding text...", None, None, None, None
|
| 198 |
+
t0 = time.time()
|
| 199 |
+
text_features = model.encode_text(query)
|
| 200 |
+
timings['Encoding'] = time.time() - t0
|
| 201 |
+
|
| 202 |
+
if text_features is None:
|
| 203 |
+
yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
# 2. Search
|
| 207 |
+
yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None
|
| 208 |
+
t0 = time.time()
|
| 209 |
+
probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0)
|
| 210 |
+
timings['Retrieval'] = time.time() - t0
|
| 211 |
+
|
| 212 |
+
if probs is None:
|
| 213 |
+
yield None, None, "Search failed (embeddings missing?).", None, None, None, None
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
# Show geographic distribution (not timed)
|
| 217 |
+
df_embed = model.df_embed
|
| 218 |
+
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})')
|
| 219 |
+
|
| 220 |
+
# 3. Download Images
|
| 221 |
+
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 222 |
+
t0 = time.time()
|
| 223 |
+
top_indices = top_indices[:10]
|
| 224 |
+
results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query)
|
| 225 |
+
timings['Download'] = time.time() - t0
|
| 226 |
+
|
| 227 |
+
# 4. Visualize - keep geo_dist_map visible
|
| 228 |
+
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 229 |
+
t0 = time.time()
|
| 230 |
+
fig_results = plot_top5_overview(None, results, query_info=query)
|
| 231 |
+
gallery_items = format_results_for_gallery(results)
|
| 232 |
+
timings['Visualization'] = time.time() - t0
|
| 233 |
+
|
| 234 |
+
# 5. Generate Final Status
|
| 235 |
+
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
|
| 236 |
+
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
|
| 237 |
+
|
| 238 |
+
all_results = get_all_results_metadata(model, filtered_indices, probs)
|
| 239 |
+
results_txt = format_results_to_text(all_results)
|
| 240 |
+
|
| 241 |
+
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
import traceback
|
| 245 |
+
traceback.print_exc()
|
| 246 |
+
yield None, None, f"Error: {str(e)}", None, None, None, None
|
| 247 |
+
|
| 248 |
+
def search_image(image_input, threshold, model_name):
|
| 249 |
+
model, error = get_active_model(model_name)
|
| 250 |
+
if error:
|
| 251 |
+
yield None, None, error, None, None, None, None
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
if image_input is None:
|
| 255 |
+
yield None, None, "Please upload an image.", None, None, None, None
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
timings = {}
|
| 260 |
+
|
| 261 |
+
# 1. Encode Image
|
| 262 |
+
yield None, None, "Encoding image...", None, None, None, None
|
| 263 |
+
t0 = time.time()
|
| 264 |
+
image_features = model.encode_image(image_input)
|
| 265 |
+
timings['Encoding'] = time.time() - t0
|
| 266 |
+
|
| 267 |
+
if image_features is None:
|
| 268 |
+
yield None, None, "Model does not support image encoding.", None, None, None, None
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
# 2. Search
|
| 272 |
+
yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None
|
| 273 |
+
t0 = time.time()
|
| 274 |
+
probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0)
|
| 275 |
+
timings['Retrieval'] = time.time() - t0
|
| 276 |
+
|
| 277 |
+
# Show geographic distribution (not timed)
|
| 278 |
+
df_embed = model.df_embed
|
| 279 |
+
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})')
|
| 280 |
+
|
| 281 |
+
# 3. Download Images
|
| 282 |
+
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 283 |
+
t0 = time.time()
|
| 284 |
+
top_indices = top_indices[:6]
|
| 285 |
+
results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query")
|
| 286 |
+
timings['Download'] = time.time() - t0
|
| 287 |
+
|
| 288 |
+
# 4. Visualize - keep geo_dist_map visible
|
| 289 |
+
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 290 |
+
t0 = time.time()
|
| 291 |
+
fig_results = plot_top5_overview(image_input, results, query_info="Image Query")
|
| 292 |
+
gallery_items = format_results_for_gallery(results)
|
| 293 |
+
timings['Visualization'] = time.time() - t0
|
| 294 |
+
|
| 295 |
+
# 5. Generate Final Status
|
| 296 |
+
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
|
| 297 |
+
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
|
| 298 |
+
|
| 299 |
+
all_results = get_all_results_metadata(model, filtered_indices, probs)
|
| 300 |
+
results_txt = format_results_to_text(all_results[:50])
|
| 301 |
+
|
| 302 |
+
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
import traceback
|
| 306 |
+
traceback.print_exc()
|
| 307 |
+
yield None, None, f"Error: {str(e)}", None, None, None, None
|
| 308 |
+
|
| 309 |
+
def search_location(lat, lon, threshold):
|
| 310 |
+
model_name = "SatCLIP"
|
| 311 |
+
model, error = get_active_model(model_name)
|
| 312 |
+
if error:
|
| 313 |
+
yield None, None, error, None, None, None, None
|
| 314 |
+
return
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
timings = {}
|
| 318 |
+
|
| 319 |
+
# 1. Encode Location
|
| 320 |
+
yield None, None, "Encoding location...", None, None, None, None
|
| 321 |
+
t0 = time.time()
|
| 322 |
+
loc_features = model.encode_location(float(lat), float(lon))
|
| 323 |
+
timings['Encoding'] = time.time() - t0
|
| 324 |
+
|
| 325 |
+
if loc_features is None:
|
| 326 |
+
yield None, None, "Location encoding failed.", None, None, None, None
|
| 327 |
+
return
|
| 328 |
+
|
| 329 |
+
# 2. Search
|
| 330 |
+
yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None
|
| 331 |
+
t0 = time.time()
|
| 332 |
+
probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0)
|
| 333 |
+
timings['Retrieval'] = time.time() - t0
|
| 334 |
+
|
| 335 |
+
# 3. Generate Distribution Map (not timed for location distribution)
|
| 336 |
+
yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None
|
| 337 |
+
df_embed = model.df_embed
|
| 338 |
+
top_10_indices = top_indices[:10]
|
| 339 |
+
top_10_results = []
|
| 340 |
+
for idx in top_10_indices:
|
| 341 |
+
row = df_embed.iloc[idx]
|
| 342 |
+
top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']})
|
| 343 |
+
|
| 344 |
+
# Show geographic distribution (not timed)
|
| 345 |
+
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})')
|
| 346 |
+
|
| 347 |
+
# 4. Download Images
|
| 348 |
+
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 349 |
+
t0 = time.time()
|
| 350 |
+
top_6_indices = top_indices[:6]
|
| 351 |
+
results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}")
|
| 352 |
+
|
| 353 |
+
# Get query tile
|
| 354 |
+
query_tile = None
|
| 355 |
+
try:
|
| 356 |
+
lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce')
|
| 357 |
+
lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce')
|
| 358 |
+
dists = (lats - float(lat))**2 + (lons - float(lon))**2
|
| 359 |
+
nearest_idx = dists.idxmin()
|
| 360 |
+
pid = df_embed.loc[nearest_idx, 'product_id']
|
| 361 |
+
query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False)
|
| 362 |
+
except Exception as e:
|
| 363 |
+
print(f"Error fetching nearest MajorTOM image: {e}")
|
| 364 |
+
if query_tile is None:
|
| 365 |
+
query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})")
|
| 366 |
+
timings['Download'] = time.time() - t0
|
| 367 |
+
|
| 368 |
+
# 5. Visualize - keep geo_dist_map visible
|
| 369 |
+
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 370 |
+
t0 = time.time()
|
| 371 |
+
fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}")
|
| 372 |
+
gallery_items = format_results_for_gallery(results)
|
| 373 |
+
timings['Visualization'] = time.time() - t0
|
| 374 |
+
|
| 375 |
+
# 6. Generate Final Status
|
| 376 |
+
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
|
| 377 |
+
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
|
| 378 |
+
|
| 379 |
+
all_results = get_all_results_metadata(model, filtered_indices, probs)
|
| 380 |
+
results_txt = format_results_to_text(all_results)
|
| 381 |
+
|
| 382 |
+
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 383 |
+
|
| 384 |
+
except Exception as e:
|
| 385 |
+
import traceback
|
| 386 |
+
traceback.print_exc()
|
| 387 |
+
yield None, None, f"Error: {str(e)}", None, None, None, None
|
| 388 |
+
|
| 389 |
+
def generate_status_msg(count, threshold, results):
|
| 390 |
+
status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n"
|
| 391 |
+
for i, res in enumerate(results[:3]):
|
| 392 |
+
status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n"
|
| 393 |
+
return status_msg
|
| 394 |
+
|
| 395 |
+
def get_initial_plot():
|
| 396 |
+
# Use FarSLIP as default for initial plot, fallback to SigLIP
|
| 397 |
+
df_vis = None
|
| 398 |
+
img = None
|
| 399 |
+
if 'FarSLIP' in models and models['FarSLIP'].df_embed is not None:
|
| 400 |
+
img, df_vis = plot_global_map_static(models['FarSLIP'].df_embed)
|
| 401 |
+
# fig = plot_global_map(models['FarSLIP'].df_embed)
|
| 402 |
+
elif 'SigLIP' in models and models['SigLIP'].df_embed is not None:
|
| 403 |
+
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
|
| 404 |
+
return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False)
|
| 405 |
+
|
| 406 |
+
def handle_map_click(evt: gr.SelectData, df_vis):
|
| 407 |
+
if evt is None:
|
| 408 |
+
return None, None, None, "No point selected."
|
| 409 |
+
|
| 410 |
+
try:
|
| 411 |
+
x, y = evt.index[0], evt.index[1]
|
| 412 |
+
|
| 413 |
+
# Image dimensions (New)
|
| 414 |
+
img_width = 4000
|
| 415 |
+
img_height = 2000
|
| 416 |
+
|
| 417 |
+
# Scaled Margins (Proportional to 4000x2000)
|
| 418 |
+
left_margin = 110
|
| 419 |
+
right_margin = 110
|
| 420 |
+
top_margin = 100
|
| 421 |
+
bottom_margin = 67
|
| 422 |
+
|
| 423 |
+
plot_width = img_width - left_margin - right_margin
|
| 424 |
+
plot_height = img_height - top_margin - bottom_margin
|
| 425 |
+
|
| 426 |
+
# Adjust for aspect ratio preservation
|
| 427 |
+
map_aspect = 360.0 / 180.0 # 2.0
|
| 428 |
+
plot_aspect = plot_width / plot_height
|
| 429 |
+
|
| 430 |
+
if plot_aspect > map_aspect:
|
| 431 |
+
actual_map_width = plot_height * map_aspect
|
| 432 |
+
actual_map_height = plot_height
|
| 433 |
+
h_offset = (plot_width - actual_map_width) / 2
|
| 434 |
+
v_offset = 0
|
| 435 |
+
else:
|
| 436 |
+
actual_map_width = plot_width
|
| 437 |
+
actual_map_height = plot_width / map_aspect
|
| 438 |
+
h_offset = 0
|
| 439 |
+
v_offset = (plot_height - actual_map_height) / 2
|
| 440 |
+
|
| 441 |
+
# Calculate relative position within the plot area
|
| 442 |
+
x_in_plot = x - left_margin
|
| 443 |
+
y_in_plot = y - top_margin
|
| 444 |
+
|
| 445 |
+
# Check if click is within the actual map bounds
|
| 446 |
+
if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or
|
| 447 |
+
y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height):
|
| 448 |
+
return None, None, None, "Click outside map area. Please click on the map."
|
| 449 |
+
|
| 450 |
+
# Calculate relative position within the map (0 to 1)
|
| 451 |
+
x_rel = (x_in_plot - h_offset) / actual_map_width
|
| 452 |
+
y_rel = (y_in_plot - v_offset) / actual_map_height
|
| 453 |
+
|
| 454 |
+
# Clamp to [0, 1]
|
| 455 |
+
x_rel = max(0, min(1, x_rel))
|
| 456 |
+
y_rel = max(0, min(1, y_rel))
|
| 457 |
+
|
| 458 |
+
# Convert to geographic coordinates
|
| 459 |
+
lon = x_rel * 360 - 180
|
| 460 |
+
lat = 90 - y_rel * 180
|
| 461 |
+
|
| 462 |
+
# Find nearest point in df_vis if available
|
| 463 |
+
pid = ""
|
| 464 |
+
if df_vis is not None:
|
| 465 |
+
dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2
|
| 466 |
+
min_idx = dists.idxmin()
|
| 467 |
+
nearest_row = df_vis.loc[min_idx]
|
| 468 |
+
|
| 469 |
+
if dists[min_idx] < 25:
|
| 470 |
+
lat = nearest_row['centre_lat']
|
| 471 |
+
lon = nearest_row['centre_lon']
|
| 472 |
+
pid = nearest_row['product_id']
|
| 473 |
+
|
| 474 |
+
except Exception as e:
|
| 475 |
+
print(f"Error handling click: {e}")
|
| 476 |
+
import traceback
|
| 477 |
+
traceback.print_exc()
|
| 478 |
+
return None, None, None, f"Error: {e}"
|
| 479 |
+
|
| 480 |
+
return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})"
|
| 481 |
+
|
| 482 |
+
def download_image_by_location(lat, lon, pid, model_name):
|
| 483 |
+
"""Download and return the image at the specified location"""
|
| 484 |
+
if lat is None or lon is None:
|
| 485 |
+
return None, "Please specify coordinates first."
|
| 486 |
+
|
| 487 |
+
model, error = get_active_model(model_name)
|
| 488 |
+
if error:
|
| 489 |
+
return None, error
|
| 490 |
+
|
| 491 |
+
try:
|
| 492 |
+
# Convert to float to ensure proper formatting
|
| 493 |
+
lat = float(lat)
|
| 494 |
+
lon = float(lon)
|
| 495 |
+
|
| 496 |
+
# Find Product ID if not provided
|
| 497 |
+
if not pid:
|
| 498 |
+
df = model.df_embed
|
| 499 |
+
lats = pd.to_numeric(df['centre_lat'], errors='coerce')
|
| 500 |
+
lons = pd.to_numeric(df['centre_lon'], errors='coerce')
|
| 501 |
+
dists = (lats - lat)**2 + (lons - lon)**2
|
| 502 |
+
nearest_idx = dists.idxmin()
|
| 503 |
+
pid = df.loc[nearest_idx, 'product_id']
|
| 504 |
+
|
| 505 |
+
# Download image
|
| 506 |
+
img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True)
|
| 507 |
+
|
| 508 |
+
if img_384 is None:
|
| 509 |
+
return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})"
|
| 510 |
+
|
| 511 |
+
return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})"
|
| 512 |
+
|
| 513 |
+
except Exception as e:
|
| 514 |
+
import traceback
|
| 515 |
+
traceback.print_exc()
|
| 516 |
+
return None, f"Error: {str(e)}"
|
| 517 |
+
|
| 518 |
+
def reset_to_global_map():
|
| 519 |
+
"""Reset the map to the initial global distribution view"""
|
| 520 |
+
img = None
|
| 521 |
+
df_vis = None
|
| 522 |
+
if 'FarSLIP' in models and models['FarSLIP'].df_embed is not None:
|
| 523 |
+
img, df_vis = plot_global_map_static(models['FarSLIP'].df_embed)
|
| 524 |
+
elif 'SigLIP' in models and models['SigLIP'].df_embed is not None:
|
| 525 |
+
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
|
| 526 |
+
|
| 527 |
+
return gr.update(value=img, visible=True), [img], df_vis
|
| 528 |
+
|
| 529 |
+
def format_results_to_text(results):
|
| 530 |
+
if not results:
|
| 531 |
+
return "No results found."
|
| 532 |
+
|
| 533 |
+
txt = f"Top {len(results)} Retrieval Results\n"
|
| 534 |
+
txt += "=" * 30 + "\n\n"
|
| 535 |
+
for i, res in enumerate(results):
|
| 536 |
+
txt += f"Rank: {i+1}\n"
|
| 537 |
+
txt += f"Product ID: {res['id']}\n"
|
| 538 |
+
txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n"
|
| 539 |
+
txt += f"Similarity Score: {res['score']:.6f}\n"
|
| 540 |
+
txt += "-" * 30 + "\n"
|
| 541 |
+
return txt
|
| 542 |
+
|
| 543 |
+
def save_plot(figs):
|
| 544 |
+
if figs is None:
|
| 545 |
+
return None
|
| 546 |
+
try:
|
| 547 |
+
# If it's a single image (initial state), save as png
|
| 548 |
+
if isinstance(figs, PILImage.Image):
|
| 549 |
+
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
|
| 550 |
+
os.close(fd)
|
| 551 |
+
figs.save(path)
|
| 552 |
+
return path
|
| 553 |
+
|
| 554 |
+
# If it's a list/tuple of images [map_img, results_img]
|
| 555 |
+
if isinstance(figs, (list, tuple)):
|
| 556 |
+
# If only one image in list, save as PNG
|
| 557 |
+
if len(figs) == 1 and isinstance(figs[0], PILImage.Image):
|
| 558 |
+
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
|
| 559 |
+
os.close(fd)
|
| 560 |
+
figs[0].save(path)
|
| 561 |
+
return path
|
| 562 |
+
|
| 563 |
+
fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_')
|
| 564 |
+
os.close(fd)
|
| 565 |
+
|
| 566 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
| 567 |
+
# Save Map
|
| 568 |
+
if figs[0] is not None:
|
| 569 |
+
map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png')
|
| 570 |
+
figs[0].save(map_path)
|
| 571 |
+
zipf.write(map_path, arcname='map_distribution.png')
|
| 572 |
+
|
| 573 |
+
# Save Results
|
| 574 |
+
if len(figs) > 1 and figs[1] is not None:
|
| 575 |
+
res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png')
|
| 576 |
+
figs[1].save(res_path)
|
| 577 |
+
zipf.write(res_path, arcname='retrieval_results.png')
|
| 578 |
+
|
| 579 |
+
# Save Results Text
|
| 580 |
+
if len(figs) > 2 and figs[2] is not None:
|
| 581 |
+
txt_path = os.path.join(tempfile.gettempdir(), 'results.txt')
|
| 582 |
+
with open(txt_path, 'w', encoding='utf-8') as f:
|
| 583 |
+
f.write(figs[2])
|
| 584 |
+
zipf.write(txt_path, arcname='results.txt')
|
| 585 |
+
|
| 586 |
+
return zip_path
|
| 587 |
+
|
| 588 |
+
# Fallback for Plotly figure (if any)
|
| 589 |
+
# Create a temporary file
|
| 590 |
+
fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_')
|
| 591 |
+
os.close(fd)
|
| 592 |
+
|
| 593 |
+
# Write to the temporary file
|
| 594 |
+
figs.write_html(path)
|
| 595 |
+
return path
|
| 596 |
+
except Exception as e:
|
| 597 |
+
print(f"Error saving: {e}")
|
| 598 |
+
return None
|
| 599 |
+
|
| 600 |
+
# Gradio Blocks Interface
|
| 601 |
+
with gr.Blocks(title="EarthEmbeddingExplorer") as demo:
|
| 602 |
+
gr.Markdown("# EarthEmbeddingExplorer")
|
| 603 |
+
gr.HTML("""
|
| 604 |
+
<div style="font-size: 1.2em;">
|
| 605 |
+
EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images.
|
| 606 |
+
</div>
|
| 607 |
+
|
| 608 |
+
<div style="display: flex; gap: 0.2em; align-items: center; justify-content: center;">
|
| 609 |
+
<a href="https://www.modelscope.cn/studios/VoyagerX/EarthExplorer"><img src="https://img.shields.io/badge/Open in ModelScope.cn-xGPU-624aff"></a>
|
| 610 |
+
<a href="https://www.modelscope.ai/studios/VoyagerX/EarthExplorer"><img src="https://img.shields.io/badge/Open in ModelScope.ai-CPU-624aff"></a>
|
| 611 |
+
<a href="https://huggingface.co/spaces/ML4Sustain/EarthExplorer"><img src="https://img.shields.io/badge/Open in HF Space-CPU-FFD21E"></a>
|
| 612 |
+
<a href="https://modelscope.cn/studios/VoyagerX/EarthExplorer/file/view/master/Tutorial.md?status=1"> <img src="https://img.shields.io/badge/Tutorial-📖-007bff"> </a>
|
| 613 |
+
<a href="https://www.modelscope.cn/learn/3958"> <img src="https://img.shields.io/badge/中文教程-📖-007bff"> </a>
|
| 614 |
+
</div>
|
| 615 |
+
|
| 616 |
+
""")
|
| 617 |
+
|
| 618 |
+
with gr.Row():
|
| 619 |
+
with gr.Column(scale=4):
|
| 620 |
+
with gr.Tabs():
|
| 621 |
+
with gr.TabItem("Text Search") as tab_text:
|
| 622 |
+
model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model")
|
| 623 |
+
query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier")
|
| 624 |
+
|
| 625 |
+
gr.Examples(
|
| 626 |
+
examples=[
|
| 627 |
+
["a satellite image of a river around a city"],
|
| 628 |
+
["a satellite image of a rainforest"],
|
| 629 |
+
["a satellite image of a slum"],
|
| 630 |
+
["a satellite image of a glacier"],
|
| 631 |
+
["a satellite image of snow covered mountains"]
|
| 632 |
+
],
|
| 633 |
+
inputs=[query_input],
|
| 634 |
+
label="Text Examples"
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
search_btn = gr.Button("Search by Text", variant="primary")
|
| 638 |
+
|
| 639 |
+
with gr.TabItem("Image Search") as tab_image:
|
| 640 |
+
model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP"], value="FarSLIP", label="Model")
|
| 641 |
+
|
| 642 |
+
gr.Markdown("### Option 1: Upload or Select Image")
|
| 643 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
| 644 |
+
|
| 645 |
+
gr.Examples(
|
| 646 |
+
examples=[
|
| 647 |
+
["./examples/example1.png"],
|
| 648 |
+
["./examples/example2.png"],
|
| 649 |
+
["./examples/example3.png"]
|
| 650 |
+
],
|
| 651 |
+
inputs=[image_input],
|
| 652 |
+
label="Image Examples"
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
gr.Markdown("### Option 2: Click Map or Enter Coordinates")
|
| 656 |
+
btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
|
| 657 |
+
|
| 658 |
+
with gr.Row():
|
| 659 |
+
img_lat = gr.Number(label="Latitude", interactive=True)
|
| 660 |
+
img_lon = gr.Number(label="Longitude", interactive=True)
|
| 661 |
+
|
| 662 |
+
img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
|
| 663 |
+
img_click_status = gr.Markdown("")
|
| 664 |
+
|
| 665 |
+
btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary")
|
| 666 |
+
|
| 667 |
+
search_img_btn = gr.Button("Search by Image", variant="primary")
|
| 668 |
+
|
| 669 |
+
with gr.TabItem("Location Search") as tab_location:
|
| 670 |
+
gr.Markdown("Search using **SatCLIP** location encoder.")
|
| 671 |
+
|
| 672 |
+
gr.Markdown("### Click Map or Enter Coordinates")
|
| 673 |
+
btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
|
| 674 |
+
|
| 675 |
+
with gr.Row():
|
| 676 |
+
lat_input = gr.Number(label="Latitude", value=30.0, interactive=True)
|
| 677 |
+
lon_input = gr.Number(label="Longitude", value=120.0, interactive=True)
|
| 678 |
+
|
| 679 |
+
loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
|
| 680 |
+
loc_click_status = gr.Markdown("")
|
| 681 |
+
|
| 682 |
+
gr.Examples(
|
| 683 |
+
examples=[
|
| 684 |
+
[30.32, 120.15],
|
| 685 |
+
[40.7128, -74.0060],
|
| 686 |
+
[24.65, 46.71],
|
| 687 |
+
[-3.4653, -62.2159],
|
| 688 |
+
[64.4, 16.8]
|
| 689 |
+
],
|
| 690 |
+
inputs=[lat_input, lon_input],
|
| 691 |
+
label="Location Examples"
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
search_loc_btn = gr.Button("Search by Location", variant="primary")
|
| 695 |
+
|
| 696 |
+
threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)")
|
| 697 |
+
status_output = gr.Textbox(label="Status", lines=10)
|
| 698 |
+
save_btn = gr.Button("Download Result")
|
| 699 |
+
download_file = gr.File(label="Zipped Results", height=40)
|
| 700 |
+
|
| 701 |
+
with gr.Column(scale=6):
|
| 702 |
+
plot_map = gr.Image(
|
| 703 |
+
label="Geographical Distribution",
|
| 704 |
+
type="pil",
|
| 705 |
+
interactive=False,
|
| 706 |
+
height=400,
|
| 707 |
+
width=800,
|
| 708 |
+
visible=True
|
| 709 |
+
)
|
| 710 |
+
plot_map_interactive = gr.Plot(
|
| 711 |
+
label="Geographical Distribution (Interactive)",
|
| 712 |
+
visible=False
|
| 713 |
+
)
|
| 714 |
+
results_plot = gr.Image(label="Top 5 Matched Images", type="pil")
|
| 715 |
+
gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto")
|
| 716 |
+
|
| 717 |
+
current_fig = gr.State()
|
| 718 |
+
map_data_state = gr.State()
|
| 719 |
+
|
| 720 |
+
# Initial Load
|
| 721 |
+
demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive])
|
| 722 |
+
|
| 723 |
+
# Reset Map Buttons
|
| 724 |
+
btn_reset_map_img.click(
|
| 725 |
+
fn=reset_to_global_map,
|
| 726 |
+
outputs=[plot_map, current_fig, map_data_state]
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
btn_reset_map_loc.click(
|
| 730 |
+
fn=reset_to_global_map,
|
| 731 |
+
outputs=[plot_map, current_fig, map_data_state]
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
# Map Click Event - updates Image Search coordinates
|
| 735 |
+
plot_map.select(
|
| 736 |
+
fn=handle_map_click,
|
| 737 |
+
inputs=[map_data_state],
|
| 738 |
+
outputs=[img_lat, img_lon, img_pid, img_click_status]
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Map Click Event - also updates Location Search coordinates
|
| 742 |
+
plot_map.select(
|
| 743 |
+
fn=handle_map_click,
|
| 744 |
+
inputs=[map_data_state],
|
| 745 |
+
outputs=[lat_input, lon_input, loc_pid, loc_click_status]
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Download Image by Geolocation
|
| 749 |
+
btn_download_img.click(
|
| 750 |
+
fn=download_image_by_location,
|
| 751 |
+
inputs=[img_lat, img_lon, img_pid, model_selector_img],
|
| 752 |
+
outputs=[image_input, img_click_status]
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
# Search Event (Text)
|
| 756 |
+
search_btn.click(
|
| 757 |
+
fn=search_text,
|
| 758 |
+
inputs=[query_input, threshold_slider, model_selector_text],
|
| 759 |
+
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# Search Event (Image)
|
| 763 |
+
search_img_btn.click(
|
| 764 |
+
fn=search_image,
|
| 765 |
+
inputs=[image_input, threshold_slider, model_selector_img],
|
| 766 |
+
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
# Search Event (Location)
|
| 770 |
+
search_loc_btn.click(
|
| 771 |
+
fn=search_location,
|
| 772 |
+
inputs=[lat_input, lon_input, threshold_slider],
|
| 773 |
+
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
# Save Event
|
| 777 |
+
save_btn.click(
|
| 778 |
+
fn=save_plot,
|
| 779 |
+
inputs=[current_fig],
|
| 780 |
+
outputs=[download_file]
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# Tab Selection Events
|
| 784 |
+
def show_static_map():
|
| 785 |
+
return gr.update(visible=True), gr.update(visible=False)
|
| 786 |
+
|
| 787 |
+
tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
|
| 788 |
+
tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
|
| 789 |
+
tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
|
| 790 |
+
|
| 791 |
+
if __name__ == "__main__":
|
| 792 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
configs/huggingface.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
siglip:
|
| 2 |
+
ckpt_path: "hf"
|
| 3 |
+
model_name: "ViT-SO400M-14-SigLIP-384"
|
| 4 |
+
tokenizer_path: "hf"
|
| 5 |
+
embedding_path: "hf://ML4Sustain/EarthEmbeddings/uniform_sample_250k/siglip/SigLIP_grid_sample_center_384x384_243k.parquet"
|
| 6 |
+
farslip:
|
| 7 |
+
ckpt_path: "hf"
|
| 8 |
+
model_name: "ViT-B-16"
|
| 9 |
+
embedding_path: "hf://ML4Sustain/EarthEmbeddings/uniform_sample_250k/farslip/FarSLIP_grid_sample_center_384x384_243k.parquet"
|
| 10 |
+
satclip:
|
| 11 |
+
ckpt_path: "hf"
|
| 12 |
+
embedding_path: "hf://ML4Sustain/EarthEmbeddings/uniform_sample_250k/satclip/SatCLIP_grid_sample_center_384x384_243k.parquet"
|
countries.geo.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_utils.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fsspec
|
| 2 |
+
import pyarrow.parquet as pq
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from rasterio.io import MemoryFile
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import cartopy.crs as ccrs
|
| 9 |
+
import cartopy.io.img_tiles as cimgt
|
| 10 |
+
from matplotlib.patches import Rectangle
|
| 11 |
+
import math
|
| 12 |
+
from matplotlib.figure import Figure
|
| 13 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def crop_center(img_array, cropx, cropy):
|
| 17 |
+
y, x, c = img_array.shape
|
| 18 |
+
startx = x // 2 - (cropx // 2)
|
| 19 |
+
starty = y // 2 - (cropy // 2)
|
| 20 |
+
return img_array[starty:starty+cropy, startx:startx+cropx]
|
| 21 |
+
|
| 22 |
+
def read_tif_bytes(tif_bytes):
|
| 23 |
+
with MemoryFile(tif_bytes) as mem_f:
|
| 24 |
+
with mem_f.open(driver='GTiff') as f:
|
| 25 |
+
return f.read().squeeze()
|
| 26 |
+
|
| 27 |
+
def read_row_memory(row_dict, columns=["thumbnail"]):
|
| 28 |
+
url = row_dict['parquet_url']
|
| 29 |
+
row_idx = row_dict['parquet_row']
|
| 30 |
+
|
| 31 |
+
fs_options = {
|
| 32 |
+
"cache_type": "readahead",
|
| 33 |
+
"block_size": 5 * 1024 * 1024
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
with fsspec.open(url, mode='rb', **fs_options) as f:
|
| 37 |
+
with pq.ParquetFile(f) as pf:
|
| 38 |
+
table = pf.read_row_group(row_idx, columns=columns)
|
| 39 |
+
|
| 40 |
+
row_output = {}
|
| 41 |
+
for col in columns:
|
| 42 |
+
col_data = table[col][0].as_py()
|
| 43 |
+
|
| 44 |
+
if col != 'thumbnail':
|
| 45 |
+
row_output[col] = read_tif_bytes(col_data)
|
| 46 |
+
else:
|
| 47 |
+
stream = BytesIO(col_data)
|
| 48 |
+
row_output[col] = Image.open(stream)
|
| 49 |
+
|
| 50 |
+
return row_output
|
| 51 |
+
|
| 52 |
+
def download_and_process_image(product_id, df_source=None, verbose=True):
|
| 53 |
+
if df_source is None:
|
| 54 |
+
if verbose: print("❌ Error: No DataFrame provided.")
|
| 55 |
+
return None, None
|
| 56 |
+
|
| 57 |
+
row_subset = df_source[df_source['product_id'] == product_id]
|
| 58 |
+
if len(row_subset) == 0:
|
| 59 |
+
if verbose: print(f"❌ Error: Product ID {product_id} not found in DataFrame.")
|
| 60 |
+
return None, None
|
| 61 |
+
|
| 62 |
+
row_dict = row_subset.iloc[0].to_dict()
|
| 63 |
+
|
| 64 |
+
if 'parquet_url' in row_dict:
|
| 65 |
+
url = row_dict['parquet_url']
|
| 66 |
+
if 'huggingface.co' in url:
|
| 67 |
+
row_dict['parquet_url'] = url.replace('https://huggingface.co', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
|
| 68 |
+
elif 'hf-mirror.com' in url:
|
| 69 |
+
row_dict['parquet_url'] = url.replace('https://hf-mirror.com', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
|
| 70 |
+
else:
|
| 71 |
+
if verbose: print("❌ Error: 'parquet_url' missing in metadata.")
|
| 72 |
+
return None, None
|
| 73 |
+
|
| 74 |
+
if verbose: print(f"⬇️ Fetching data for {product_id} from {row_dict['parquet_url']}...")
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
bands_data = read_row_memory(row_dict, columns=['B04', 'B03', 'B02'])
|
| 78 |
+
|
| 79 |
+
if not all(b in bands_data for b in ['B04', 'B03', 'B02']):
|
| 80 |
+
if verbose: print(f"❌ Error: Missing bands in fetched data for {product_id}")
|
| 81 |
+
return None, None
|
| 82 |
+
|
| 83 |
+
rgb_img = np.stack([bands_data['B04'], bands_data['B03'], bands_data['B02']], axis=-1)
|
| 84 |
+
|
| 85 |
+
if verbose:
|
| 86 |
+
print(f"Raw RGB stats: Min={rgb_img.min()}, Max={rgb_img.max()}, Mean={rgb_img.mean()}, Dtype={rgb_img.dtype}")
|
| 87 |
+
|
| 88 |
+
# Check if data is already 0-255 or 0-1
|
| 89 |
+
if rgb_img.max() <= 255:
|
| 90 |
+
# Assume it might be uint8 or scaled
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
rgb_norm = (2.5 * (rgb_img.astype(float) / 10000.0)).clip(0, 1)
|
| 94 |
+
rgb_uint8 = (rgb_norm * 255).astype(np.uint8)
|
| 95 |
+
|
| 96 |
+
if verbose:
|
| 97 |
+
print(f"Processed RGB stats: Min={rgb_uint8.min()}, Max={rgb_uint8.max()}, Mean={rgb_uint8.mean()}")
|
| 98 |
+
|
| 99 |
+
img_full = Image.fromarray(rgb_uint8)
|
| 100 |
+
|
| 101 |
+
if rgb_uint8.shape[0] >= 384 and rgb_uint8.shape[1] >= 384:
|
| 102 |
+
cropped_array = crop_center(rgb_uint8, 384, 384)
|
| 103 |
+
img_384 = Image.fromarray(cropped_array)
|
| 104 |
+
else:
|
| 105 |
+
if verbose: print(f"⚠️ Image too small {rgb_uint8.shape}, resizing to 384x384.")
|
| 106 |
+
img_384 = img_full.resize((384, 384))
|
| 107 |
+
|
| 108 |
+
if verbose: print(f"✅ Successfully processed {product_id}")
|
| 109 |
+
return img_384, img_full
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
if verbose: print(f"❌ Error processing {product_id}: {e}")
|
| 113 |
+
import traceback
|
| 114 |
+
traceback.print_exc()
|
| 115 |
+
return None, None
|
| 116 |
+
|
| 117 |
+
# Define Esri Imagery Class
|
| 118 |
+
class EsriImagery(cimgt.GoogleTiles):
|
| 119 |
+
def _image_url(self, tile):
|
| 120 |
+
x, y, z = tile
|
| 121 |
+
return f'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
|
| 122 |
+
|
| 123 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 124 |
+
|
| 125 |
+
def get_placeholder_image(text="Image Unavailable", size=(384, 384)):
|
| 126 |
+
img = Image.new('RGB', size, color=(200, 200, 200))
|
| 127 |
+
d = ImageDraw.Draw(img)
|
| 128 |
+
try:
|
| 129 |
+
# Try to load a default font
|
| 130 |
+
font = ImageFont.load_default()
|
| 131 |
+
except:
|
| 132 |
+
font = None
|
| 133 |
+
|
| 134 |
+
# Draw text in center (rough approximation)
|
| 135 |
+
# For better centering we would need font metrics, but simple is fine here
|
| 136 |
+
d.text((20, size[1]//2), text, fill=(0, 0, 0), font=font)
|
| 137 |
+
return img
|
| 138 |
+
|
| 139 |
+
def get_esri_satellite_image(lat, lon, score=None, rank=None, query=None):
|
| 140 |
+
"""
|
| 141 |
+
Generates a satellite image visualization using Esri World Imagery via Cartopy.
|
| 142 |
+
Matches the style of the provided notebook.
|
| 143 |
+
Uses OO Matplotlib API for thread safety.
|
| 144 |
+
"""
|
| 145 |
+
try:
|
| 146 |
+
imagery = EsriImagery()
|
| 147 |
+
|
| 148 |
+
# Create figure using OO API
|
| 149 |
+
fig = Figure(figsize=(5, 5), dpi=100)
|
| 150 |
+
canvas = FigureCanvasAgg(fig)
|
| 151 |
+
ax = fig.add_subplot(1, 1, 1, projection=imagery.crs)
|
| 152 |
+
|
| 153 |
+
# Set extent to approx 10km x 10km around the point
|
| 154 |
+
extent_deg = 0.05
|
| 155 |
+
ax.set_extent([lon - extent_deg, lon + extent_deg, lat - extent_deg, lat + extent_deg], crs=ccrs.PlateCarree())
|
| 156 |
+
|
| 157 |
+
# Add the imagery
|
| 158 |
+
ax.add_image(imagery, 14)
|
| 159 |
+
|
| 160 |
+
# Add a marker for the center
|
| 161 |
+
ax.plot(lon, lat, marker='+', color='yellow', markersize=12, markeredgewidth=2, transform=ccrs.PlateCarree())
|
| 162 |
+
|
| 163 |
+
# Add Bounding Box (3840m x 3840m)
|
| 164 |
+
box_size_m = 384 * 10 # 3840m
|
| 165 |
+
|
| 166 |
+
# Convert meters to degrees (approx)
|
| 167 |
+
# 1 deg lat = 111320m
|
| 168 |
+
# 1 deg lon = 111320m * cos(lat)
|
| 169 |
+
dlat = (box_size_m / 111320)
|
| 170 |
+
dlon = (box_size_m / (111320 * math.cos(math.radians(lat))))
|
| 171 |
+
|
| 172 |
+
# Bottom-Left corner
|
| 173 |
+
rect_lon = lon - dlon / 2
|
| 174 |
+
rect_lat = lat - dlat / 2
|
| 175 |
+
|
| 176 |
+
# Add Rectangle
|
| 177 |
+
rect = Rectangle((rect_lon, rect_lat), dlon, dlat,
|
| 178 |
+
linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree())
|
| 179 |
+
ax.add_patch(rect)
|
| 180 |
+
|
| 181 |
+
# Title
|
| 182 |
+
title_parts = []
|
| 183 |
+
if query: title_parts.append(f"{query}")
|
| 184 |
+
if rank is not None: title_parts.append(f"Rank {rank}")
|
| 185 |
+
if score is not None: title_parts.append(f"Score: {score:.4f}")
|
| 186 |
+
|
| 187 |
+
ax.set_title("\n".join(title_parts), fontsize=10)
|
| 188 |
+
|
| 189 |
+
# Save to buffer
|
| 190 |
+
buf = BytesIO()
|
| 191 |
+
fig.savefig(buf, format='png', bbox_inches='tight')
|
| 192 |
+
buf.seek(0)
|
| 193 |
+
|
| 194 |
+
return Image.open(buf)
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
# Suppress full traceback for network errors to avoid log spam
|
| 198 |
+
error_msg = str(e)
|
| 199 |
+
if "Connection reset by peer" in error_msg or "Network is unreachable" in error_msg or "urlopen error" in error_msg:
|
| 200 |
+
print(f"⚠️ Network warning: Could not fetch Esri satellite map for ({lat:.4f}, {lon:.4f}). Server might be offline.")
|
| 201 |
+
else:
|
| 202 |
+
print(f"Error generating Esri image for {lat}, {lon}: {e}")
|
| 203 |
+
# Only print traceback for non-network errors
|
| 204 |
+
# import traceback
|
| 205 |
+
# traceback.print_exc()
|
| 206 |
+
|
| 207 |
+
# Return a placeholder image with text
|
| 208 |
+
return get_placeholder_image(f"Map Unavailable\n({lat:.2f}, {lon:.2f})")
|
| 209 |
+
|
| 210 |
+
def get_esri_satellite_image_url(lat, lon, zoom=14):
|
| 211 |
+
"""
|
| 212 |
+
Returns the URL for the Esri World Imagery tile at the given location.
|
| 213 |
+
"""
|
| 214 |
+
try:
|
| 215 |
+
imagery = EsriImagery()
|
| 216 |
+
# Calculate tile coordinates
|
| 217 |
+
# This is a simplification, cimgt handles this internally usually
|
| 218 |
+
# But for direct URL we might need more logic or just use the static map approach above
|
| 219 |
+
# For now, let's stick to the static map generation which works
|
| 220 |
+
pass
|
| 221 |
+
except:
|
| 222 |
+
pass
|
| 223 |
+
return None
|
embedding_datasets/grid_sample_center_22k_FarSLIP_384x384.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3555e0279742daa7ee27ba5587a8234f791966ce4411ef804455ee03af52e1aa
|
| 3 |
+
size 23547770
|
embedding_datasets/grid_sample_center_22k_SatCLIP_384x384.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76484097dea1f0fc65e4f2c8d3e825ec3ccda8914da83e3a65aabd86a4f59ec2
|
| 3 |
+
size 25158503
|
embedding_datasets/grid_sample_center_22k_SigLIP_384x384.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a34d949f704f8f4d9d963f28dbe547c341591645cf86d191587a3cc0a866855f
|
| 3 |
+
size 50178408
|
embedding_datasets/grid_sample_metadata.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:808fde21fdae5ef2dc8183c7e8017b286dc2d2419ed64e6058358291cbeef06c
|
| 3 |
+
size 1999889
|
embedding_datasets/zhejiang_sample_center_2k_FarSLIP_384x384.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a4bc51828dd58d45c62d3168557870e8db6659c0c52e9661865326cafb11c88b
|
| 3 |
+
size 2088911
|
embedding_datasets/zhejiang_sample_center_2k_SatCLIP_384x384.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81bdf991b1d6a100108d0cad730bf79b1e9f558f261f2d4fb18c3f68c9ff2796
|
| 3 |
+
size 2719357
|
embedding_datasets/zhejiang_sample_center_2k_SigLIP_384x384.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47ab37984d86b9b15949448f36d52022c24c40e7ff3bc65f44510ff08d0cbe81
|
| 3 |
+
size 4381379
|
embedding_datasets/zhejiang_sample_metadata.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb7c46e7985c05cb010e4fd308489865271b43b250844b8d42a3fe8263d01a78
|
| 3 |
+
size 159438
|
examples/example1.png
ADDED
|
Git LFS Details
|
examples/example2.png
ADDED
|
Git LFS Details
|
examples/example3.png
ADDED
|
Git LFS Details
|
images/CLIP.png
ADDED
|
Git LFS Details
|
images/Image_Search_Amazon.jpg
ADDED
|
Git LFS Details
|
images/Image_Search_Middle_East.jpg
ADDED
|
Git LFS Details
|
images/Location_Search_Amazon.jpg
ADDED
|
Git LFS Details
|
images/Location_Search_Hangzhou.jpg
ADDED
|
Git LFS Details
|
images/Text_Search.jpg
ADDED
|
Git LFS Details
|
images/embedding.png
ADDED
|
Git LFS Details
|
images/framework_en.png
ADDED
|
Git LFS Details
|
images/framework_zh.png
ADDED
|
Git LFS Details
|
images/samples.png
ADDED
|
Git LFS Details
|
models/FarSLIP/.gitignore
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/logs/
|
| 2 |
+
**/wandb/
|
| 3 |
+
models/
|
| 4 |
+
features/
|
| 5 |
+
results/
|
| 6 |
+
src/open_clip_train/config.py
|
| 7 |
+
src/open_clip_train/output_samples/
|
| 8 |
+
**/results_retrieval/
|
| 9 |
+
**/results_classification/
|
| 10 |
+
checkpoints/
|
| 11 |
+
|
| 12 |
+
tests/data/
|
| 13 |
+
*.pt
|
| 14 |
+
|
| 15 |
+
# Byte-compiled / optimized / DLL files
|
| 16 |
+
__pycache__/
|
| 17 |
+
*.py[cod]
|
| 18 |
+
*$py.class
|
| 19 |
+
|
| 20 |
+
# C extensions
|
| 21 |
+
*.so
|
| 22 |
+
|
| 23 |
+
# Distribution / packaging
|
| 24 |
+
.Python
|
| 25 |
+
build/
|
| 26 |
+
develop-eggs/
|
| 27 |
+
dist/
|
| 28 |
+
downloads/
|
| 29 |
+
eggs/
|
| 30 |
+
.eggs/
|
| 31 |
+
lib/
|
| 32 |
+
lib64/
|
| 33 |
+
parts/
|
| 34 |
+
sdist/
|
| 35 |
+
var/
|
| 36 |
+
wheels/
|
| 37 |
+
pip-wheel-metadata/
|
| 38 |
+
share/python-wheels/
|
| 39 |
+
*.egg-info/
|
| 40 |
+
.installed.cfg
|
| 41 |
+
*.egg
|
| 42 |
+
MANIFEST
|
| 43 |
+
|
| 44 |
+
# PyInstaller
|
| 45 |
+
# Usually these files are written by a python script from a template
|
| 46 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 47 |
+
*.manifest
|
| 48 |
+
*.spec
|
| 49 |
+
|
| 50 |
+
# Installer logs
|
| 51 |
+
pip-log.txt
|
| 52 |
+
pip-delete-this-directory.txt
|
| 53 |
+
|
| 54 |
+
# Unit test / coverage reports
|
| 55 |
+
htmlcov/
|
| 56 |
+
.tox/
|
| 57 |
+
.nox/
|
| 58 |
+
.coverage
|
| 59 |
+
.coverage.*
|
| 60 |
+
.cache
|
| 61 |
+
nosetests.xml
|
| 62 |
+
coverage.xml
|
| 63 |
+
*.cover
|
| 64 |
+
*.py,cover
|
| 65 |
+
.hypothesis/
|
| 66 |
+
.pytest_cache/
|
| 67 |
+
|
| 68 |
+
# Translations
|
| 69 |
+
*.mo
|
| 70 |
+
*.pot
|
| 71 |
+
|
| 72 |
+
# Django stuff:
|
| 73 |
+
*.log
|
| 74 |
+
local_settings.py
|
| 75 |
+
db.sqlite3
|
| 76 |
+
db.sqlite3-journal
|
| 77 |
+
|
| 78 |
+
# Flask stuff:
|
| 79 |
+
instance/
|
| 80 |
+
.webassets-cache
|
| 81 |
+
|
| 82 |
+
# Scrapy stuff:
|
| 83 |
+
.scrapy
|
| 84 |
+
|
| 85 |
+
# Sphinx documentation
|
| 86 |
+
docs/_build/
|
| 87 |
+
|
| 88 |
+
# PyBuilder
|
| 89 |
+
target/
|
| 90 |
+
|
| 91 |
+
# Jupyter Notebook
|
| 92 |
+
.ipynb_checkpoints
|
| 93 |
+
|
| 94 |
+
# IPython
|
| 95 |
+
profile_default/
|
| 96 |
+
ipython_config.py
|
| 97 |
+
|
| 98 |
+
# pyenv
|
| 99 |
+
.python-version
|
| 100 |
+
|
| 101 |
+
# pipenv
|
| 102 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 103 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 104 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 105 |
+
# install all needed dependencies.
|
| 106 |
+
#Pipfile.lock
|
| 107 |
+
|
| 108 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 109 |
+
__pypackages__/
|
| 110 |
+
|
| 111 |
+
# Celery stuff
|
| 112 |
+
celerybeat-schedule
|
| 113 |
+
celerybeat.pid
|
| 114 |
+
|
| 115 |
+
# SageMath parsed files
|
| 116 |
+
*.sage.py
|
| 117 |
+
|
| 118 |
+
# Environments
|
| 119 |
+
.env
|
| 120 |
+
.venv
|
| 121 |
+
env/
|
| 122 |
+
venv/
|
| 123 |
+
ENV/
|
| 124 |
+
env.bak/
|
| 125 |
+
venv.bak/
|
| 126 |
+
|
| 127 |
+
# Spyder project settings
|
| 128 |
+
.spyderproject
|
| 129 |
+
.spyproject
|
| 130 |
+
|
| 131 |
+
# Rope project settings
|
| 132 |
+
.ropeproject
|
| 133 |
+
|
| 134 |
+
# mkdocs documentation
|
| 135 |
+
/site
|
| 136 |
+
|
| 137 |
+
# mypy
|
| 138 |
+
.mypy_cache/
|
| 139 |
+
.dmypy.json
|
| 140 |
+
dmypy.json
|
| 141 |
+
|
| 142 |
+
# Pyre type checker
|
| 143 |
+
.pyre/
|
| 144 |
+
sync.sh
|
| 145 |
+
gpu1sync.sh
|
| 146 |
+
.idea
|
| 147 |
+
*.pdf
|
| 148 |
+
**/._*
|
| 149 |
+
**/*DS_*
|
| 150 |
+
**.jsonl
|
| 151 |
+
src/sbatch
|
| 152 |
+
src/misc
|
| 153 |
+
.vscode
|
| 154 |
+
src/debug
|
| 155 |
+
core.*
|
| 156 |
+
|
| 157 |
+
*.out
|
| 158 |
+
|
| 159 |
+
# Allow
|
| 160 |
+
!src/evaluation/misc/results_dbs/*
|