VoyagerXvoyagerx commited on
Commit
29fab93
·
0 Parent(s):

sync from hf

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +54 -0
  2. .gitignore +10 -0
  3. MajorTOM/MajorTOMDataset.py +64 -0
  4. MajorTOM/__init__.py +5 -0
  5. MajorTOM/embedder/MajorTOM_Embedder.py +191 -0
  6. MajorTOM/embedder/__init__.py +2 -0
  7. MajorTOM/embedder/grid_cell_fragment.py +164 -0
  8. MajorTOM/embedder/models/DINOv2_S2RGB.py +91 -0
  9. MajorTOM/embedder/models/SSL4EO_S1RTC.py +125 -0
  10. MajorTOM/embedder/models/SSL4EO_S2L1C.py +97 -0
  11. MajorTOM/embedder/models/SigLIP_S2RGB.py +65 -0
  12. MajorTOM/embedder/models/__init__.py +4 -0
  13. MajorTOM/extras/coverage-example.png +3 -0
  14. MajorTOM/extras/coverage_vis.py +149 -0
  15. MajorTOM/extras/extract-sample-from-raw-S2.ipynb +0 -0
  16. MajorTOM/extras/thumbnail_dem.py +77 -0
  17. MajorTOM/extras/thumbnail_s1rtc.py +80 -0
  18. MajorTOM/extras/thumbnail_s2.py +68 -0
  19. MajorTOM/grid.py +284 -0
  20. MajorTOM/metadata_helpers.py +159 -0
  21. MajorTOM/sample_helpers.py +20 -0
  22. README.md +28 -0
  23. Tutorial.md +162 -0
  24. Tutorial_zh.md +157 -0
  25. app.py +792 -0
  26. configs/huggingface.yaml +12 -0
  27. countries.geo.json +0 -0
  28. data_utils.py +223 -0
  29. embedding_datasets/grid_sample_center_22k_FarSLIP_384x384.parquet +3 -0
  30. embedding_datasets/grid_sample_center_22k_SatCLIP_384x384.parquet +3 -0
  31. embedding_datasets/grid_sample_center_22k_SigLIP_384x384.parquet +3 -0
  32. embedding_datasets/grid_sample_metadata.parquet +3 -0
  33. embedding_datasets/zhejiang_sample_center_2k_FarSLIP_384x384.parquet +3 -0
  34. embedding_datasets/zhejiang_sample_center_2k_SatCLIP_384x384.parquet +3 -0
  35. embedding_datasets/zhejiang_sample_center_2k_SigLIP_384x384.parquet +3 -0
  36. embedding_datasets/zhejiang_sample_metadata.parquet +3 -0
  37. examples/example1.png +3 -0
  38. examples/example2.png +3 -0
  39. examples/example3.png +3 -0
  40. images/CLIP.png +3 -0
  41. images/Image_Search_Amazon.jpg +3 -0
  42. images/Image_Search_Middle_East.jpg +3 -0
  43. images/Location_Search_Amazon.jpg +3 -0
  44. images/Location_Search_Hangzhou.jpg +3 -0
  45. images/Text_Search.jpg +3 -0
  46. images/embedding.png +3 -0
  47. images/framework_en.png +3 -0
  48. images/framework_zh.png +3 -0
  49. images/samples.png +3 -0
  50. models/FarSLIP/.gitignore +160 -0
.gitattributes ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.gguf* filter=lfs diff=lfs merge=lfs -text
36
+ *.ggml filter=lfs diff=lfs merge=lfs -text
37
+ *.llamafile* filter=lfs diff=lfs merge=lfs -text
38
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
39
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
40
+ *.npy filter=lfs diff=lfs merge=lfs -text
41
+ *.npz filter=lfs diff=lfs merge=lfs -text
42
+ *.pickle filter=lfs diff=lfs merge=lfs -text
43
+ *.pkl filter=lfs diff=lfs merge=lfs -text
44
+ *.tar filter=lfs diff=lfs merge=lfs -text
45
+ *.wasm filter=lfs diff=lfs merge=lfs -text
46
+ *.zst filter=lfs diff=lfs merge=lfs -text
47
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
48
+ ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
49
+ center_bbx_22k.parquet filter=lfs diff=lfs merge=lfs -text
50
+ embedding_datasets/center_bbx_22k_SigLIP_384x384.parquet filter=lfs diff=lfs merge=lfs -text
51
+ embedding_datasets/center_bbx_22k_FarSLIP_384x384.parquet filter=lfs diff=lfs merge=lfs -text
52
+ *.png filter=lfs diff=lfs merge=lfs -text
53
+ *.jpg filter=lfs diff=lfs merge=lfs -text
54
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ .gradio/
5
+ .vscode/
6
+ .DS_Store
7
+ checkpoints/
8
+ models/FarSLIP/assets
9
+ models/SatCLIP/figures
10
+ configs/local.yaml
MajorTOM/MajorTOMDataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from pathlib import Path
6
+ import rasterio as rio
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+
10
+ class MajorTOM(Dataset):
11
+ """MajorTOM Dataset (https://huggingface.co/Major-TOM)
12
+
13
+ Args:
14
+ df ((geo)pandas.DataFrame): Metadata dataframe
15
+ local_dir (string): Root directory of the local dataset version
16
+ tif_bands (list): A list of tif file names to be read
17
+ png_bands (list): A list of png file names to be read
18
+
19
+ """
20
+
21
+ def __init__(self,
22
+ df,
23
+ local_dir = None,
24
+ tif_bands=['B04','B03','B02'],
25
+ png_bands=['thumbnail'],
26
+ tif_transforms=[transforms.ToTensor()],
27
+ png_transforms=[transforms.ToTensor()]
28
+ ):
29
+ super().__init__()
30
+ self.df = df
31
+ self.local_dir = Path(local_dir) if isinstance(local_dir,str) else local_dir
32
+ self.tif_bands = tif_bands if not isinstance(tif_bands,str) else [tif_bands]
33
+ self.png_bands = png_bands if not isinstance(png_bands,str) else [png_bands]
34
+ self.tif_transforms = transforms.Compose(tif_transforms) if tif_transforms is not None else None
35
+ self.png_transforms = transforms.Compose(png_transforms) if png_transforms is not None else None
36
+
37
+ def __len__(self):
38
+ return len(self.df)
39
+
40
+ def __getitem__(self, idx):
41
+ meta = self.df.iloc[idx]
42
+
43
+ product_id = meta.product_id
44
+ grid_cell = meta.grid_cell
45
+ row = grid_cell.split('_')[0]
46
+
47
+ path = self.local_dir / Path("{}/{}/{}".format(row, grid_cell, product_id))
48
+ out_dict = {'meta' : meta}
49
+
50
+ for band in self.tif_bands:
51
+ with rio.open(path / '{}.tif'.format(band)) as f:
52
+ out = f.read()
53
+ if self.tif_transforms is not None:
54
+ out = self.tif_transforms(out)
55
+ out_dict[band] = out
56
+
57
+
58
+ for band in self.png_bands:
59
+ out = Image.open(path / '{}.png'.format(band))
60
+ if self.png_transforms is not None:
61
+ out = self.png_transforms(out)
62
+ out_dict[band] = out
63
+
64
+ return out_dict
MajorTOM/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .sample_helpers import *
2
+ from .metadata_helpers import *
3
+ from .MajorTOMDataset import *
4
+ from .grid import *
5
+ from .embedder import *
MajorTOM/embedder/MajorTOM_Embedder.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import geopandas as gpd
3
+ import hashlib
4
+ from rasterio.io import MemoryFile
5
+
6
+ from .grid_cell_fragment import *
7
+ from .models import *
8
+ import cv2
9
+
10
+ class MajorTOM_Embedder(torch.nn.Module):
11
+ """
12
+ MajorTOM Embedder class that applies a model to geospatial image fragments,
13
+ computes embeddings, and returns metadata for each fragment.
14
+
15
+ This class is designed to work with raster data, where the image is fragmented
16
+ into smaller tiles, and embeddings are computed for each tile using the provided
17
+ embedder model. The output is a GeoDataFrame containing spatial metadata and
18
+ the corresponding embeddings for each tile.
19
+
20
+ Attributes:
21
+ embedder: A model that generates embeddings for image fragments.
22
+ frag_params: Dictionary containing fragmentation parameters such as the
23
+ target overlap and border shift.
24
+ column_types: Dictionary specifying data types for the output GeoDataFrame columns.
25
+ """
26
+
27
+ def __init__(self, embedder, target_overlap=0.1, border_shift=True):
28
+ """
29
+ Initializes the MajorTOM Embedder with the given parameters.
30
+
31
+ Args:
32
+ embedder (torch.nn.Module): A model that generates embeddings for image fragments.
33
+ target_overlap (float): The target overlap between image fragments. Default is 0.1.
34
+ border_shift (bool): Whether to shift the borders of fragments to avoid edge artifacts. Default is True.
35
+ """
36
+ super().__init__()
37
+
38
+ # Model
39
+ self.embedder = embedder
40
+
41
+ # Fragmentation Settings
42
+ self.frag_params = params = {
43
+ 'fragment_size' : self.embedder.size[0],
44
+ 'target_overlap' : target_overlap,
45
+ 'border_shift' : border_shift
46
+ }
47
+
48
+ # Data types for the output dataframe (commented columns need no conversion)
49
+ self.column_types = {
50
+ #'unique_id' :,
51
+ #'embedding' : ,
52
+ #'timestamp' : ,
53
+ #'product_id' : ,
54
+ #'grid_cell' : ,
55
+ 'grid_row_u' : 'int16',
56
+ 'grid_col_r' : 'int16',
57
+ 'centre_lat' : 'float32',
58
+ 'centre_lon' : 'float32',
59
+ #'utm_footprint' : ,
60
+ #'utm_crs' : ,
61
+ #'pixel_bbox' : ,
62
+ }
63
+
64
+ def bands(self):
65
+ """
66
+ Returns the set of input bands in the correct order.
67
+
68
+ Returns:
69
+ list: List of input bands used by the embedder.
70
+ """
71
+ return self.embedder.bands
72
+
73
+ def size(self):
74
+ """
75
+ Returns the input image size.
76
+
77
+ Returns:
78
+ tuple: Tuple representing the image size (height, width).
79
+ """
80
+ return self.embedder.size
81
+
82
+ def calculate_checksum(self, geometry, timestamp, product_id, embedding):
83
+ """
84
+ Calculates a checksum for the given geometry, timestamp, product ID, and embedding.
85
+
86
+ Args:
87
+ geometry (shapely.geometry): The geometry object representing the fragment's footprint.
88
+ timestamp (str): Timestamp of the data.
89
+ product_id (str): Product identifier.
90
+ embedding (np.ndarray): The embedding of the image fragment.
91
+
92
+ Returns:
93
+ str: A SHA256 checksum of the concatenated input parameters.
94
+ """
95
+ combined = f"{geometry}_{timestamp}_{product_id}_{embedding}"
96
+ checksum = hashlib.sha256(combined.encode()).hexdigest()
97
+ return checksum
98
+
99
+ def _read_image(self, row):
100
+ """
101
+ Reads and processes the image bands for a given row, performs optional upsampling
102
+ if the resolution is mismatched, and returns the image data, footprint, and CRS.
103
+
104
+ Args:
105
+ row (pandas.Series): The input row containing the image bands.
106
+
107
+ Returns:
108
+ torch.Tensor: A tensor containing the stacked image bands.
109
+ shapely.geometry: The footprint of the image.
110
+ rasterio.crs.CRS: The CRS of the image.
111
+ """
112
+
113
+ # Read the file
114
+ img = []
115
+ for band in self.embedder.bands:
116
+ with MemoryFile(row[band][0].as_py()) as mem_f:
117
+ with mem_f.open(driver='GTiff') as f:
118
+ crs = f.crs
119
+ footprint = box(*f.bounds)
120
+ img.append(f.read()[0])
121
+
122
+ # optional upsampling
123
+ shapes = [layer.shape for layer in img]
124
+ if any([el!=shapes[0] for el in shapes]): # if any resolution mismatch
125
+ h, w = max([el[0] for el in shapes]), max([el[1] for el in shapes]) # maximum size
126
+ for layer_idx, layer in enumerate(img):
127
+ if layer.shape != (h,w):
128
+ img[layer_idx] = cv2.resize(layer, (h,w), interpolation=cv2.INTER_NEAREST)
129
+ img = torch.from_numpy(np.stack(img,-1).astype(np.float32))
130
+
131
+ return img, footprint, crs
132
+
133
+
134
+ def forward(self, row, row_meta, device='cuda'):
135
+ """
136
+ Forward pass of the model: Reads the image, fragments it, computes embeddings
137
+ for each fragment, and returns a GeoDataFrame with the spatial metadata and
138
+ embeddings.
139
+
140
+ Args:
141
+ row (pandas.Series): The input row containing the image data.
142
+ row_meta (pandas.Series): Metadata associated with the row (e.g., timestamp, product_id).
143
+ device (str): The device to run the model on ('cpu' or 'cuda'). Default is 'cuda'.
144
+
145
+ Returns:
146
+ geopandas.GeoDataFrame: A GeoDataFrame containing metadata and embeddings for each fragment.
147
+ """
148
+ # Read file
149
+ img, footprint, crs = self._read_image(row)
150
+
151
+ # Fragment the sample
152
+ fragments, xys = fragment_fn(img, **self.frag_params, return_indices=True, verbose=False)
153
+
154
+ nrows, ncols, c, h, w = fragments.shape
155
+ # Apply the model
156
+ with torch.no_grad():
157
+ embeddings = self.embedder(fragments.reshape(-1,c,h,w).to(device)).view(nrows, ncols, -1)
158
+
159
+ df_rows = []
160
+
161
+ # Pack rows for geoparquet
162
+ for r_idx in range(nrows):
163
+ for c_idx in range(ncols):
164
+ embedding = embeddings[r_idx, c_idx].cpu().numpy()
165
+ # spatial features per fragment
166
+ x_offset,y_offset=xys[r_idx,c_idx].int().tolist()
167
+ pixel_bbox = [x_offset, y_offset, x_offset + h,y_offset + w] # in pixels
168
+ utm_footprint = crop_footprint(footprint, *img.shape[:2], pixel_bbox)
169
+ # main footprint is in WGS84 (needs to be consistent across parquet)
170
+ transformer = Transformer.from_crs(crs, CRS.from_epsg(4326), always_xy=True)
171
+ geometry = transform(transformer.transform, utm_footprint) # WGS84
172
+ centre_lon, centre_lat = geometry.centroid.coords[0]
173
+
174
+ row_dict = {
175
+ 'unique_id' : self.calculate_checksum(geometry, row_meta.timestamp.item(), row_meta.product_id.item(), embedding),
176
+ 'embedding' : embedding,
177
+ 'timestamp' : row_meta.timestamp.item(),
178
+ 'product_id' : row_meta.product_id.item(),
179
+ 'grid_cell' : row_meta.grid_cell.item(),
180
+ 'grid_row_u' : row_meta.grid_row_u.item(),
181
+ 'grid_col_r' : row_meta.grid_col_r.item(),
182
+ 'geometry' : geometry,
183
+ 'centre_lat' : centre_lat,
184
+ 'centre_lon' : centre_lon,
185
+ 'utm_footprint' : utm_footprint.wkt,
186
+ 'utm_crs' : crs.to_string(),
187
+ 'pixel_bbox' : pixel_bbox,
188
+ }
189
+ df_rows.append(row_dict)
190
+
191
+ return gpd.GeoDataFrame(df_rows).astype(self.column_types)
MajorTOM/embedder/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .MajorTOM_Embedder import *
2
+ from .grid_cell_fragment import *
MajorTOM/embedder/grid_cell_fragment.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from shapely.ops import transform
5
+ from pyproj import CRS, Transformer
6
+ import geopandas as gpd
7
+ import pandas as pd
8
+ import numpy as np
9
+ from shapely.geometry import Polygon, box
10
+ from rasterio.transform import from_bounds, xy
11
+ #from rasterio.windows import Window, from_bounds
12
+ import rasterio as rio
13
+
14
+ def crop_footprint(footprint, height, width, crop_bbox):
15
+ """
16
+ Crops the given footprint to the specified bounding box.
17
+
18
+ Args:
19
+ footprint (shapely.geometry.Polygon): The original footprint of the image or area.
20
+ height (int): Height of the image (in pixels).
21
+ width (int): Width of the image (in pixels).
22
+ crop_bbox (list): The bounding box to crop the footprint. The format is
23
+ [col_start, row_start, col_end, row_end], where:
24
+ - col_start, row_start: top-left corner
25
+ - col_end, row_end: bottom-right corner
26
+
27
+ Returns:
28
+ shapely.geometry.Polygon: The cropped bounding box in the same coordinate reference system (CRS) as the original footprint.
29
+ """
30
+
31
+ transform = from_bounds(*footprint.bounds, width, height)
32
+
33
+ # Convert pixel coordinates (col, row) to spatial coordinates (e.g., UTM)
34
+ # Using the raster's affine transform
35
+ min_x, min_y = transform * (crop_bbox[0], crop_bbox[1]) # (col_start, row_start)
36
+ max_x, max_y = transform * (crop_bbox[2], crop_bbox[3]) # (col_end, row_end)
37
+
38
+ # Create a Shapely polygon for the crop's bounding box in UTM
39
+ return box(min_x, min_y, max_x, max_y)
40
+
41
+ def fragment_unfold(image,fragment_size,overlap):
42
+ """
43
+ Unfold operation for a fragment with overlap. This function extracts image patches (fragments) with a specified
44
+ size and overlap between them.
45
+
46
+ Args:
47
+ image (torch.Tensor or np.ndarray): The input image to be fragmented (height, width, channels).
48
+ fragment_size (int or list): The size of each fragment. Can be a single integer for square fragments or
49
+ a list of two integers for non-square fragments.
50
+ overlap (int or list): The overlap between adjacent fragments. Can be a single integer or a list of two integers.
51
+
52
+ Returns:
53
+ torch.Tensor: The unfolded fragments of the image, each with the specified size and overlap.
54
+ """
55
+
56
+ # Convert image to a tensor and reorder dimensions if necessary
57
+ if not torch.is_tensor(image):
58
+ image = torch.from_numpy(image).permute(2, 0, 1) # Rearrange to (channels, height, width)
59
+ if len(image.shape) < 4:
60
+ image = image.unsqueeze(0) # Add batch dimension
61
+
62
+ b, c, h, w = image.shape
63
+
64
+ # Ensure fragment size is a list
65
+ if isinstance(fragment_size, int):
66
+ fragment_size = [fragment_size, fragment_size]
67
+ if isinstance(overlap, int):
68
+ overlap = [overlap, overlap]
69
+
70
+ # Calculate stride based on fragment size and overlap
71
+ stride = [f - o for f, o in zip(fragment_size, overlap)]
72
+
73
+ # Perform the unfolding operation
74
+ uf = torch.nn.functional.unfold(image, fragment_size, dilation=1, padding=0, stride=stride)
75
+
76
+ # Reshape and permute to return the unfolded image fragments
77
+ return uf.view(b, c, *fragment_size, -1).permute(0, 4, 1, 2, 3)[0]
78
+
79
+ def fragment_fn(img,
80
+ fragment_size,
81
+ target_overlap,
82
+ border_shift=True, # determines whether the outer border is shifted to ensure full coverage
83
+ return_indices=False,
84
+ verbose=False
85
+ ):
86
+ """
87
+ Fragment an image into smaller patches with a specified fragment size and overlap.
88
+
89
+ This function handles different scenarios based on image size, fragment size, and overlap,
90
+ and creates fragments from the input image accordingly. It also supports shifting the outer
91
+ border of fragments to ensure full coverage of the image.
92
+
93
+ Args:
94
+ img (np.ndarray or torch.Tensor): The input image to be fragmented (height, width, channels).
95
+ fragment_size (int or list): The size of the fragments. Can be a single integer (square) or a list of two integers (non-square).
96
+ target_overlap (float): The target overlap between adjacent fragments, in pixels.
97
+ border_shift (bool): Whether to shift the border of fragments to ensure full coverage of the image. Default is True.
98
+ return_indices (bool): If True, the function will also return the indices (offsets) for each fragment. Default is False.
99
+ verbose (bool): If True, the function will print additional details about the overlap. Default is False.
100
+
101
+ Returns:
102
+ torch.Tensor or tuple:
103
+ - If `return_indices` is False, a tensor containing the image fragments.
104
+ - If `return_indices` is True, a tuple of the image fragments and their offsets.
105
+ """
106
+
107
+ h,w,c=img.shape
108
+
109
+ assert h==w # SQUARE IMAGES SUPPORT ONLY
110
+
111
+ hf, wf = fragment_size, fragment_size
112
+ ho, wo = target_overlap*hf, target_overlap*wf
113
+
114
+ assert h >= hf and w >= wf # reject Scenario 1
115
+
116
+ # Scenario 2
117
+ if h == hf or w == wf:
118
+ if not torch.is_tensor(img):
119
+ img=torch.from_numpy(img).permute(2,0,1)
120
+ return img.view(1,1,c,h,w)
121
+
122
+ # Scenario 3 & 4
123
+
124
+ # determine number of segments between the centers of outermost fragments
125
+ h_n = max(1, int(np.round((h-hf)/(hf-ho))))
126
+ w_n = max(1, int(np.round((w-wf)/(wf-wo))))
127
+
128
+ # adjust practical overlap (divide the distance between the centers of outermost fragments by the true number of segments)
129
+ aho = int(np.ceil(hf-(h-hf)/(h_n)))
130
+ awo = int(np.ceil(wf-(w-wf)/(w_n)))
131
+
132
+ # compute fragments (might not exactly fill the outermost border)
133
+ topleft = fragment_unfold(img.permute(2,0,1),fragment_size=(hf,wf), overlap=(aho,awo)).view(1+h_n, 1+w_n, c, hf, wf)
134
+
135
+ full = topleft
136
+
137
+ if border_shift:
138
+
139
+ if h > hf+h_n*(hf-aho) or w > wf+w_n*(wf-awo):
140
+ #print('Outers...')
141
+ bottomleft = fragment_unfold(img[-hf:,:,:],fragment_size=(hf,wf), overlap=(aho,awo)).view(1,1+w_n,c,hf,wf)
142
+ topright = fragment_unfold(img[:,-wf:,:],fragment_size=(hf,wf), overlap=(aho,awo)).view(1+h_n,1,c,hf,wf)
143
+
144
+ # Shift last row and col to the border of the original
145
+ full[:,-1,None] = topright
146
+ full[-1] = bottomleft
147
+
148
+ if verbose:
149
+ print('Target Overlap: {} pixels. Feasible Overlap: {} pixels.'.format(ho,aho))
150
+
151
+ if not return_indices:
152
+ return full
153
+ else:
154
+ offset=-1*torch.ones(*full.shape[:2],2)
155
+ for ridx in range(full.shape[0]):
156
+ for cidx in range(full.shape[1]):
157
+ offset[ridx,cidx,1] = cidx * (hf-aho)
158
+ offset[ridx,cidx,0] = ridx * (wf-awo)
159
+
160
+ if border_shift:
161
+ offset[ridx,-1,1] = h-hf
162
+ offset[-1,cidx,0] = w-wf
163
+
164
+ return full,offset
MajorTOM/embedder/models/DINOv2_S2RGB.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, AutoModel
3
+
4
+ class DINOv2_S2RGB_Embedder(torch.nn.Module):
5
+ """
6
+ Embedding wrapper for DINOv2 and Sentinel-2 data.
7
+
8
+ This model uses the DINOv2 architecture to generate embeddings for Sentinel-2 RGB data. The input data (RGB bands)
9
+ is preprocessed by normalizing and mapping it to true-color values. Then, it is passed through the DINOv2 model
10
+ to obtain feature embeddings.
11
+
12
+ Preprocessing:
13
+ The input Sentinel-2 image is divided by 10,000 and multiplied by 2.5 to map it to a true-color image
14
+ (normalized to the range [0, 1]), followed by processing using the DINOv2 image processor.
15
+
16
+ Model:
17
+ The DINOv2 model processes RGB input images of shape [224, 224] and produces embeddings, which are then
18
+ averaged across the sequence dimension to obtain a fixed-size embedding vector.
19
+
20
+ Model Components:
21
+ - `AutoImageProcessor`: Preprocessing pipeline for handling Sentinel-2 data.
22
+ - `AutoModel`: DINOv2 transformer model used for feature extraction.
23
+
24
+ Attributes:
25
+ processor (AutoImageProcessor): The DINOv2 image processor to handle preprocessing.
26
+ model (AutoModel): The DINOv2 model used to generate embeddings from preprocessed images.
27
+ bands (list): List of the Sentinel-2 bands used for RGB input (B04, B03, B02).
28
+ size (tuple): The input size expected by the model (height, width) for the RGB image.
29
+ """
30
+
31
+ def __init__(self):
32
+ """
33
+ Initializes the DINOv2_S2RGB_Embedder by loading the pre-trained DINOv2 model and processor,
34
+ and setting the expected input size for Sentinel-2 RGB data.
35
+
36
+ This embedder uses the 'facebook/dinov2-base' model for feature extraction from Sentinel-2
37
+ true-color images (RGB).
38
+
39
+ Attributes:
40
+ processor (AutoImageProcessor): The DINOv2 image processor for preprocessing Sentinel-2 images.
41
+ model (AutoModel): The pre-trained DINOv2 model for generating embeddings.
42
+ bands (list): The Sentinel-2 bands used for RGB data (B04 - Red, B03 - Green, B02 - Blue).
43
+ size (tuple): The expected input size of the image for the DINOv2 model (height, width).
44
+ """
45
+ super().__init__()
46
+
47
+ # Load the DINOv2 processor and model from Hugging Face
48
+ self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
49
+ self.model = AutoModel.from_pretrained('facebook/dinov2-base')
50
+
51
+ # Define the RGB bands for Sentinel-2 (B04, B03, B02)
52
+ self.bands = ['B04', 'B03', 'B02']
53
+
54
+ # Extract the input size from the processor settings
55
+ self.size = self.processor.crop_size['height'], self.processor.crop_size['width']
56
+
57
+
58
+ def normalize(self, input):
59
+ """
60
+ Normalizes Sentinel-2 RGB data to true-color values.
61
+
62
+ The input image (in raw Sentinel-2 reflectance values) is first divided by 10,000 to convert it
63
+ to reflectance values in the range [0, 1]. Then, the result is multiplied by 2.5 to obtain true-color
64
+ values that are suitable for input into the DINOv2 model.
65
+
66
+ Args:
67
+ input (torch.Tensor): The raw Sentinel-2 image tensor to be normalized.
68
+
69
+ Returns:
70
+ torch.Tensor: The normalized true-color image.
71
+ """
72
+ return (2.5 * (input / 1e4)).clip(0,1)
73
+
74
+ def forward(self, input):
75
+ """
76
+ Forward pass through the model to generate embeddings for the input image.
77
+
78
+ The input image is first normalized using the `normalize` method, then processed by the DINOv2 image processor
79
+ and passed through the DINOv2 model to generate embeddings. The output from the model is averaged across
80
+ the sequence dimension to obtain a fixed-size embedding.
81
+
82
+ Args:
83
+ input (torch.Tensor): The input Sentinel-2 image tensor with shape [C, H, W], where C=3 (RGB channels).
84
+
85
+ Returns:
86
+ torch.Tensor: The embedding vector, averaged over the sequence dimension, with shape [embedding_dim].
87
+ """
88
+ model_input = self.processor(self.normalize(input), return_tensors="pt")
89
+ outputs = self.model(model_input['pixel_values'].to(self.model.device))
90
+ last_hidden_states = outputs.last_hidden_state
91
+ return last_hidden_states.mean(dim=1).cpu()
MajorTOM/embedder/models/SSL4EO_S1RTC.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchgeo.models import ResNet50_Weights
3
+ import timm
4
+ import numpy as np
5
+
6
+ class SSL4EO_S1RTC_Embedder(torch.nn.Module):
7
+ """
8
+ SSL4EO Embedder for Sentinel-1 data using a pre-trained model.
9
+
10
+ This model is based on the SSL4EO (Self-Supervised Learning for Earth Observation) approach,
11
+ using a pre-trained ResNet50 model for Sentinel-1 radar data (SAR). The model is fine-tuned
12
+ to work with Sentinel-1 data and can be used directly for feature extraction.
13
+
14
+ Project Code:
15
+ https://github.com/zhu-xlab/SSL4EO-S12
16
+
17
+ Publication:
18
+ https://arxiv.org/abs/2211.07044
19
+ """
20
+
21
+ def __init__(self, s1_mean=[-12.54847273, -20.19237134], s1_std=[5.25697717,5.91150917]):
22
+ """
23
+ Initializes the SSL4EO_S1RTC_Embedder by setting up the mean and standard deviation for Sentinel-1 data normalization,
24
+ and loading the pre-trained model.
25
+
26
+ The model uses a pre-trained ResNet50 architecture adapted for Sentinel-1 radar (SAR) data, with weights provided
27
+ by the `torchgeo` library. The `s1_mean` and `s1_std` are used for normalizing the input data to the model.
28
+
29
+ Args:
30
+ s1_mean (list, optional): Mean values for Sentinel-1 radar (SAR) data. Default is set to SSL4EO's values.
31
+ s1_std (list, optional): Standard deviation values for Sentinel-1 radar (SAR) data. Default is set to SSL4EO's values.
32
+
33
+ Attributes:
34
+ s1_mean (torch.FloatTensor): Mean values for normalization.
35
+ s1_std (torch.FloatTensor): Standard deviation values for normalization.
36
+ model (torch.nn.Module): The ResNet50 model initialized with pre-trained weights.
37
+ bands (list): List of Sentinel-1 bands used for input data (VV, VH).
38
+ size (tuple): The input size expected by the model (224x224 pixels).
39
+ """
40
+ super().__init__()
41
+
42
+ self.s1_mean = torch.FloatTensor(s1_mean)
43
+ self.s1_std = torch.FloatTensor(s1_std)
44
+
45
+ # load model
46
+ self.model = self.init_model()
47
+ self.bands = ['vv','vh']
48
+ self.size = 224,224
49
+
50
+ def init_model(self):
51
+ """
52
+ Initializes the ResNet50 model with pre-trained weights for Sentinel-1 data.
53
+
54
+ This method loads the pre-trained model weights for Sentinel-1 data from `ResNet50_Weights.SENTINEL1_ALL_MOCO`
55
+ and sets the fully connected layer (`fc`) to an identity function to output embeddings directly from the last
56
+ convolutional layer.
57
+
58
+ Returns:
59
+ torch.nn.Module: The initialized ResNet50 model.
60
+ """
61
+ weights = ResNet50_Weights.SENTINEL1_ALL_MOCO
62
+ model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
63
+ model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
64
+ model.fc=torch.nn.Identity()
65
+
66
+ return model
67
+
68
+ def normalize(self, img,scale=1.0):
69
+ """
70
+ Normalizes the Sentinel-1 SAR (Synthetic Aperture Radar) data.
71
+
72
+ This method normalizes the Sentinel-1 radar signals using the mean (`s1_mean`)
73
+ and standard deviation (`s1_std`) values. The radar data is normalized to a
74
+ standard range, and the pixel values are scaled using a factor (`scale`).
75
+
76
+ Args:
77
+ img (torch.Tensor): The input Sentinel-1 image to be normalized.
78
+ scale (float, optional): The scaling factor for the normalized image. Default is 1.0.
79
+
80
+ Returns:
81
+ torch.Tensor: The normalized and scaled image.
82
+ """
83
+
84
+
85
+ min_value = (self.s1_mean - 2 * self.s1_std).to(img.device)
86
+ max_value = (self.s1_mean + 2 * self.s1_std).to(img.device)
87
+ img = (img - min_value[:,None,None]) / (max_value - min_value)[:,None,None] * scale
88
+ img = img.clip(0,scale).float()
89
+
90
+ return img
91
+
92
+ def preprocess(self, input):
93
+ """
94
+ Preprocesses the Sentinel-1 SAR (Synthetic Aperture Radar) data before feeding it into the model.
95
+
96
+ This method applies a logarithmic transformation to the input image to convert
97
+ it from linear scale to decibel (dB) scale. The image is clipped to avoid
98
+ logarithm of zero and then normalized using the `normalize` method.
99
+
100
+ Args:
101
+ input (torch.Tensor): The input Sentinel-1 image (e.g., VV or VH polarization).
102
+
103
+ Returns:
104
+ torch.Tensor: The preprocessed and normalized image in dB scale.
105
+ """
106
+ # Convert the input from linear scale to decibel (dB) scale
107
+ dB_input = 10 * input.log10(input.clip(min=1e-10)) # Clip to prevent log(0)
108
+
109
+ # Normalize the dB-scaled image
110
+ return self.normalize(dB_input)
111
+
112
+ def forward(self, input):
113
+ """
114
+ Forward pass through the model.
115
+
116
+ The input image is preprocessed using the `preprocess` method and then passed
117
+ through the ResNet50 model to obtain an embedding.
118
+
119
+ Args:
120
+ input (torch.Tensor): Preprocessed Sentinel-1 image (e.g., shape: [C, H, W]).
121
+
122
+ Returns:
123
+ torch.Tensor: The output embedding from the model.
124
+ """
125
+ return self.model(self.preprocess(input))
MajorTOM/embedder/models/SSL4EO_S2L1C.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchgeo.models import ResNet50_Weights
3
+ import timm
4
+
5
+ class SSL4EO_S2L1C_Embedder(torch.nn.Module):
6
+ """
7
+ SSL4EO Embedder for Sentinel-2 data using a pre-trained model.
8
+
9
+ This model is based on the SSL4EO (Self-Supervised Learning for Earth Observation) approach,
10
+ using a pre-trained ResNet50 model for Sentinel-2 data. The model is fine-tuned for Sentinel-2
11
+ images and can be used directly for feature extraction.
12
+
13
+ Project Code:
14
+ https://github.com/zhu-xlab/SSL4EO-S12
15
+
16
+ Publication:
17
+ https://arxiv.org/abs/2211.07044
18
+ """
19
+
20
+
21
+
22
+ def __init__(self):
23
+ """
24
+ Initializes the SSL4EO_S2L1C_Embedder by loading the pre-trained SSL4EO model.
25
+
26
+ The model uses ResNet50 architecture, adapted for Sentinel-2 data with a specific
27
+ weight configuration (`ResNet50_Weights.SENTINEL2_ALL_DINO`) provided by `torchgeo`.
28
+ It also defines the bands used for Sentinel-2 data and sets the input image size to
29
+ 224x224 pixels (the model input size).
30
+
31
+ Attributes:
32
+ model (torch.nn.Module): The ResNet50 model with pre-trained weights for Sentinel-2 data.
33
+ bands (list): List of Sentinel-2 bands used for input data.
34
+ size (tuple): The input image size expected by the model, set to 224x224 pixels.
35
+ """
36
+ super().__init__()
37
+
38
+ # Load the pre-trained SSL4EO ResNet50 model
39
+ self.model = self.init_model()
40
+
41
+ # Define the Sentinel-2 L1C bands (e.g., B01, B02, B03, etc.)
42
+ self.bands = [
43
+ 'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07',
44
+ 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
45
+ ]
46
+
47
+ # Define the expected input size of the model
48
+ self.size = 224, 224
49
+
50
+ def init_model(self):
51
+ """
52
+ Initializes the ResNet50 model with pre-trained weights for Sentinel-2 data.
53
+
54
+ The model is loaded using the `timm` library, with Sentinel-2 specific weights
55
+ (`ResNet50_Weights.SENTINEL2_ALL_DINO`). The fully connected layer (`fc`) is replaced
56
+ with an identity function to obtain embeddings directly from the last convolutional
57
+ layer.
58
+
59
+ Returns:
60
+ torch.nn.Module: The initialized ResNet50 model.
61
+ """
62
+ weights = ResNet50_Weights.SENTINEL2_ALL_DINO
63
+ model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
64
+ model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
65
+ model.fc=torch.nn.Identity()
66
+
67
+ return model
68
+
69
+ def preprocess(self, input):
70
+ """
71
+ Preprocesses the Sentinel-2 input data for the model.
72
+
73
+ This function normalizes the input image by dividing the pixel values by 10,000.
74
+ This scaling step ensures that the reflectance values are mapped into an appropriate
75
+ range for the model.
76
+
77
+ Args:
78
+ input (torch.Tensor): Input image with Sentinel-2 reflectance values (e.g., shape: [C, H, W]).
79
+
80
+ Returns:
81
+ torch.Tensor: Preprocessed input, scaled by a factor of 10,000.
82
+ """
83
+ return input / 1e4
84
+
85
+ def forward(self, input):
86
+ """
87
+ Forward pass through the model.
88
+
89
+ The input image is preprocessed and then passed through the ResNet50 model to obtain the embedding.
90
+
91
+ Args:
92
+ input (torch.Tensor): Preprocessed Sentinel-2 image (e.g., shape: [C, H, W]).
93
+
94
+ Returns:
95
+ torch.Tensor: The output embedding from the model.
96
+ """
97
+ return self.model(self.preprocess(input))
MajorTOM/embedder/models/SigLIP_S2RGB.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from open_clip import create_model_from_pretrained, get_tokenizer
2
+ import torch
3
+
4
+ class SigLIP_S2RGB_Embedder(torch.nn.Module):
5
+ """
6
+ Embedding wrapper for SigLIP and Sentinel-2 data.
7
+
8
+ This model processes Sentinel-2 RGB data and embeds it into a feature space using the DINOv@ transformer model.
9
+ The preprocessing includes normalizing Sentinel-2 values to create a True-Colour image before passing it through
10
+ the model. The final output is a high-dimensional feature vector representing the input image.
11
+
12
+ Preprocessing:
13
+ - Sentinel-2 bands are divided by 10,000 to scale the reflectance values.
14
+ - Then, the values are multiplied by 2.5 to map them into the [0, 1] range for True-Colour images.
15
+ - The model input is further processed using the DINOv@ preprocessor.
16
+
17
+ Model:
18
+ - Takes an RGB input of shape 384x384 pixels and produces an embedding vector.
19
+ """
20
+
21
+ def __init__(self):
22
+ super().__init__()
23
+
24
+ # load model
25
+ self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
26
+ # Sentinel-2 RGB bands (B04 - Red, B03 - Green, B02 - Blue)
27
+ self.bands = ['B04', 'B03', 'B02']
28
+ self.size = self.preprocess.transforms[0].size
29
+
30
+ def normalize(self, input):
31
+ """
32
+ Normalizes Sentinel-2 image data to create a True-Colour image.
33
+
34
+ Sentinel-2 images are scaled to reflectance values in the range [0, 1]. This function:
35
+ - Divides the input by 10,000 to scale Sentinel-2 values.
36
+ - Multiplies the result by 2.5 to map the values into the True-Colour image range.
37
+
38
+ Args:
39
+ input (torch.Tensor or np.ndarray): Input image with Sentinel-2 reflectance values.
40
+
41
+ Returns:
42
+ torch.Tensor: Normalized True-Colour image, clipped to the range [0, 1].
43
+ """
44
+ return (2.5 * (input / 1e4)).clip(0,1)
45
+
46
+ def forward(self, input):
47
+ """
48
+ Forward pass through the SigLIP model.
49
+
50
+ This method normalizes the input Sentinel-2 image to a True-Colour representation and processes it through
51
+ the model to obtain an embedding.
52
+
53
+ Args:
54
+ input (torch.Tensor): A Sentinel-2 image, typically of shape (C, H, W), where C=3 (RGB),
55
+ H=384, and W=384.
56
+
57
+ Returns:
58
+ torch.Tensor: The image embedding produced by the model.
59
+ """
60
+ preprocess_input = self.normalize(input)
61
+
62
+ # normalization only
63
+ model_input = self.preprocess.transforms[-1](preprocess_input)
64
+
65
+ return self.model.encode_image(model_input)
MajorTOM/embedder/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .SigLIP_S2RGB import *
2
+ from .DINOv2_S2RGB import *
3
+ from .SSL4EO_S2L1C import *
4
+ from .SSL4EO_S1RTC import *
MajorTOM/extras/coverage-example.png ADDED

Git LFS Details

  • SHA256: a2ed4c9e1b6516b07b803cdced733213d3db3692665c119814fb495089231627
  • Pointer size: 132 Bytes
  • Size of remote file: 2.97 MB
MajorTOM/extras/coverage_vis.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from mpl_toolkits.basemap import Basemap
5
+ import PIL
6
+
7
+ def get_mask(df):
8
+ """
9
+ Take a Major TOM dataframe and create a mask corresponding to available cells
10
+ """
11
+
12
+ mask = np.zeros((2004,4008), dtype=np.uint8)
13
+ row_offset = -1002
14
+ col_offset = -2004
15
+
16
+ nodata = df['nodata'].values > 0.5
17
+
18
+ yy = mask.shape[0] - (np.array(df['grid_row_u']) - row_offset) - 1
19
+ xx = np.array(df['grid_col_r']) - col_offset
20
+
21
+ yy = yy[~nodata]
22
+ xx = xx[~nodata]
23
+
24
+ mask[yy, xx] = 255
25
+
26
+ return PIL.Image.fromarray(mask)
27
+
28
+ def fig2img(fig):
29
+ """Convert a Matplotlib figure to a PIL Image and return it"""
30
+ import io
31
+ buf = io.BytesIO()
32
+ fig.savefig(buf)
33
+ buf.seek(0)
34
+ img = PIL.Image.open(buf)
35
+ return img
36
+
37
+ def light_basemap():
38
+ """
39
+ Bright coloured contours
40
+ """
41
+
42
+ with plt.ioff():
43
+ fig, ax = plt.subplots(figsize=(48,24), dpi=167)
44
+
45
+ m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
46
+ m.fillcontinents(color="#9eba9b", lake_color='#CCDDFF')
47
+ m.drawmapboundary(fill_color="#CCDDFF")
48
+ m.drawcountries(color="#666666", linewidth=1)
49
+ m.drawcoastlines(color="#666666", linewidth=1)
50
+
51
+ plt.gca().set_axis_off()
52
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
53
+ hspace = 0, wspace = 0)
54
+ plt.margins(0,0)
55
+
56
+ return fig2img(fig)
57
+
58
+ def dark_basemap():
59
+ """
60
+ Dark contours
61
+ """
62
+
63
+ with plt.ioff():
64
+ fig, ax = plt.subplots(figsize=(48,24), dpi=167)
65
+
66
+ m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
67
+ m.fillcontinents(color="#242424", lake_color='#242424')
68
+ m.drawmapboundary(fill_color="#242424")
69
+ m.drawcountries(color="#000000", linewidth=1)
70
+ m.drawcoastlines(color="#000000", linewidth=1)
71
+
72
+ plt.gca().set_axis_off()
73
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
74
+ hspace = 0, wspace = 0)
75
+ plt.margins(0,0)
76
+
77
+ return fig2img(fig)
78
+
79
+ def get_coveragemap(input, input2=None):
80
+ """
81
+ Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
82
+
83
+ Optionally, input2 can be provided and then, the map plots a map with extra colours indicating cells available only in input (green) or only input2 (blue)
84
+ """
85
+
86
+ if input2 is None:
87
+ return single_coveragemap(input)
88
+ else:
89
+ cmap1 = single_coveragemap(input)
90
+ cmap2 = single_coveragemap(input2)
91
+
92
+ # arrays for mixing
93
+ inp1_arr = np.array(cmap1)[...,:3]
94
+ inp2_arr = np.array(cmap2)[...,:3]
95
+
96
+ common_arr = inp1_arr*(inp1_arr.sum(-1) == inp2_arr.sum(-1))[:,:,None]
97
+ common_arr[:,:,(1,2)] = 0
98
+ inp1_arr[:,:,(0,2)] = 0 # Green - indicates presence of S2 only
99
+ inp2_arr[:,:,(0,1)] = 0 # Blue - indicates presense of DEM only
100
+
101
+ return PIL.Image.fromarray(((common_arr + inp1_arr + inp2_arr)).astype(np.uint8))
102
+
103
+
104
+ def single_coveragemap(input):
105
+ """
106
+ Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
107
+ """
108
+
109
+ # compute mask if df is provided
110
+ if isinstance(input, pd.DataFrame):
111
+ mask = get_mask(input)
112
+ else:
113
+ mask = input
114
+
115
+ basemap = light_basemap()
116
+ basemap_d = dark_basemap()
117
+
118
+ outside_earth = np.array(basemap.convert('RGBA'))[:, :, 0] == 255
119
+ outside_earth = PIL.Image.fromarray(outside_earth)
120
+
121
+ mask = mask.resize(basemap.size, PIL.Image.NEAREST)
122
+
123
+ basemap.putalpha(mask)
124
+
125
+ # Mask outside of earth
126
+ basemap.paste(outside_earth, (0,0), outside_earth)
127
+
128
+ basemap_d.paste(basemap, (0,0), basemap)
129
+
130
+ return basemap_d
131
+
132
+ if __name__ == '__main__':
133
+ DATASET_NAME = 'Major-TOM/Core-S2L2A'
134
+ meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
135
+ df = pd.read_parquet(meta_path)
136
+
137
+ # This is how you make a coverage figure!
138
+ coverage_img = get_coveragemap(df)
139
+
140
+ coverage_img.save('coverage-example.png', format='PNG')
141
+
142
+ # and this is how you can create an overap for 2 datasets!
143
+ DATASET_NAME = 'Major-TOM/Core-DEM'
144
+ meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
145
+ dem_df = pd.read_parquet(meta_path)
146
+
147
+ coverage_img = get_coveragemap(df,dem_df)
148
+
149
+ coverage_img.save('overlap-coverage-example.png', format='PNG')
MajorTOM/extras/extract-sample-from-raw-S2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
MajorTOM/extras/thumbnail_dem.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
3
+
4
+ Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
5
+ """
6
+
7
+ from rasterio.io import MemoryFile
8
+ from PIL import Image
9
+ import numpy as np
10
+ import os
11
+ from pathlib import Path
12
+ import rasterio as rio
13
+ from matplotlib.colors import LightSource
14
+
15
+ def get_grayscale(x):
16
+ """
17
+ Normalized grayscale visualisation
18
+ """
19
+
20
+ # normalize
21
+ x_n = x-x.min()
22
+ x_n = x_n/x_n.max()
23
+
24
+ return np.uint8(x_n*255)
25
+
26
+ def get_hillshade(x, azdeg=315, altdeg=45,ve=1):
27
+ """
28
+ Hillshade visualisation for DEM
29
+ """
30
+ ls = LightSource(azdeg=azdeg, altdeg=altdeg)
31
+
32
+ return np.uint8(255*ls.hillshade(x, vert_exag=ve))
33
+
34
+ def dem_thumbnail(dem, dem_NODATA = -32768.0, hillshade=True):
35
+ """
36
+ Takes vv and vh numpy arrays along with the corresponding NODATA values (default is -32768.0)
37
+
38
+ Returns a numpy array with the thumbnail
39
+ """
40
+ if hillshade:
41
+ return get_hillshade(dem)
42
+ else:
43
+ return get_grayscale(dem)
44
+
45
+
46
+ def dem_thumbnail_from_datarow(datarow):
47
+ """
48
+ Takes a datarow directly from one of the data parquet files
49
+
50
+ Returns a PIL Image
51
+ """
52
+
53
+ with MemoryFile(datarow['DEM'][0].as_py()) as mem_f:
54
+ with mem_f.open(driver='GTiff') as f:
55
+ dem=f.read().squeeze()
56
+ dem_NODATA = f.nodata
57
+
58
+ img = dem_thumbnail(dem, dem_NODATA)
59
+
60
+ return Image.fromarray(img,'L')
61
+
62
+ if __name__ == '__main__':
63
+ from fsspec.parquet import open_parquet_file
64
+ import pyarrow.parquet as pq
65
+
66
+ print('[example run] reading file from HuggingFace...')
67
+ url = "https://huggingface.co/datasets/Major-TOM/Core-DEM/resolve/main/images/part_01001.parquet"
68
+ with open_parquet_file(url) as f:
69
+ with pq.ParquetFile(f) as pf:
70
+ first_row_group = pf.read_row_group(1)
71
+
72
+ print('[example run] computing the thumbnail...')
73
+ thumbnail = dem_thumbnail_from_datarow(first_row_group)
74
+
75
+ thumbnail_fname = 'example_thumbnail.png'
76
+ thumbnail.save(thumbnail_fname, format = 'PNG')
77
+ print('[example run] saved as "{}"'.format(thumbnail_fname))
MajorTOM/extras/thumbnail_s1rtc.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
3
+
4
+ Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
5
+ """
6
+
7
+ from rasterio.io import MemoryFile
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ def s1rtc_thumbnail(vv, vh, vv_NODATA = -32768.0, vh_NODATA = -32768.0):
12
+ """
13
+ Takes vv and vh numpy arrays along with the corresponding NODATA values (default is -32768.0)
14
+
15
+ Returns a numpy array with the thumbnail
16
+ """
17
+
18
+ # valid data masks
19
+ vv_mask = vv != vv_NODATA
20
+ vh_mask = vh != vh_NODATA
21
+
22
+ # remove invalid values before log op
23
+ vv[vv<0] = vv[vv>=0].min()
24
+ vh[vh<0] = vh[vh>=0].min()
25
+
26
+ # apply log op
27
+ vv_dB = 10*np.log10(vv)
28
+ vh_dB = 10*np.log10(vh)
29
+
30
+ # scale to 0-255
31
+ vv_dB = (vv_dB - vv_dB[vv_mask].min()) / (vv_dB[vv_mask].max() - vv_dB[vv_mask].min()) * 255
32
+ vh_dB = (vh_dB - vh_dB[vh_mask].min()) / (vh_dB[vh_mask].max() - vh_dB[vh_mask].min()) * 255
33
+
34
+ # represent nodata as 0
35
+ vv_dB[vv_mask==0] = 0
36
+ vh_dB[vh_mask==0] = 0
37
+
38
+ # false colour composite
39
+ return np.stack([vv_dB,
40
+ 255*(vv_dB+vh_dB)/np.max(vv_dB+vh_dB),
41
+ vh_dB
42
+ ],-1).astype(np.uint8)
43
+
44
+ def s1rtc_thumbnail_from_datarow(datarow):
45
+ """
46
+ Takes a datarow directly from one of the data parquet files
47
+
48
+ Returns a PIL Image
49
+ """
50
+
51
+ with MemoryFile(datarow['vv'][0].as_py()) as mem_f:
52
+ with mem_f.open(driver='GTiff') as f:
53
+ vv=f.read().squeeze()
54
+ vv_NODATA = f.nodata
55
+
56
+ with MemoryFile(datarow['vh'][0].as_py()) as mem_f:
57
+ with mem_f.open(driver='GTiff') as f:
58
+ vh=f.read().squeeze()
59
+ vh_NODATA = f.nodata
60
+
61
+ img = s1rtc_thumbnail(vv, vh, vv_NODATA=vv_NODATA, vh_NODATA=vh_NODATA)
62
+
63
+ return Image.fromarray(img)
64
+
65
+ if __name__ == '__main__':
66
+ from fsspec.parquet import open_parquet_file
67
+ import pyarrow.parquet as pq
68
+
69
+ print('[example run] reading file from HuggingFace...')
70
+ url = "https://huggingface.co/datasets/Major-TOM/Core-S1RTC/resolve/main/images/part_00001.parquet"
71
+ with open_parquet_file(url) as f:
72
+ with pq.ParquetFile(f) as pf:
73
+ first_row_group = pf.read_row_group(1)
74
+
75
+ print('[example run] computing the thumbnail...')
76
+ thumbnail = s1rtc_thumbnail_from_datarow(first_row_group)
77
+
78
+ thumbnail_fname = 'example_thumbnail.png'
79
+ thumbnail.save(thumbnail_fname, format = 'PNG')
80
+ print('[example run] saved as "{}"'.format(thumbnail_fname))
MajorTOM/extras/thumbnail_s2.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
3
+
4
+ Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
5
+ """
6
+
7
+ from rasterio.io import MemoryFile
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ def s2l2a_thumbnail(B04, B03, B02, gain=1.3, gamma=0.6):
12
+ """
13
+ Takes B04, B03, B02 numpy arrays along with the corresponding NODATA values (default is -32768.0)
14
+
15
+ Returns a numpy array with the thumbnail
16
+ """
17
+
18
+ # concatenate
19
+ thumb = np.stack([B04, B03, B02], -1)
20
+
21
+ # apply gain & gamma
22
+ thumb = gain*((thumb/10_000)**gamma)
23
+
24
+ return (thumb.clip(0,1)*255).astype(np.uint8)
25
+
26
+ def s2l2a_thumbnail_from_datarow(datarow):
27
+ """
28
+ Takes a datarow directly from one of the data parquet files
29
+
30
+ Returns a PIL Image
31
+ """
32
+
33
+ # red
34
+ with MemoryFile(datarow['B04'][0].as_py()) as mem_f:
35
+ with mem_f.open(driver='GTiff') as f:
36
+ B04=f.read().squeeze()
37
+ B04_NODATA = f.nodata
38
+
39
+ # green
40
+ with MemoryFile(datarow['B03'][0].as_py()) as mem_f:
41
+ with mem_f.open(driver='GTiff') as f:
42
+ B03=f.read().squeeze()
43
+ B03_NODATA = f.nodata
44
+
45
+ # blue
46
+ with MemoryFile(datarow['B02'][0].as_py()) as mem_f:
47
+ with mem_f.open(driver='GTiff') as f:
48
+ B02=f.read().squeeze()
49
+ B02_NODATA = f.nodata
50
+
51
+ img = s2l2a_thumbnail(B04,B03,B02)
52
+
53
+ return Image.fromarray(img)
54
+
55
+ if __name__ == '__main__':
56
+ from fsspec.parquet import open_parquet_file
57
+ import pyarrow.parquet as pq
58
+
59
+ print('[example run] reading file from HuggingFace...')
60
+ url = "https://huggingface.co/datasets/Major-TOM/Core-S2L2A/resolve/main/images/part_01000.parquet"
61
+ with open_parquet_file(url, columns = ["B04", "B03", "B02"]) as f:
62
+ with pq.ParquetFile(f) as pf:
63
+ first_row_group = pf.read_row_group(1, columns = ["B04", "B03", "B02"])
64
+
65
+ print('[example run] computing the thumbnail...')
66
+ thumbnail = s2l2a_thumbnail_from_datarow(first_row_group)
67
+
68
+ thumbnail.save('example_thumbnail.png', format = 'PNG')
MajorTOM/grid.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import pandas as pd
4
+ import geopandas as gpd
5
+ from shapely.geometry import LineString, Polygon
6
+ from tqdm import tqdm
7
+ import re
8
+
9
+
10
+
11
+ class Grid():
12
+
13
+ RADIUS_EQUATOR = 6378.137 # km
14
+
15
+ def __init__(self,dist,latitude_range=(-85,85),longitude_range=(-180,180),utm_definition='bottomleft'):
16
+ self.dist = dist
17
+ self.latitude_range = latitude_range
18
+ self.longitude_range = longitude_range
19
+ self.utm_definition = utm_definition
20
+ self.rows,self.lats = self.get_rows()
21
+ self.points, self.points_by_row = self.get_points()
22
+
23
+ def get_rows(self):
24
+
25
+ # Define set of latitudes to use, based on the grid distance
26
+ arc_pole_to_pole = math.pi * self.RADIUS_EQUATOR
27
+ num_divisions_in_hemisphere = math.ceil(arc_pole_to_pole / self.dist)
28
+
29
+ latitudes = np.linspace(-90, 90, num_divisions_in_hemisphere+1)[:-1]
30
+ latitudes = np.mod(latitudes, 180) - 90
31
+
32
+ # order should be from south to north
33
+ latitudes = np.sort(latitudes)
34
+
35
+ zeroth_row = np.searchsorted(latitudes,0)
36
+
37
+ # From 0U-NU and 1D-ND
38
+ rows = [None] * len(latitudes)
39
+ rows[zeroth_row:] = [f'{i}U' for i in range(len(latitudes)-zeroth_row)]
40
+ rows[:zeroth_row] = [f'{abs(i-zeroth_row)}D' for i in range(zeroth_row)]
41
+
42
+ # bound to range
43
+ idxs = (latitudes>=self.latitude_range[0]) * (latitudes<=self.latitude_range[1])
44
+ rows,latitudes = np.array(rows), np.array(latitudes)
45
+ rows,latitudes = rows[idxs],latitudes[idxs]
46
+
47
+ return rows,latitudes
48
+
49
+ def get_circumference_at_latitude(self,lat):
50
+
51
+ # Circumference of the cross-section of a sphere at a given latitude
52
+
53
+ radius_at_lat = self.RADIUS_EQUATOR * math.cos(lat * math.pi / 180)
54
+ circumference = 2 * math.pi * radius_at_lat
55
+
56
+ return circumference
57
+
58
+ def subdivide_circumference(self,lat,return_cols=False):
59
+ # Provide a list of longitudes that subdivide the circumference of the earth at a given latitude
60
+ # into equal parts as close as possible to dist
61
+
62
+ circumference = self.get_circumference_at_latitude(lat)
63
+ num_divisions = math.ceil(circumference / self.dist)
64
+ longitudes = np.linspace(-180,180, num_divisions+1)[:-1]
65
+ longitudes = np.mod(longitudes, 360) - 180
66
+ longitudes = np.sort(longitudes)
67
+
68
+
69
+ if return_cols:
70
+ cols = [None] * len(longitudes)
71
+ zeroth_idx = np.where(longitudes==0)[0][0]
72
+ cols[zeroth_idx:] = [f'{i}R' for i in range(len(longitudes)-zeroth_idx)]
73
+ cols[:zeroth_idx] = [f'{abs(i-zeroth_idx)}L' for i in range(zeroth_idx)]
74
+ return np.array(cols),np.array(longitudes)
75
+
76
+ return np.array(longitudes)
77
+
78
+ def get_points(self):
79
+
80
+ r_idx = 0
81
+ points_by_row = [None]*len(self.rows)
82
+ for r,lat in zip(self.rows,self.lats):
83
+ point_names,grid_row_names,grid_col_names,grid_row_idx,grid_col_idx,grid_lats,grid_lons,utm_zones,epsgs = [],[],[],[],[],[],[],[],[]
84
+ cols,lons = self.subdivide_circumference(lat,return_cols=True)
85
+
86
+ cols,lons = self.filter_longitude(cols,lons)
87
+ c_idx = 0
88
+ for c,lon in zip(cols,lons):
89
+ point_names.append(f'{r}_{c}')
90
+ grid_row_names.append(r)
91
+ grid_col_names.append(c)
92
+ grid_row_idx.append(r_idx)
93
+ grid_col_idx.append(c_idx)
94
+ grid_lats.append(lat)
95
+ grid_lons.append(lon)
96
+ if self.utm_definition == 'bottomleft':
97
+ utm_zones.append(get_utm_zone_from_latlng([lat,lon]))
98
+ elif self.utm_definition == 'center':
99
+ center_lat = lat + (1000*self.dist/2)/111_120
100
+ center_lon = lon + (1000*self.dist/2)/(111_120*math.cos(center_lat*math.pi/180))
101
+ utm_zones.append(get_utm_zone_from_latlng([center_lat,center_lon]))
102
+ else:
103
+ raise ValueError(f'Invalid utm_definition {self.utm_definition}')
104
+ epsgs.append(f'EPSG:{utm_zones[-1]}')
105
+
106
+ c_idx += 1
107
+ points_by_row[r_idx] = gpd.GeoDataFrame({
108
+ 'name':point_names,
109
+ 'row':grid_row_names,
110
+ 'col':grid_col_names,
111
+ 'row_idx':grid_row_idx,
112
+ 'col_idx':grid_col_idx,
113
+ 'utm_zone':utm_zones,
114
+ 'epsg':epsgs
115
+ },geometry=gpd.points_from_xy(grid_lons,grid_lats))
116
+ r_idx += 1
117
+ points = gpd.GeoDataFrame(pd.concat(points_by_row))
118
+ # points.reset_index(inplace=True,drop=True)
119
+ return points, points_by_row
120
+
121
+ def group_points_by_row(self):
122
+ # Make list of different gdfs for each row
123
+ points_by_row = [None]*len(self.rows)
124
+ for i,row in enumerate(self.rows):
125
+ points_by_row[i] = self.points[self.points.row==row]
126
+ return points_by_row
127
+
128
+ def filter_longitude(self,cols,lons):
129
+ idxs = (lons>=self.longitude_range[0]) * (lons<=self.longitude_range[1])
130
+ cols,lons = cols[idxs],lons[idxs]
131
+ return cols,lons
132
+
133
+ def latlon2rowcol(self,lats,lons,return_idx=False,integer=False):
134
+ """
135
+ Convert latitude and longitude to row and column number from the grid
136
+ """
137
+ # Always take bottom left corner of grid cell
138
+ rows = np.searchsorted(self.lats,lats)-1
139
+
140
+ # Get the possible points of the grid cells at the given latitude
141
+ possible_points = [self.points_by_row[row] for row in rows]
142
+
143
+ # For each point, find the rightmost point that is still to the left of the given longitude
144
+ cols = [poss_points.iloc[np.searchsorted(poss_points.geometry.x,lon)-1].col for poss_points,lon in zip(possible_points,lons)]
145
+ rows = self.rows[rows].tolist()
146
+
147
+ outputs = [rows, cols]
148
+ if return_idx:
149
+ # Get the table index for self.points with each row,col pair in rows, cols
150
+ idx = [self.points[(self.points.row==row) & (self.points.col==col)].index.values[0] for row,col in zip(rows,cols)]
151
+ outputs.append(idx)
152
+
153
+ # return raw numbers
154
+ if integer:
155
+ outputs[0] = [int(el[:-1]) if el[-1] == 'U' else -int(el[:-1]) for el in outputs[0]]
156
+ outputs[1] = [int(el[:-1]) if el[-1] == 'R' else -int(el[:-1]) for el in outputs[1]]
157
+
158
+ return outputs
159
+
160
+ def rowcol2latlon(self,rows,cols):
161
+ point_geoms = [self.points.loc[(self.points.row==row) & (self.points.col==col),'geometry'].values[0] for row,col in zip(rows,cols)]
162
+ lats = [point.y for point in point_geoms]
163
+ lons = [point.x for point in point_geoms]
164
+ return lats,lons
165
+
166
+ def get_bounded_footprint(self,point,buffer_ratio=0):
167
+ # Gets the polygon footprint of the grid cell for a given point, bounded by the other grid points' cells.
168
+ # Grid point defined as bottom-left corner of polygon. Buffer ratio is the ratio of the grid cell's width/height to buffer by.
169
+
170
+ bottom,left = point.geometry.y,point.geometry.x
171
+ row_idx = point.row_idx
172
+ col_idx = point.col_idx
173
+ next_row_idx = row_idx+1
174
+ next_col_idx = col_idx+1
175
+
176
+ if next_row_idx >= len(self.lats): # If at top row, use difference between top and second-to-top row for height
177
+ height = (self.lats[row_idx] - self.lats[row_idx-1])
178
+ top = self.lats[row_idx] + height
179
+ else:
180
+ top = self.lats[next_row_idx]
181
+
182
+ max_col = len(self.points_by_row[row_idx].col_idx)-1
183
+ if next_col_idx > max_col: # If at rightmost column, use difference between rightmost and second-to-rightmost column for width
184
+ width = (self.points_by_row[row_idx].iloc[col_idx].geometry.x - self.points_by_row[row_idx].iloc[col_idx-1].geometry.x)
185
+ right = self.points_by_row[row_idx].iloc[col_idx].geometry.x + width
186
+ else:
187
+ right = self.points_by_row[row_idx].iloc[next_col_idx].geometry.x
188
+
189
+ # Buffer the polygon by the ratio of the grid cell's width/height
190
+ width = right - left
191
+ height = top - bottom
192
+
193
+ buffer_horizontal = width * buffer_ratio
194
+ buffer_vertical = height * buffer_ratio
195
+
196
+ new_left = left - buffer_horizontal
197
+ new_right = right + buffer_horizontal
198
+
199
+ new_bottom = bottom - buffer_vertical
200
+ new_top = top + buffer_vertical
201
+
202
+ bbox = Polygon([(new_left,new_bottom),(new_left,new_top),(new_right,new_top),(new_right,new_bottom)])
203
+
204
+ return bbox
205
+
206
+ def get_utm_zone_from_latlng(latlng):
207
+ """
208
+ Get the UTM zone from a latlng list and return the corresponding EPSG code.
209
+
210
+ Parameters
211
+ ----------
212
+ latlng : List[Union[int, float]]
213
+ The latlng list to get the UTM zone from.
214
+
215
+ Returns
216
+ -------
217
+ str
218
+ The EPSG code for the UTM zone.
219
+ """
220
+ assert isinstance(latlng, (list, tuple)), "latlng must be in the form of a list or tuple."
221
+
222
+ longitude = latlng[1]
223
+ latitude = latlng[0]
224
+
225
+ zone_number = (math.floor((longitude + 180) / 6)) % 60 + 1
226
+
227
+ # Special zones for Svalbard and Norway
228
+ if latitude >= 56.0 and latitude < 64.0 and longitude >= 3.0 and longitude < 12.0:
229
+ zone_number = 32
230
+ elif latitude >= 72.0 and latitude < 84.0:
231
+ if longitude >= 0.0 and longitude < 9.0:
232
+ zone_number = 31
233
+ elif longitude >= 9.0 and longitude < 21.0:
234
+ zone_number = 33
235
+ elif longitude >= 21.0 and longitude < 33.0:
236
+ zone_number = 35
237
+ elif longitude >= 33.0 and longitude < 42.0:
238
+ zone_number = 37
239
+
240
+ # Determine the hemisphere and construct the EPSG code
241
+ if latitude < 0:
242
+ epsg_code = f"327{zone_number:02d}"
243
+ else:
244
+ epsg_code = f"326{zone_number:02d}"
245
+ if not re.match(r"32[6-7](0[1-9]|[1-5][0-9]|60)",epsg_code):
246
+ print(f"latlng: {latlng}, epsg_code: {epsg_code}")
247
+ raise ValueError(f"out of bound latlng resulted in incorrect EPSG code for the point")
248
+
249
+ return epsg_code
250
+
251
+
252
+ if __name__ == '__main__':
253
+
254
+ assert get_utm_zone_from_latlng([-1,-174.34]) == "32701"
255
+ assert get_utm_zone_from_latlng([48,-4]) == "32630"
256
+ assert get_utm_zone_from_latlng([78,13]) == "32633"
257
+ assert get_utm_zone_from_latlng([-34,19.7]) == "32734"
258
+ assert get_utm_zone_from_latlng([-36,175.7]) == "32760"
259
+
260
+
261
+ dist = 100
262
+ grid = Grid(dist)
263
+
264
+ np.random.seed(0)
265
+ test_lons = np.random.uniform(-20,20,size=(1000)) % 180 # Checks edge-case of crossing 180th meridian
266
+ test_lats = np.random.uniform(-20,68,size=(1000))
267
+
268
+ test_rows,test_cols = grid.latlon2rowcol(test_lats,test_lons)
269
+ test_lats2,test_lons2 = grid.rowcol2latlon(test_rows,test_cols)
270
+
271
+ print(test_lons[:10])
272
+ print(test_lats[:10])
273
+ print(test_rows[:10])
274
+ print(test_cols[:10])
275
+
276
+ # Make line segments from the points to their corresponding grid points
277
+ lines = []
278
+ for i in range(len(test_lats)):
279
+ lines.append([(test_lons[i],test_lats[i]),(test_lons2[i],test_lats2[i])])
280
+
281
+ lines = gpd.GeoDataFrame(geometry=gpd.GeoSeries([LineString(line) for line in lines]))
282
+
283
+ lines.to_file(f'testlines_{dist}km.geojson',driver='GeoJSON')
284
+ grid.points.to_file(f'testgrid_{dist}km.geojson',driver='GeoJSON')
MajorTOM/metadata_helpers.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyarrow.parquet as pq
2
+ import pandas as pd
3
+ import geopandas as gpd
4
+ from pathlib import Path
5
+ import urllib.request
6
+ import fsspec
7
+ from fsspec.parquet import open_parquet_file
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ from rasterio.io import MemoryFile
11
+ from tqdm.notebook import tqdm
12
+ import os
13
+
14
+ from .sample_helpers import *
15
+
16
+ def metadata_from_url(access_url, local_url):
17
+ local_url, response = urllib.request.urlretrieve(access_url, local_url)
18
+ df = pq.read_table(local_url).to_pandas()
19
+ df['timestamp'] = pd.to_datetime(df.timestamp)
20
+ gdf = gpd.GeoDataFrame(
21
+ df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
22
+ )
23
+ return gdf
24
+
25
+ def filter_metadata(df,
26
+ region=None,
27
+ daterange=None,
28
+ cloud_cover=(0,100),
29
+ nodata=(0, 1.0)
30
+ ):
31
+ """Filters the Major-TOM dataframe based on several parameters
32
+
33
+ Args:
34
+ df (geopandas dataframe): Parent dataframe
35
+ region (shapely geometry object) : Region of interest
36
+ daterange (tuple) : Inclusive range of dates (example format: '2020-01-01')
37
+ cloud_cover (tuple) : Inclusive percentage range (0-100) of cloud cover
38
+ nodata (tuple) : Inclusive fraction (0.0-1.0) of no data allowed in a sample
39
+
40
+ Returns:
41
+ df: a filtered dataframe
42
+ """
43
+ # temporal filtering
44
+ if daterange is not None:
45
+ assert (isinstance(daterange, list) or isinstance(daterange, tuple)) and len(daterange)==2
46
+ df = df[df.timestamp >= daterange[0]]
47
+ df = df[df.timestamp <= daterange[1]]
48
+
49
+ # spatial filtering
50
+ if region is not None:
51
+ idxs = df.sindex.query(region)
52
+ df = df.take(idxs)
53
+ # cloud filtering
54
+ if cloud_cover is not None:
55
+ df = df[df.cloud_cover >= cloud_cover[0]]
56
+ df = df[df.cloud_cover <= cloud_cover[1]]
57
+
58
+ # spatial filtering
59
+ if nodata is not None:
60
+ df = df[df.nodata >= nodata[0]]
61
+ df = df[df.nodata <= nodata[1]]
62
+
63
+ return df
64
+
65
+ def read_row(row, columns=["thumbnail"]):
66
+ """Reads a row from a Major-TOM dataframe
67
+
68
+ Args:
69
+ row (row from geopandas dataframe): The row of metadata
70
+ columns (list): columns to be read from the file
71
+
72
+ Returns:
73
+ data (dict): dictionary with returned data from requested columns
74
+ """
75
+ with open_parquet_file(row.parquet_url, columns=columns, footer_sample_size=2000000) as f:
76
+ with pq.ParquetFile(f) as pf:
77
+ row_group = pf.read_row_group(row.parquet_row, columns=columns)
78
+
79
+ if columns == ["thumbnail"]:
80
+ stream = BytesIO(row_group['thumbnail'][0].as_py())
81
+ return Image.open(stream)
82
+ else:
83
+ row_output = {}
84
+ for col in columns:
85
+ bytes = row_group[col][0].as_py()
86
+
87
+ if col != 'thumbnail':
88
+ row_output[col] = read_tif_bytes(bytes)
89
+ else:
90
+ stream = BytesIO(bytes)
91
+ row_output[col] = Image.open(stream)
92
+
93
+ return row_output
94
+
95
+ def filter_download(df, local_dir, source_name, by_row = False, verbose = False, tif_columns=None):
96
+ """Downloads and unpacks the data of Major-TOM based on a metadata dataframe
97
+
98
+ Args:
99
+ df (geopandas dataframe): Metadata dataframe
100
+ local_dir (str or Path) : Path to the where the data is to be stored locally
101
+ source_name (str) : Name alias of the resulting dataset
102
+ by_row (bool): If True, it will access individual rows of parquet via http - otherwise entire parquets are downloaded temporarily
103
+ verbose (bool) : option for potential internal state printing
104
+ tif_columns (list of str) : Optionally specified columns to be downloaded as .tifs, e.g. ['B04', 'B03', 'B02']
105
+
106
+ Returns:
107
+ None
108
+
109
+ """
110
+
111
+ if isinstance(local_dir, str):
112
+ local_dir = Path(local_dir)
113
+
114
+ temp_file = local_dir / 'temp.parquet'
115
+
116
+ # identify all parquets that need to be downloaded (group them)
117
+ urls = df.parquet_url.unique()
118
+ print('Starting download of {} parquet files.'.format(len(urls))) if verbose else None
119
+
120
+ for url in tqdm(urls, desc='Downloading and unpacking...', disable=not verbose):
121
+ # identify all relevant rows
122
+ rows = df[df.parquet_url == url].parquet_row.unique()
123
+
124
+ if not by_row: # (downloads entire parquet)
125
+ # download a temporary file
126
+ temp_path, http_resp = urllib.request.urlretrieve(url, temp_file)
127
+ else:
128
+ f=fsspec.open(url)
129
+ temp_path = f.open()
130
+
131
+ # populate the bands
132
+ with pq.ParquetFile(temp_path) as pf:
133
+ for row_idx in rows:
134
+ table = pf.read_row_group(row_idx)
135
+
136
+ product_id = table['product_id'][0].as_py()
137
+ grid_cell = table['grid_cell'][0].as_py()
138
+ row = grid_cell.split('_')[0]
139
+
140
+ dest = local_dir / Path("{}/{}/{}/{}".format(source_name, row, grid_cell, product_id))
141
+ dest.mkdir(exist_ok=True, parents=True)
142
+
143
+ columns = [col for col in table.column_names if col[0] == 'B'] + ['cloud_mask'] if tif_columns is None else tif_columns
144
+ # tifs
145
+ for col in columns:
146
+ with open(dest / "{}.tif".format(col), "wb") as f:
147
+ # Write bytes to file
148
+ f.write(table[col][0].as_py())
149
+
150
+ # thumbnail (png)
151
+ col = 'thumbnail'
152
+ with open(dest / "{}.png".format(col), "wb") as f:
153
+ # Write bytes to file
154
+ f.write(table[col][0].as_py())
155
+ if not by_row:
156
+ # remove downloaded file
157
+ os.remove(temp_path)
158
+ else:
159
+ f.close()
MajorTOM/sample_helpers.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rasterio.io import MemoryFile
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from PIL import Image
5
+ from io import BytesIO
6
+
7
+ def plot(sample, bands = ['B04', 'B03', 'B02'], scaling=2e3):
8
+ img = []
9
+ for b in bands:
10
+ img.append(read_tif_bytes(sample[b]))
11
+ plt.imshow(np.stack(img, -1)/2e3)
12
+
13
+ def read_tif_bytes(tif_bytes):
14
+ with MemoryFile(tif_bytes) as mem_f:
15
+ with mem_f.open(driver='GTiff') as f:
16
+ return f.read().squeeze()
17
+
18
+ def read_png_bytes(png_bytes):
19
+ stream = BytesIO(png_bytes)
20
+ return Image.open(stream)
README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EarthExplorer
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # EarthExplorer
14
+
15
+ A tool for searching satellite images of Earth using natural language descriptions, images, geolocations, or a simple click on the map.
16
+
17
+ ## Features
18
+
19
+ - Text-based satellite image search
20
+ - Image-based similarity search
21
+ - Location-based search
22
+ - Interactive map interface
23
+
24
+ ## Clone
25
+
26
+ ```bash
27
+ git clone https://huggingface.co/spaces/ML4Sustain/EarthExplorer
28
+ ```
Tutorial.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tutorial: EarthEmbeddingExplorer
2
+
3
+ ## Background
4
+
5
+ ### What is this project about?
6
+ EarthEmbeddingExplorer is a tool that lets you search satellite imagery using **natural language**, **images**, or **geographic locations**. In simple terms, you can enter prompts like “a satellite image of a glacier” or “a satellite image of a city with a coastline”, and the system will find places on Earth that match your description and visualize them on a map.
7
+
8
+ EarthEmbeddingExplorer enables users to explore the Earth in multiple ways without leaving their desk, and it can be useful for many geoscience tasks. For example, geologists can quickly locate glacier regions; biologists can rapidly map forest cover; and architects can study urban patterns across different parts of the world.
9
+
10
+ ## How does it work? (Core ideas)
11
+
12
+ ### Satellite imagery dataset
13
+ We use **MajorTOM** (Major TOM: Expandable Datasets for Earth Observation) released by the European Space Agency (ESA) [1]. Specifically, we use the [Core-S2L2A](https://modelscope.cn/datasets/Major-TOM/Core-S2L2A) subset.
14
+
15
+ | Dataset | Imagery source | Number of samples | Sensor type |
16
+ | :--- | :--- | :--- | :--- |
17
+ | MajorTOM-Core-S2L2A | Sentinel-2 Level 2A | 2,245,886 | Multispectral |
18
+
19
+ MajorTOM Core-S2L2A provides global Sentinel-2 multispectral imagery (10 m resolution). We convert the RGB bands into embeddings using CLIP-like models (e.g., SigLIP), which saves substantial time because we do not need to preprocess raw imagery ourselves. In addition, embeddings (vectors) are much smaller than raw imagery, and they are significantly faster to search.
20
+
21
+ To keep EarthEmbeddingExplorer responsive, we build a smaller but representative version of the dataset.
22
+
23
+ The original tiles in Core-S2L2A are large (1068×1068 pixels), but most AI models expect smaller inputs (384×384 or 224×224 pixels).
24
+ 1. **Cropping**: for simplicity, from each original tile we only take the **center** 384×384 (or 224×224) crop to generate an embedding.
25
+ 2. **Uniform sampling**: using MajorTOM’s grid coding system, we sample **1%** of the data (about 22,000 images). This preserves global coverage while keeping search fast.
26
+
27
+ <div align="center">
28
+ <img src="images/samples.png" width="50%" />
29
+ <br>
30
+ <em>Figure 1: Geographic distribution of our sampled satellite image embeddings.</em>
31
+ </div>
32
+
33
+ ### Retrieval models
34
+ The core of image retrieval is a family of models known as **CLIP (Contrastive Language-Image Pre-training)** [2]. We use its improved variants such as **SigLIP (Sigmoid Language-Image Pre-training)** [3], **FarSLIP (Fine-grained Aligned Remote Sensing Language Image Pretraining)** [4], and **SatCLIP (Satellite Location-Image Pretraining)** [5].
35
+
36
+ An analogy: when teaching a child, you show a picture of a glacier and say “glacier”. After seeing many examples, the child learns to associate the visual concept with the word.
37
+
38
+ CLIP-like models learn in a similar way, but at much larger scale.
39
+ - An image encoder turns an **image** into an **embedding** (a vector of numbers).
40
+ - A text (or location) encoder turns **text** (or **latitude/longitude**) into an embedding.
41
+
42
+ The key property is: if an image matches a text description (or location), their embeddings will be close; otherwise they will be far apart.
43
+
44
+ <div align="center">
45
+ <img src="images/CLIP.png" width="40%" />
46
+ <br>
47
+ <em>Figure 2: How CLIP-like models connect images and text.</em>
48
+ </div>
49
+
50
+ The three models we use differ in their encoders and training data:
51
+
52
+ | Model | Encoder type | Training data |
53
+ | :--- | :--- | :--- |
54
+ | SigLIP | image encoder + text encoder | natural image–text pairs from the web |
55
+ | FarSLIP | image encoder + text encoder | satellite image–text pairs |
56
+ | SatCLIP | image encoder + location encoder | satellite image–location pairs |
57
+
58
+ <div align="center">
59
+ <img src="images/embedding.png" width="30%" />
60
+ <br>
61
+ <em>Figure 3: Converting satellite images into embedding vectors.</em>
62
+ </div>
63
+
64
+ In EarthEmbeddingExplorer:
65
+ 1. We precompute embeddings for ~22k globally distributed satellite images using SigLIP, FarSLIP, and SatCLIP.
66
+ 2. When you provide a query (text like “a satellite image of glacier”, an image, or a location such as (-89, 120)), we encode the query into an embedding using the corresponding encoder.
67
+ 3. We compare the query embedding with all image embeddings, visualize similarities on a map, and show the top-5 most similar images.
68
+
69
+ ## System architecture
70
+
71
+ <div align="center">
72
+ <img src="images/framework_en.png" width="70%" />
73
+ <br>
74
+ <em>Figure 4: EarthEmbeddingExplorer system architecture on ModelScope.</em>
75
+ </div>
76
+
77
+ We deploy EarthEmbeddingExplorer on ModelScope: the models, embedding datasets, and raw imagery datasets are all hosted on the platform. The app runs on [xGPU](https://www.modelscope.cn/brand/view/xGPU), allowing flexible access to GPU resources and faster retrieval.
78
+
79
+ ### How is the raw imagery stored?
80
+
81
+ MajorTOM Core-S2L2A is large (about 23 TB), so we do not download the full dataset. Instead, the raw imagery is stored as **Parquet shards**:
82
+
83
+ - **Shard storage**: the dataset is split into many remote Parquet files (shards), each containing a subset of the samples.
84
+ - **Columnar storage**: different fields/bands (e.g., B04/B03/B02, thumbnail) are stored as separate columns; we only read what we need.
85
+ - **Metadata index**: we maintain a small index table mapping `product_id → (parquet_url, parquet_row)` so the system can locate “which shard and which position” contains a given image.
86
+
87
+ With this design, when a user only needs a small number of images from the retrieval results, the system can use **HTTP Range requests** to download only a small byte range from a Parquet file (the target row/row group and the requested columns), rather than downloading the full 23 TB dataset—enabling near real-time retrieval of raw images.
88
+
89
+ ### What happens when you use the app?
90
+
91
+ 1. **Enter a query**: you can enter text, upload an image, or input a latitude/longitude. You can also click on the map to use the clicked location as a query.
92
+ 2. **Compute similarity**: the app encodes your query into an embedding vector and computes similarity scores against all satellite image embeddings.
93
+ 3. **Show results**: the system filters out low-similarity results and shows the highest-scoring locations (and scores) on the map. You can adjust the threshold using a slider.
94
+ 4. **Download raw images on demand**: for the top-5 most similar images, the system looks up their `parquet_url` and row position via the metadata index, then uses HTTP Range to fetch only the required data (RGB bands) and displays the images quickly in the UI.
95
+
96
+ ## Examples
97
+ <div align="center">
98
+ <img src="images/Text_Search.jpg" width="99%" />
99
+ <br>
100
+ <em>Figure 5: Search by text.</em>
101
+ </div>
102
+ <br>
103
+
104
+ <div align="center">
105
+ <img src="images/Image_Search_Amazon.jpg" width="99%" />
106
+ <br>
107
+ <em>Figure 6: Search by image.</em>
108
+ </div>
109
+ <br>
110
+
111
+ <div align="center">
112
+ <img src="images/Location_Search_Amazon.jpg" width="99%" />
113
+ <br>
114
+ <em>Figure 7: Search by location.</em>
115
+ </div>
116
+
117
+ ## Limitations
118
+
119
+ While EarthEmbeddingExplorer has strong potential, it also has limitations. SigLIP is primarily trained on “natural images” from the internet (people, pets, cars, everyday objects) rather than satellite imagery. This domain gap can make it harder for the model to understand certain scientific terms or distinctive geographic patterns that are uncommon in typical web photos.
120
+
121
+ FarSLIP may perform poorly on non-remote-sensing concepts described in text, such as queries like “an image of face”.
122
+
123
+ ## Acknowledgements
124
+
125
+ We thank the following open-source projects and datasets that made EarthEmbeddingExplorer possible:
126
+
127
+ **Models:**
128
+ - [SigLIP](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384) - Vision Transformer model for image-text alignment
129
+ - [FarSLIP](https://github.com/NJU-LHRS/FarSLIP) - Fine-grained satellite image-text pretraining model
130
+ - [SatCLIP](https://github.com/microsoft/satclip) - Satellite location-image pretraining model
131
+
132
+ **Datasets:**
133
+ - [MajorTOM](https://github.com/ESA-PhiLab/MajorTOM) - Expandable datasets for Earth observation by ESA
134
+
135
+ We are grateful to the research communities and organizations that developed and shared these resources.
136
+
137
+ ## Contributors
138
+ - [Yijie Zheng](https://voyagerxvoyagerx.github.io/)
139
+ - [Weijie Wu](https://github.com/go-bananas-wwj)
140
+ - [Bingyue Wu](https://brynn-wu.github.io/Brynn-Wu)
141
+
142
+ ## Roadmap
143
+ - [ ] Increase the geographical coverage (sample rate) to 1.2% of of the Earth's land surface. (coming by 16 Jan!)
144
+ - [ ] Support DINOv2 Embedding model and embedding datasets.
145
+ - [ ] Support FAISS for faster similarity search.
146
+ - [ ] What features do you want? Leave an issue [here](https://huggingface.co/spaces/ML4Sustain/EarthExplorer/discussions)!
147
+
148
+ We warmly welcome new contributors!
149
+
150
+ ## References
151
+
152
+ [1] Francis, A., & Czerkawski, M. (2024). Major TOM: Expandable Datasets for Earth Observation. IGARSS 2024.
153
+
154
+ [2] Radford, A., et al. (2021). Learning Transferable Visual Models From Natural Language Supervision. ICML 2021.
155
+
156
+ [3] Zhai, X., et al. (2023). Sigmoid Loss for Language-Image Pre-Training. ICCV 2023.
157
+
158
+ [4] Li, Z., et al. (2025). FarSLIP: Discovering Effective CLIP Adaptation for Fine-Grained Remote Sensing Understanding. arXiv 2025.
159
+
160
+ [5] Klemmer, K. et al. (2025). SatCLIP: Global, General-Purpose Location Embeddings with Satellite Imagery. AAAI 2025.
161
+
162
+ [6] Czerkawski, M., Kluczek, M., & Bojanowski, J. S. (2024). Global and Dense Embeddings of Earth: Major TOM Floating in the Latent Space. arXiv preprint arXiv:2412.05600.
Tutorial_zh.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 教程:EarthExplorer 地球探索者
2
+
3
+ ## 背景介绍
4
+
5
+ ### 这个项目是做什么的?
6
+ EarthExplorer 是一个可以通过**自然语言**,**图像**,或**地理位置**搜索卫星图像的工具。简单来说,你可以输入像“a satellite image of glacier”或“a satellite image of city with a coastline”这样的描述,系统就会在地球上找到符合你描述的地点,并将它们在地图上展示出来。EarthExplorer 可以让用户足不出户地,以多种方式探索地球上的每一个角落,在地理科学领域有广泛的应用价值。例如,地质学家们可以用这个工具来快速寻找冰川的分布;生物学家可以快速进行森林覆盖的制图,建筑学家们可以研究世界不同地区的城市发展结构。
7
+
8
+
9
+ ## 它是如何工作的?(核心原理)
10
+
11
+ ### 卫星影像数据集
12
+ 我们使用了欧空局(ESA)发布的 **MajorTOM** (Major TOM: Expandable Datasets for Earth Observation) 数据集 [1]。具体来说,我们使用的是 [Core-S2L2A](https://modelscope.cn/datasets/Major-TOM/Core-S2L2A) 这个子集。
13
+
14
+ | 数据集 | 影像来源 | 嵌入数量 | 传感器类型 |
15
+ | :--- | :--- | :--- | :--- |
16
+ | MajorTOM-Core-S2L2A | Sentinel-2 Level 2A | 2,245,886 | 多光谱 |
17
+
18
+ MajorTOM Core-S2L2A 包含了全球覆盖的 Sentinel-2 多光谱影像(10m 分辨率);我们将这个数据集则利用 SigLIP 模型将 RGB 波段处理成了嵌入。这为我们节省了大量时间,因为我们不需要自己去处理这些原始图像!此外,图像嵌入(一串数字)的存储空间远小于原始图像,计算效率也更高!
19
+
20
+ 为了让 EarthExplorer 响应迅速,我们创建了一个更小、更有代表性的数据集版本。
21
+
22
+ Core-S2L2A 中的原始卫星图像尺寸很大(1068x1068 像素),但 AI 模型需要较小的输入尺寸(384x384 或 224x224 像素)。
23
+ 1. **裁剪**:为了简化,对每个原尺寸图像,我们仅选取大图正中心的 384x384 或 224x224 像素区域所生成的嵌入。
24
+ 2. **随机采样**:我们根据 MajorTOM 的网格编码系统,均匀采样了 **1%** 的数据(约 22000 张图像)。这样既能保证全球覆盖,又可以在很短的时间内检索出结果。
25
+
26
+ <div align="center">
27
+ <img src="images/samples.png" width="50%" />
28
+ <br>
29
+ <em>图 1:我们采样的卫星图像嵌入的地理分布。</em>
30
+ </div>
31
+
32
+ ### 检索模型
33
+ 图像检索核心技术是一种叫做 **CLIP (Contrastive Language-Image Pre-training)** [2] 的人工智能模型,我们使用的是它的改进版本 **SigLIP (Sigmoid Language-Image Pre-training)** [3], **FarSLIP (Fine-grained Aligned Remote Sensing Language Image Pretraining)** [4], 和 **SatCLIP (Satellite Location-Image Pretraining)** [5]。
34
+
35
+ 想象一下教小孩子识物。你给他们看一张冰川的照片,并说“冰川”。在看了很多冰川的照片并听到这个词后,孩子就学会了将冰川的样子和“冰川”这个词联系起来。
36
+
37
+ SigLIP/FarSLIP/SatCLIP 的工作原理类似,但规模要大得多。它在学习了数百万个图片-文字对或图片-地理位置对,从而理解了图像和文本/地理位置之间的关系。
38
+ - 它使用图片编码器将**图像**转换成一种数学表示(一串数字),我们称之为**嵌入 (Embedding)**。
39
+ - 它也使用文本/地理位置编码器将**文本**或**地理位置(经纬度坐标)**转换成类似的数学表示(嵌入)。
40
+
41
+ 神奇之处在于,如果一张图片和一段文字描述或经纬度是匹配的,它们转换后的数学表示就会非常接近。如果不匹配,它们就会相距很远。
42
+
43
+ <div align="center">
44
+ <img src="images/CLIP.png" width="40%" />
45
+ <br>
46
+ <em>图 2:CLIP 类模型如何连接图像和文本/位置。</em>
47
+ </div>
48
+
49
+ 我们用到的三个模型的模型结构和训练数据是:
50
+ | 模型 | 编码器类型 | 训练数据来源 |
51
+ | :--- | :--- | :--- |
52
+ | SigLIP | 图像编码器+文本编码器 | 互联网上的自然图像-文本对 |
53
+ | FarSLIP | 图像编码器+文本编码器 | 卫星图像-文本对 |
54
+ | SatCLIP | 图像编码器+位置编码器 | 卫星图像-地理位置对 |
55
+
56
+ <div align="center">
57
+ <img src="images/embedding.png" width="30%" />
58
+ <br>
59
+ <em>图 3:将卫星图像转换成嵌入向量。</em>
60
+ </div>
61
+
62
+ 在 EarthExplorer 中:
63
+ 1. 我们将全球均匀采样的两万多张卫星图像,分别使用 SigLIP, FarSLIP, 和 SatCLIP 的图像编码器,将卫星图像已经转换成这种数学“嵌入”。
64
+ 2. 当你输入一个查询,这个查询可以是文本(例如“a satellite image of glacier”),图像(一张冰川的图像),或地理位置(-89, 120),我们将你的查询也使用对应的编码器转换成嵌入。
65
+ 3. 然后,我们将你的查询嵌入与所有卫星图像的嵌入进行比较,将相似度在地图上可视化,并展示最相似的5张图像。
66
+
67
+
68
+ ## 系统架构
69
+
70
+ <div align="center">
71
+ <img src="images/framework_zh.png" width="70%" />
72
+ <br>
73
+ <em>图 4:基于魔搭创空间的 EarthExplorer 系统架构。</em>
74
+ </div>
75
+
76
+ 我们基于魔搭平台进行部署:模型、嵌入数据集、以及原始影像数据集都托管在魔搭上。我们将 APP 部署在 [xGPU](https://www.modelscope.cn/brand/view/xGPU) 环境下,使得用户可以获得灵活调度的免费 GPU 资源,加快检索速度。
77
+
78
+ ### 原始影像是如何存的?
79
+
80
+ MajorTOM Core-S2L2A 的原始影像体量很大(约 23TB),以 **Parquet 分片(shard)** 的方式存储:
81
+
82
+ - **分片存储**:数据被拆成很多个远端 Parquet 文件(分片),每个分片只包含一部分影像样本。
83
+ - **列式存储**:每个影像的不同字段/波段(例如 B04/B03/B02、thumbnail)存成不同的列,需要什么就读什么。
84
+ - **元数据索引**:我们额外维护一份很小的索引表,把 `product_id → (parquet_url, parquet_row)` 对应起来,告诉系统“这个 id 的影像在哪个分片、在分片里的哪个位置”。
85
+
86
+ 这样,当用户只需要查看检索结果的少量影像时,系统可以通过 **HTTP Range 请求**只下载 Parquet 文件中“那一小段字节”(对应目标行/行组 + 指定列的数据),而不是下载整个 23TB 数据集,从而实现秒级取图。
87
+
88
+ ### 当你使用这个 App 时
89
+
90
+ 1. **输入查询**:你可以输入文字、上传图片、输入经纬度;也可以在地图上点击一个位置,直接把该点经纬度作为查询。
91
+ 2. **计算相似度**:App 将你的查询编码成一个“嵌入向量”,并与嵌入数据集中每一张卫星图像的嵌入计算相似度分数。
92
+ 3. **展示检索结果**:系统过滤掉相似度较低的结果,把相似度最高的地点(以及分数)显示在地图上;你可以用滑动条调整阈值。
93
+ 4. **按需下载原图**:对最相似的前 5 张影像,系统用 `product_id` 查询元数据索引定位到远端 `parquet_url` 和行位置,然后通过 HTTP Range 只拉取对应缩略图数据,在前端快速展示原始影像。
94
+
95
+
96
+ ## 示例
97
+ <div align="center">
98
+ <img src="images/Text_Search.jpg" width="99%" />
99
+ <br>
100
+ <em>图 5:以文搜图示例。</em>
101
+ </div>
102
+ <br>
103
+
104
+ <div align="center">
105
+ <img src="images/Image_Search_Amazon.jpg" width="99%" />
106
+ <br>
107
+ <em>图 6:以图搜图示例。</em>
108
+ </div>
109
+
110
+ <br>
111
+ <div align="center">
112
+ <img src="images/Location_Search_Amazon.jpg" width="99%" />
113
+ <br>
114
+ <em>图 7:以点搜图示例。</em>
115
+ </div>
116
+
117
+
118
+ ## 局限性
119
+
120
+ 虽然 EarthExplorer 有很大的应用潜力,但它也有一些局限性。SigLIP 模型主要是通过互联网上的“自然图像”(如人物、猫狗、汽车、日常用品的照片)训练的,而不是专门针对卫星图像训练的。这种训练数据和应用时数据的偏差,使得模型可能难以理解特定的科学术语或在普通网络照片中不常见的独特地理特征。而 FarSLIP 模型对非典型遥感地物的语言描述,例如 'an image of face' 的检索效果不佳。
121
+
122
+ 未来的工作可以使用其他专门针对地球观测数据训练的 AI 模型来提高检索的准确性。
123
+
124
+ ## 未来工作
125
+ - 结合时间序列影像,实现全球变化监测
126
+ - 添加不同地球基础模型,对比不同模型的检索性能
127
+
128
+ ## 致谢
129
+ 我们感谢以下开源项目和数据集,它们使 EarthExplorer 得以实现:
130
+
131
+ **模型:**
132
+ - [SigLIP](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384) - 用于图像-文本对齐的视觉Transformer模型
133
+ - [FarSLIP](https://github.com/NJU-LHRS/FarSLIP) - 细粒度卫星图像-文本预训练模型
134
+ - [SatCLIP](https://github.com/microsoft/satclip) - 卫星位置-图像预训练模型
135
+
136
+ **数据集:**
137
+ - [MajorTOM](https://github.com/ESA-PhiLab/MajorTOM) - 欧洲航天局(ESA)的可扩展地球观测数据集
138
+
139
+ 我们感谢开发和分享这些资源的研究社区和组织。
140
+
141
+ ## 贡献者
142
+ - [郑祎杰](https://voyagerxvoyagerx.github.io/)
143
+ - [伍炜杰](https://github.com/go-bananas-wwj)
144
+ - [吴冰玥](https://brynn-wu.github.io/Brynn-Wu)
145
+
146
+ ## 引用
147
+ [1] Francis, A., & Czerkawski, M. (2024). Major TOM: Expandable Datasets for Earth Observation. IGARSS 2024.
148
+
149
+ [2] Radford, A., et al. (2021). Learning Transferable Visual Models From Natural Language Supervision. ICML 2021.
150
+
151
+ [3] Zhai, X., et al. (2023). Sigmoid Loss for Language-Image Pre-Training. ICCV 2023.
152
+
153
+ [4] Li, Z., et al. (2025). FarSLIP: Discovering Effective CLIP Adaptation for Fine-Grained Remote Sensing Understanding. arXiv 2025.
154
+
155
+ [5] Klemmer, K. et al. (2025). SatCLIP: Global, General-Purpose Location Embeddings with Satellite Imagery. AAAI 2025.
156
+
157
+ [6] Czerkawski, M., Kluczek, M., & Bojanowski, J. S. (2024). Global and Dense Embeddings of Earth: Major TOM Floating in the Latent Space. arXiv preprint arXiv:2412.05600.
app.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ import os
5
+ import tempfile
6
+ import zipfile
7
+ import numpy as np
8
+ import pandas as pd
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+
11
+ # Import custom modules
12
+ from models.siglip_model import SigLIPModel
13
+ from models.satclip_model import SatCLIPModel
14
+ from models.farslip_model import FarSLIPModel
15
+ from models.load_config import load_and_process_config
16
+ from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution
17
+ from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image
18
+ from PIL import Image as PILImage
19
+ from PIL import ImageDraw, ImageFont
20
+
21
+ # Configuration
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ print(f"Running on device: {device}")
24
+
25
+ # Load and process configuration
26
+ config = load_and_process_config()
27
+
28
+ # Initialize Models
29
+ print("Initializing models...")
30
+ models = {}
31
+
32
+ # SigLIP
33
+ try:
34
+ if config and 'siglip' in config:
35
+ models['SigLIP'] = SigLIPModel(
36
+ ckpt_path=config['siglip'].get('ckpt_path'),
37
+ tokenizer_path=config['siglip'].get('tokenizer_path'),
38
+ embedding_path=config['siglip'].get('embedding_path'),
39
+ device=device
40
+ )
41
+ else:
42
+ models['SigLIP'] = SigLIPModel(device=device)
43
+ except Exception as e:
44
+ print(f"Failed to load SigLIP: {e}")
45
+
46
+ # SatCLIP
47
+ try:
48
+ if config and 'satclip' in config:
49
+ models['SatCLIP'] = SatCLIPModel(
50
+ ckpt_path=config['satclip'].get('ckpt_path'),
51
+ embedding_path=config['satclip'].get('embedding_path'),
52
+ device=device
53
+ )
54
+ else:
55
+ models['SatCLIP'] = SatCLIPModel(device=device)
56
+ except Exception as e:
57
+ print(f"Failed to load SatCLIP: {e}")
58
+
59
+ # FarSLIP
60
+ try:
61
+ if config and 'farslip' in config:
62
+ models['FarSLIP'] = FarSLIPModel(
63
+ ckpt_path=config['farslip'].get('ckpt_path'),
64
+ model_name=config['farslip'].get('model_name'),
65
+ embedding_path=config['farslip'].get('embedding_path'),
66
+ device=device
67
+ )
68
+ else:
69
+ models['FarSLIP'] = FarSLIPModel(device=device)
70
+ except Exception as e:
71
+ print(f"Failed to load FarSLIP: {e}")
72
+
73
+ def get_active_model(model_name):
74
+ if model_name not in models:
75
+ return None, f"Model {model_name} not loaded."
76
+ return models[model_name], None
77
+
78
+ def combine_images(img1, img2):
79
+ if img1 is None: return img2
80
+ if img2 is None: return img1
81
+
82
+ # Resize to match width
83
+ w1, h1 = img1.size
84
+ w2, h2 = img2.size
85
+
86
+ new_w = max(w1, w2)
87
+ new_h1 = int(h1 * new_w / w1)
88
+ new_h2 = int(h2 * new_w / w2)
89
+
90
+ img1 = img1.resize((new_w, new_h1))
91
+ img2 = img2.resize((new_w, new_h2))
92
+
93
+ dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255))
94
+ dst.paste(img1, (0, 0))
95
+ dst.paste(img2, (0, new_h1))
96
+ return dst
97
+
98
+ def create_text_image(text, size=(384, 384)):
99
+ img = PILImage.new('RGB', size, color=(240, 240, 240))
100
+ d = ImageDraw.Draw(img)
101
+
102
+ # Try to load a font, fallback to default
103
+ try:
104
+ # Try to find a font that supports larger size
105
+ font = ImageFont.truetype("DejaVuSans.ttf", 40)
106
+ except:
107
+ font = ImageFont.load_default()
108
+
109
+ # Wrap text simply
110
+ margin = 20
111
+ offset = 100
112
+ for line in text.split(','):
113
+ d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0))
114
+ offset += 50
115
+
116
+ d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255))
117
+ return img
118
+
119
+ def fetch_top_k_images(top_indices, probs, df_embed, query_text=None):
120
+ """
121
+ Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image.
122
+ """
123
+ results = []
124
+
125
+ # We can run this in parallel
126
+ with ThreadPoolExecutor(max_workers=5) as executor:
127
+ future_to_idx = {}
128
+ for i, idx in enumerate(top_indices):
129
+ row = df_embed.iloc[idx]
130
+ pid = row['product_id']
131
+
132
+ # Use download_and_process_image to get real data
133
+ future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False)
134
+ future_to_idx[future] = idx
135
+
136
+ for future in as_completed(future_to_idx):
137
+ idx = future_to_idx[future]
138
+ try:
139
+ img_384, img_full = future.result()
140
+
141
+ if img_384 is None:
142
+ # Fallback to Esri if download fails
143
+ print(f"Download failed for idx {idx}, falling back to Esri...")
144
+ row = df_embed.iloc[idx]
145
+ img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text)
146
+ img_full = img_384
147
+
148
+ row = df_embed.iloc[idx]
149
+ results.append({
150
+ 'image_384': img_384,
151
+ 'image_full': img_full,
152
+ 'score': probs[idx],
153
+ 'lat': row['centre_lat'],
154
+ 'lon': row['centre_lon'],
155
+ 'id': row['product_id']
156
+ })
157
+ except Exception as e:
158
+ print(f"Error fetching image for idx {idx}: {e}")
159
+
160
+ # Sort results by score descending (since futures complete in random order)
161
+ results.sort(key=lambda x: x['score'], reverse=True)
162
+ return results
163
+
164
+ def get_all_results_metadata(model, filtered_indices, probs):
165
+ if len(filtered_indices) == 0:
166
+ return []
167
+
168
+ # Sort by score descending
169
+ filtered_scores = probs[filtered_indices]
170
+ sorted_order = np.argsort(filtered_scores)[::-1]
171
+ sorted_indices = filtered_indices[sorted_order]
172
+
173
+ # Extract from DataFrame
174
+ df_results = model.df_embed.iloc[sorted_indices].copy()
175
+ df_results['score'] = probs[sorted_indices]
176
+
177
+ # Rename columns
178
+ df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'})
179
+
180
+ # Convert to list of dicts
181
+ return df_results[['id', 'lat', 'lon', 'score']].to_dict('records')
182
+
183
+ def search_text(query, threshold, model_name):
184
+ model, error = get_active_model(model_name)
185
+ if error:
186
+ yield None, None, error, None, None, None, None
187
+ return
188
+
189
+ if not query:
190
+ yield None, None, "Please enter a query.", None, None, None, None
191
+ return
192
+
193
+ try:
194
+ timings = {}
195
+
196
+ # 1. Encode Text
197
+ yield None, None, "Encoding text...", None, None, None, None
198
+ t0 = time.time()
199
+ text_features = model.encode_text(query)
200
+ timings['Encoding'] = time.time() - t0
201
+
202
+ if text_features is None:
203
+ yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None
204
+ return
205
+
206
+ # 2. Search
207
+ yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None
208
+ t0 = time.time()
209
+ probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0)
210
+ timings['Retrieval'] = time.time() - t0
211
+
212
+ if probs is None:
213
+ yield None, None, "Search failed (embeddings missing?).", None, None, None, None
214
+ return
215
+
216
+ # Show geographic distribution (not timed)
217
+ df_embed = model.df_embed
218
+ geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})')
219
+
220
+ # 3. Download Images
221
+ yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
222
+ t0 = time.time()
223
+ top_indices = top_indices[:10]
224
+ results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query)
225
+ timings['Download'] = time.time() - t0
226
+
227
+ # 4. Visualize - keep geo_dist_map visible
228
+ yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
229
+ t0 = time.time()
230
+ fig_results = plot_top5_overview(None, results, query_info=query)
231
+ gallery_items = format_results_for_gallery(results)
232
+ timings['Visualization'] = time.time() - t0
233
+
234
+ # 5. Generate Final Status
235
+ timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
236
+ status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
237
+
238
+ all_results = get_all_results_metadata(model, filtered_indices, probs)
239
+ results_txt = format_results_to_text(all_results)
240
+
241
+ yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
242
+
243
+ except Exception as e:
244
+ import traceback
245
+ traceback.print_exc()
246
+ yield None, None, f"Error: {str(e)}", None, None, None, None
247
+
248
+ def search_image(image_input, threshold, model_name):
249
+ model, error = get_active_model(model_name)
250
+ if error:
251
+ yield None, None, error, None, None, None, None
252
+ return
253
+
254
+ if image_input is None:
255
+ yield None, None, "Please upload an image.", None, None, None, None
256
+ return
257
+
258
+ try:
259
+ timings = {}
260
+
261
+ # 1. Encode Image
262
+ yield None, None, "Encoding image...", None, None, None, None
263
+ t0 = time.time()
264
+ image_features = model.encode_image(image_input)
265
+ timings['Encoding'] = time.time() - t0
266
+
267
+ if image_features is None:
268
+ yield None, None, "Model does not support image encoding.", None, None, None, None
269
+ return
270
+
271
+ # 2. Search
272
+ yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None
273
+ t0 = time.time()
274
+ probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0)
275
+ timings['Retrieval'] = time.time() - t0
276
+
277
+ # Show geographic distribution (not timed)
278
+ df_embed = model.df_embed
279
+ geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})')
280
+
281
+ # 3. Download Images
282
+ yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
283
+ t0 = time.time()
284
+ top_indices = top_indices[:6]
285
+ results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query")
286
+ timings['Download'] = time.time() - t0
287
+
288
+ # 4. Visualize - keep geo_dist_map visible
289
+ yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
290
+ t0 = time.time()
291
+ fig_results = plot_top5_overview(image_input, results, query_info="Image Query")
292
+ gallery_items = format_results_for_gallery(results)
293
+ timings['Visualization'] = time.time() - t0
294
+
295
+ # 5. Generate Final Status
296
+ timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
297
+ status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
298
+
299
+ all_results = get_all_results_metadata(model, filtered_indices, probs)
300
+ results_txt = format_results_to_text(all_results[:50])
301
+
302
+ yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
303
+
304
+ except Exception as e:
305
+ import traceback
306
+ traceback.print_exc()
307
+ yield None, None, f"Error: {str(e)}", None, None, None, None
308
+
309
+ def search_location(lat, lon, threshold):
310
+ model_name = "SatCLIP"
311
+ model, error = get_active_model(model_name)
312
+ if error:
313
+ yield None, None, error, None, None, None, None
314
+ return
315
+
316
+ try:
317
+ timings = {}
318
+
319
+ # 1. Encode Location
320
+ yield None, None, "Encoding location...", None, None, None, None
321
+ t0 = time.time()
322
+ loc_features = model.encode_location(float(lat), float(lon))
323
+ timings['Encoding'] = time.time() - t0
324
+
325
+ if loc_features is None:
326
+ yield None, None, "Location encoding failed.", None, None, None, None
327
+ return
328
+
329
+ # 2. Search
330
+ yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None
331
+ t0 = time.time()
332
+ probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0)
333
+ timings['Retrieval'] = time.time() - t0
334
+
335
+ # 3. Generate Distribution Map (not timed for location distribution)
336
+ yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None
337
+ df_embed = model.df_embed
338
+ top_10_indices = top_indices[:10]
339
+ top_10_results = []
340
+ for idx in top_10_indices:
341
+ row = df_embed.iloc[idx]
342
+ top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']})
343
+
344
+ # Show geographic distribution (not timed)
345
+ geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})')
346
+
347
+ # 4. Download Images
348
+ yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
349
+ t0 = time.time()
350
+ top_6_indices = top_indices[:6]
351
+ results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}")
352
+
353
+ # Get query tile
354
+ query_tile = None
355
+ try:
356
+ lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce')
357
+ lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce')
358
+ dists = (lats - float(lat))**2 + (lons - float(lon))**2
359
+ nearest_idx = dists.idxmin()
360
+ pid = df_embed.loc[nearest_idx, 'product_id']
361
+ query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False)
362
+ except Exception as e:
363
+ print(f"Error fetching nearest MajorTOM image: {e}")
364
+ if query_tile is None:
365
+ query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})")
366
+ timings['Download'] = time.time() - t0
367
+
368
+ # 5. Visualize - keep geo_dist_map visible
369
+ yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
370
+ t0 = time.time()
371
+ fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}")
372
+ gallery_items = format_results_for_gallery(results)
373
+ timings['Visualization'] = time.time() - t0
374
+
375
+ # 6. Generate Final Status
376
+ timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
377
+ status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
378
+
379
+ all_results = get_all_results_metadata(model, filtered_indices, probs)
380
+ results_txt = format_results_to_text(all_results)
381
+
382
+ yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
383
+
384
+ except Exception as e:
385
+ import traceback
386
+ traceback.print_exc()
387
+ yield None, None, f"Error: {str(e)}", None, None, None, None
388
+
389
+ def generate_status_msg(count, threshold, results):
390
+ status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n"
391
+ for i, res in enumerate(results[:3]):
392
+ status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n"
393
+ return status_msg
394
+
395
+ def get_initial_plot():
396
+ # Use FarSLIP as default for initial plot, fallback to SigLIP
397
+ df_vis = None
398
+ img = None
399
+ if 'FarSLIP' in models and models['FarSLIP'].df_embed is not None:
400
+ img, df_vis = plot_global_map_static(models['FarSLIP'].df_embed)
401
+ # fig = plot_global_map(models['FarSLIP'].df_embed)
402
+ elif 'SigLIP' in models and models['SigLIP'].df_embed is not None:
403
+ img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
404
+ return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False)
405
+
406
+ def handle_map_click(evt: gr.SelectData, df_vis):
407
+ if evt is None:
408
+ return None, None, None, "No point selected."
409
+
410
+ try:
411
+ x, y = evt.index[0], evt.index[1]
412
+
413
+ # Image dimensions (New)
414
+ img_width = 4000
415
+ img_height = 2000
416
+
417
+ # Scaled Margins (Proportional to 4000x2000)
418
+ left_margin = 110
419
+ right_margin = 110
420
+ top_margin = 100
421
+ bottom_margin = 67
422
+
423
+ plot_width = img_width - left_margin - right_margin
424
+ plot_height = img_height - top_margin - bottom_margin
425
+
426
+ # Adjust for aspect ratio preservation
427
+ map_aspect = 360.0 / 180.0 # 2.0
428
+ plot_aspect = plot_width / plot_height
429
+
430
+ if plot_aspect > map_aspect:
431
+ actual_map_width = plot_height * map_aspect
432
+ actual_map_height = plot_height
433
+ h_offset = (plot_width - actual_map_width) / 2
434
+ v_offset = 0
435
+ else:
436
+ actual_map_width = plot_width
437
+ actual_map_height = plot_width / map_aspect
438
+ h_offset = 0
439
+ v_offset = (plot_height - actual_map_height) / 2
440
+
441
+ # Calculate relative position within the plot area
442
+ x_in_plot = x - left_margin
443
+ y_in_plot = y - top_margin
444
+
445
+ # Check if click is within the actual map bounds
446
+ if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or
447
+ y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height):
448
+ return None, None, None, "Click outside map area. Please click on the map."
449
+
450
+ # Calculate relative position within the map (0 to 1)
451
+ x_rel = (x_in_plot - h_offset) / actual_map_width
452
+ y_rel = (y_in_plot - v_offset) / actual_map_height
453
+
454
+ # Clamp to [0, 1]
455
+ x_rel = max(0, min(1, x_rel))
456
+ y_rel = max(0, min(1, y_rel))
457
+
458
+ # Convert to geographic coordinates
459
+ lon = x_rel * 360 - 180
460
+ lat = 90 - y_rel * 180
461
+
462
+ # Find nearest point in df_vis if available
463
+ pid = ""
464
+ if df_vis is not None:
465
+ dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2
466
+ min_idx = dists.idxmin()
467
+ nearest_row = df_vis.loc[min_idx]
468
+
469
+ if dists[min_idx] < 25:
470
+ lat = nearest_row['centre_lat']
471
+ lon = nearest_row['centre_lon']
472
+ pid = nearest_row['product_id']
473
+
474
+ except Exception as e:
475
+ print(f"Error handling click: {e}")
476
+ import traceback
477
+ traceback.print_exc()
478
+ return None, None, None, f"Error: {e}"
479
+
480
+ return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})"
481
+
482
+ def download_image_by_location(lat, lon, pid, model_name):
483
+ """Download and return the image at the specified location"""
484
+ if lat is None or lon is None:
485
+ return None, "Please specify coordinates first."
486
+
487
+ model, error = get_active_model(model_name)
488
+ if error:
489
+ return None, error
490
+
491
+ try:
492
+ # Convert to float to ensure proper formatting
493
+ lat = float(lat)
494
+ lon = float(lon)
495
+
496
+ # Find Product ID if not provided
497
+ if not pid:
498
+ df = model.df_embed
499
+ lats = pd.to_numeric(df['centre_lat'], errors='coerce')
500
+ lons = pd.to_numeric(df['centre_lon'], errors='coerce')
501
+ dists = (lats - lat)**2 + (lons - lon)**2
502
+ nearest_idx = dists.idxmin()
503
+ pid = df.loc[nearest_idx, 'product_id']
504
+
505
+ # Download image
506
+ img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True)
507
+
508
+ if img_384 is None:
509
+ return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})"
510
+
511
+ return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})"
512
+
513
+ except Exception as e:
514
+ import traceback
515
+ traceback.print_exc()
516
+ return None, f"Error: {str(e)}"
517
+
518
+ def reset_to_global_map():
519
+ """Reset the map to the initial global distribution view"""
520
+ img = None
521
+ df_vis = None
522
+ if 'FarSLIP' in models and models['FarSLIP'].df_embed is not None:
523
+ img, df_vis = plot_global_map_static(models['FarSLIP'].df_embed)
524
+ elif 'SigLIP' in models and models['SigLIP'].df_embed is not None:
525
+ img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
526
+
527
+ return gr.update(value=img, visible=True), [img], df_vis
528
+
529
+ def format_results_to_text(results):
530
+ if not results:
531
+ return "No results found."
532
+
533
+ txt = f"Top {len(results)} Retrieval Results\n"
534
+ txt += "=" * 30 + "\n\n"
535
+ for i, res in enumerate(results):
536
+ txt += f"Rank: {i+1}\n"
537
+ txt += f"Product ID: {res['id']}\n"
538
+ txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n"
539
+ txt += f"Similarity Score: {res['score']:.6f}\n"
540
+ txt += "-" * 30 + "\n"
541
+ return txt
542
+
543
+ def save_plot(figs):
544
+ if figs is None:
545
+ return None
546
+ try:
547
+ # If it's a single image (initial state), save as png
548
+ if isinstance(figs, PILImage.Image):
549
+ fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
550
+ os.close(fd)
551
+ figs.save(path)
552
+ return path
553
+
554
+ # If it's a list/tuple of images [map_img, results_img]
555
+ if isinstance(figs, (list, tuple)):
556
+ # If only one image in list, save as PNG
557
+ if len(figs) == 1 and isinstance(figs[0], PILImage.Image):
558
+ fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
559
+ os.close(fd)
560
+ figs[0].save(path)
561
+ return path
562
+
563
+ fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_')
564
+ os.close(fd)
565
+
566
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
567
+ # Save Map
568
+ if figs[0] is not None:
569
+ map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png')
570
+ figs[0].save(map_path)
571
+ zipf.write(map_path, arcname='map_distribution.png')
572
+
573
+ # Save Results
574
+ if len(figs) > 1 and figs[1] is not None:
575
+ res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png')
576
+ figs[1].save(res_path)
577
+ zipf.write(res_path, arcname='retrieval_results.png')
578
+
579
+ # Save Results Text
580
+ if len(figs) > 2 and figs[2] is not None:
581
+ txt_path = os.path.join(tempfile.gettempdir(), 'results.txt')
582
+ with open(txt_path, 'w', encoding='utf-8') as f:
583
+ f.write(figs[2])
584
+ zipf.write(txt_path, arcname='results.txt')
585
+
586
+ return zip_path
587
+
588
+ # Fallback for Plotly figure (if any)
589
+ # Create a temporary file
590
+ fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_')
591
+ os.close(fd)
592
+
593
+ # Write to the temporary file
594
+ figs.write_html(path)
595
+ return path
596
+ except Exception as e:
597
+ print(f"Error saving: {e}")
598
+ return None
599
+
600
+ # Gradio Blocks Interface
601
+ with gr.Blocks(title="EarthEmbeddingExplorer") as demo:
602
+ gr.Markdown("# EarthEmbeddingExplorer")
603
+ gr.HTML("""
604
+ <div style="font-size: 1.2em;">
605
+ EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images.
606
+ </div>
607
+
608
+ <div style="display: flex; gap: 0.2em; align-items: center; justify-content: center;">
609
+ <a href="https://www.modelscope.cn/studios/VoyagerX/EarthExplorer"><img src="https://img.shields.io/badge/Open in ModelScope.cn-xGPU-624aff"></a>
610
+ <a href="https://www.modelscope.ai/studios/VoyagerX/EarthExplorer"><img src="https://img.shields.io/badge/Open in ModelScope.ai-CPU-624aff"></a>
611
+ <a href="https://huggingface.co/spaces/ML4Sustain/EarthExplorer"><img src="https://img.shields.io/badge/Open in HF Space-CPU-FFD21E"></a>
612
+ <a href="https://modelscope.cn/studios/VoyagerX/EarthExplorer/file/view/master/Tutorial.md?status=1"> <img src="https://img.shields.io/badge/Tutorial-📖-007bff"> </a>
613
+ <a href="https://www.modelscope.cn/learn/3958"> <img src="https://img.shields.io/badge/中文教程-📖-007bff"> </a>
614
+ </div>
615
+
616
+ """)
617
+
618
+ with gr.Row():
619
+ with gr.Column(scale=4):
620
+ with gr.Tabs():
621
+ with gr.TabItem("Text Search") as tab_text:
622
+ model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model")
623
+ query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier")
624
+
625
+ gr.Examples(
626
+ examples=[
627
+ ["a satellite image of a river around a city"],
628
+ ["a satellite image of a rainforest"],
629
+ ["a satellite image of a slum"],
630
+ ["a satellite image of a glacier"],
631
+ ["a satellite image of snow covered mountains"]
632
+ ],
633
+ inputs=[query_input],
634
+ label="Text Examples"
635
+ )
636
+
637
+ search_btn = gr.Button("Search by Text", variant="primary")
638
+
639
+ with gr.TabItem("Image Search") as tab_image:
640
+ model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP"], value="FarSLIP", label="Model")
641
+
642
+ gr.Markdown("### Option 1: Upload or Select Image")
643
+ image_input = gr.Image(type="pil", label="Upload Image")
644
+
645
+ gr.Examples(
646
+ examples=[
647
+ ["./examples/example1.png"],
648
+ ["./examples/example2.png"],
649
+ ["./examples/example3.png"]
650
+ ],
651
+ inputs=[image_input],
652
+ label="Image Examples"
653
+ )
654
+
655
+ gr.Markdown("### Option 2: Click Map or Enter Coordinates")
656
+ btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
657
+
658
+ with gr.Row():
659
+ img_lat = gr.Number(label="Latitude", interactive=True)
660
+ img_lon = gr.Number(label="Longitude", interactive=True)
661
+
662
+ img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
663
+ img_click_status = gr.Markdown("")
664
+
665
+ btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary")
666
+
667
+ search_img_btn = gr.Button("Search by Image", variant="primary")
668
+
669
+ with gr.TabItem("Location Search") as tab_location:
670
+ gr.Markdown("Search using **SatCLIP** location encoder.")
671
+
672
+ gr.Markdown("### Click Map or Enter Coordinates")
673
+ btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
674
+
675
+ with gr.Row():
676
+ lat_input = gr.Number(label="Latitude", value=30.0, interactive=True)
677
+ lon_input = gr.Number(label="Longitude", value=120.0, interactive=True)
678
+
679
+ loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
680
+ loc_click_status = gr.Markdown("")
681
+
682
+ gr.Examples(
683
+ examples=[
684
+ [30.32, 120.15],
685
+ [40.7128, -74.0060],
686
+ [24.65, 46.71],
687
+ [-3.4653, -62.2159],
688
+ [64.4, 16.8]
689
+ ],
690
+ inputs=[lat_input, lon_input],
691
+ label="Location Examples"
692
+ )
693
+
694
+ search_loc_btn = gr.Button("Search by Location", variant="primary")
695
+
696
+ threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)")
697
+ status_output = gr.Textbox(label="Status", lines=10)
698
+ save_btn = gr.Button("Download Result")
699
+ download_file = gr.File(label="Zipped Results", height=40)
700
+
701
+ with gr.Column(scale=6):
702
+ plot_map = gr.Image(
703
+ label="Geographical Distribution",
704
+ type="pil",
705
+ interactive=False,
706
+ height=400,
707
+ width=800,
708
+ visible=True
709
+ )
710
+ plot_map_interactive = gr.Plot(
711
+ label="Geographical Distribution (Interactive)",
712
+ visible=False
713
+ )
714
+ results_plot = gr.Image(label="Top 5 Matched Images", type="pil")
715
+ gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto")
716
+
717
+ current_fig = gr.State()
718
+ map_data_state = gr.State()
719
+
720
+ # Initial Load
721
+ demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive])
722
+
723
+ # Reset Map Buttons
724
+ btn_reset_map_img.click(
725
+ fn=reset_to_global_map,
726
+ outputs=[plot_map, current_fig, map_data_state]
727
+ )
728
+
729
+ btn_reset_map_loc.click(
730
+ fn=reset_to_global_map,
731
+ outputs=[plot_map, current_fig, map_data_state]
732
+ )
733
+
734
+ # Map Click Event - updates Image Search coordinates
735
+ plot_map.select(
736
+ fn=handle_map_click,
737
+ inputs=[map_data_state],
738
+ outputs=[img_lat, img_lon, img_pid, img_click_status]
739
+ )
740
+
741
+ # Map Click Event - also updates Location Search coordinates
742
+ plot_map.select(
743
+ fn=handle_map_click,
744
+ inputs=[map_data_state],
745
+ outputs=[lat_input, lon_input, loc_pid, loc_click_status]
746
+ )
747
+
748
+ # Download Image by Geolocation
749
+ btn_download_img.click(
750
+ fn=download_image_by_location,
751
+ inputs=[img_lat, img_lon, img_pid, model_selector_img],
752
+ outputs=[image_input, img_click_status]
753
+ )
754
+
755
+ # Search Event (Text)
756
+ search_btn.click(
757
+ fn=search_text,
758
+ inputs=[query_input, threshold_slider, model_selector_text],
759
+ outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
760
+ )
761
+
762
+ # Search Event (Image)
763
+ search_img_btn.click(
764
+ fn=search_image,
765
+ inputs=[image_input, threshold_slider, model_selector_img],
766
+ outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
767
+ )
768
+
769
+ # Search Event (Location)
770
+ search_loc_btn.click(
771
+ fn=search_location,
772
+ inputs=[lat_input, lon_input, threshold_slider],
773
+ outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
774
+ )
775
+
776
+ # Save Event
777
+ save_btn.click(
778
+ fn=save_plot,
779
+ inputs=[current_fig],
780
+ outputs=[download_file]
781
+ )
782
+
783
+ # Tab Selection Events
784
+ def show_static_map():
785
+ return gr.update(visible=True), gr.update(visible=False)
786
+
787
+ tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
788
+ tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
789
+ tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
790
+
791
+ if __name__ == "__main__":
792
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
configs/huggingface.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ siglip:
2
+ ckpt_path: "hf"
3
+ model_name: "ViT-SO400M-14-SigLIP-384"
4
+ tokenizer_path: "hf"
5
+ embedding_path: "hf://ML4Sustain/EarthEmbeddings/uniform_sample_250k/siglip/SigLIP_grid_sample_center_384x384_243k.parquet"
6
+ farslip:
7
+ ckpt_path: "hf"
8
+ model_name: "ViT-B-16"
9
+ embedding_path: "hf://ML4Sustain/EarthEmbeddings/uniform_sample_250k/farslip/FarSLIP_grid_sample_center_384x384_243k.parquet"
10
+ satclip:
11
+ ckpt_path: "hf"
12
+ embedding_path: "hf://ML4Sustain/EarthEmbeddings/uniform_sample_250k/satclip/SatCLIP_grid_sample_center_384x384_243k.parquet"
countries.geo.json ADDED
The diff for this file is too large to render. See raw diff
 
data_utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fsspec
2
+ import pyarrow.parquet as pq
3
+ import numpy as np
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from rasterio.io import MemoryFile
7
+ import matplotlib.pyplot as plt
8
+ import cartopy.crs as ccrs
9
+ import cartopy.io.img_tiles as cimgt
10
+ from matplotlib.patches import Rectangle
11
+ import math
12
+ from matplotlib.figure import Figure
13
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
14
+
15
+
16
+ def crop_center(img_array, cropx, cropy):
17
+ y, x, c = img_array.shape
18
+ startx = x // 2 - (cropx // 2)
19
+ starty = y // 2 - (cropy // 2)
20
+ return img_array[starty:starty+cropy, startx:startx+cropx]
21
+
22
+ def read_tif_bytes(tif_bytes):
23
+ with MemoryFile(tif_bytes) as mem_f:
24
+ with mem_f.open(driver='GTiff') as f:
25
+ return f.read().squeeze()
26
+
27
+ def read_row_memory(row_dict, columns=["thumbnail"]):
28
+ url = row_dict['parquet_url']
29
+ row_idx = row_dict['parquet_row']
30
+
31
+ fs_options = {
32
+ "cache_type": "readahead",
33
+ "block_size": 5 * 1024 * 1024
34
+ }
35
+
36
+ with fsspec.open(url, mode='rb', **fs_options) as f:
37
+ with pq.ParquetFile(f) as pf:
38
+ table = pf.read_row_group(row_idx, columns=columns)
39
+
40
+ row_output = {}
41
+ for col in columns:
42
+ col_data = table[col][0].as_py()
43
+
44
+ if col != 'thumbnail':
45
+ row_output[col] = read_tif_bytes(col_data)
46
+ else:
47
+ stream = BytesIO(col_data)
48
+ row_output[col] = Image.open(stream)
49
+
50
+ return row_output
51
+
52
+ def download_and_process_image(product_id, df_source=None, verbose=True):
53
+ if df_source is None:
54
+ if verbose: print("❌ Error: No DataFrame provided.")
55
+ return None, None
56
+
57
+ row_subset = df_source[df_source['product_id'] == product_id]
58
+ if len(row_subset) == 0:
59
+ if verbose: print(f"❌ Error: Product ID {product_id} not found in DataFrame.")
60
+ return None, None
61
+
62
+ row_dict = row_subset.iloc[0].to_dict()
63
+
64
+ if 'parquet_url' in row_dict:
65
+ url = row_dict['parquet_url']
66
+ if 'huggingface.co' in url:
67
+ row_dict['parquet_url'] = url.replace('https://huggingface.co', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
68
+ elif 'hf-mirror.com' in url:
69
+ row_dict['parquet_url'] = url.replace('https://hf-mirror.com', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
70
+ else:
71
+ if verbose: print("❌ Error: 'parquet_url' missing in metadata.")
72
+ return None, None
73
+
74
+ if verbose: print(f"⬇️ Fetching data for {product_id} from {row_dict['parquet_url']}...")
75
+
76
+ try:
77
+ bands_data = read_row_memory(row_dict, columns=['B04', 'B03', 'B02'])
78
+
79
+ if not all(b in bands_data for b in ['B04', 'B03', 'B02']):
80
+ if verbose: print(f"❌ Error: Missing bands in fetched data for {product_id}")
81
+ return None, None
82
+
83
+ rgb_img = np.stack([bands_data['B04'], bands_data['B03'], bands_data['B02']], axis=-1)
84
+
85
+ if verbose:
86
+ print(f"Raw RGB stats: Min={rgb_img.min()}, Max={rgb_img.max()}, Mean={rgb_img.mean()}, Dtype={rgb_img.dtype}")
87
+
88
+ # Check if data is already 0-255 or 0-1
89
+ if rgb_img.max() <= 255:
90
+ # Assume it might be uint8 or scaled
91
+ pass
92
+
93
+ rgb_norm = (2.5 * (rgb_img.astype(float) / 10000.0)).clip(0, 1)
94
+ rgb_uint8 = (rgb_norm * 255).astype(np.uint8)
95
+
96
+ if verbose:
97
+ print(f"Processed RGB stats: Min={rgb_uint8.min()}, Max={rgb_uint8.max()}, Mean={rgb_uint8.mean()}")
98
+
99
+ img_full = Image.fromarray(rgb_uint8)
100
+
101
+ if rgb_uint8.shape[0] >= 384 and rgb_uint8.shape[1] >= 384:
102
+ cropped_array = crop_center(rgb_uint8, 384, 384)
103
+ img_384 = Image.fromarray(cropped_array)
104
+ else:
105
+ if verbose: print(f"⚠️ Image too small {rgb_uint8.shape}, resizing to 384x384.")
106
+ img_384 = img_full.resize((384, 384))
107
+
108
+ if verbose: print(f"✅ Successfully processed {product_id}")
109
+ return img_384, img_full
110
+
111
+ except Exception as e:
112
+ if verbose: print(f"❌ Error processing {product_id}: {e}")
113
+ import traceback
114
+ traceback.print_exc()
115
+ return None, None
116
+
117
+ # Define Esri Imagery Class
118
+ class EsriImagery(cimgt.GoogleTiles):
119
+ def _image_url(self, tile):
120
+ x, y, z = tile
121
+ return f'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
122
+
123
+ from PIL import Image, ImageDraw, ImageFont
124
+
125
+ def get_placeholder_image(text="Image Unavailable", size=(384, 384)):
126
+ img = Image.new('RGB', size, color=(200, 200, 200))
127
+ d = ImageDraw.Draw(img)
128
+ try:
129
+ # Try to load a default font
130
+ font = ImageFont.load_default()
131
+ except:
132
+ font = None
133
+
134
+ # Draw text in center (rough approximation)
135
+ # For better centering we would need font metrics, but simple is fine here
136
+ d.text((20, size[1]//2), text, fill=(0, 0, 0), font=font)
137
+ return img
138
+
139
+ def get_esri_satellite_image(lat, lon, score=None, rank=None, query=None):
140
+ """
141
+ Generates a satellite image visualization using Esri World Imagery via Cartopy.
142
+ Matches the style of the provided notebook.
143
+ Uses OO Matplotlib API for thread safety.
144
+ """
145
+ try:
146
+ imagery = EsriImagery()
147
+
148
+ # Create figure using OO API
149
+ fig = Figure(figsize=(5, 5), dpi=100)
150
+ canvas = FigureCanvasAgg(fig)
151
+ ax = fig.add_subplot(1, 1, 1, projection=imagery.crs)
152
+
153
+ # Set extent to approx 10km x 10km around the point
154
+ extent_deg = 0.05
155
+ ax.set_extent([lon - extent_deg, lon + extent_deg, lat - extent_deg, lat + extent_deg], crs=ccrs.PlateCarree())
156
+
157
+ # Add the imagery
158
+ ax.add_image(imagery, 14)
159
+
160
+ # Add a marker for the center
161
+ ax.plot(lon, lat, marker='+', color='yellow', markersize=12, markeredgewidth=2, transform=ccrs.PlateCarree())
162
+
163
+ # Add Bounding Box (3840m x 3840m)
164
+ box_size_m = 384 * 10 # 3840m
165
+
166
+ # Convert meters to degrees (approx)
167
+ # 1 deg lat = 111320m
168
+ # 1 deg lon = 111320m * cos(lat)
169
+ dlat = (box_size_m / 111320)
170
+ dlon = (box_size_m / (111320 * math.cos(math.radians(lat))))
171
+
172
+ # Bottom-Left corner
173
+ rect_lon = lon - dlon / 2
174
+ rect_lat = lat - dlat / 2
175
+
176
+ # Add Rectangle
177
+ rect = Rectangle((rect_lon, rect_lat), dlon, dlat,
178
+ linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree())
179
+ ax.add_patch(rect)
180
+
181
+ # Title
182
+ title_parts = []
183
+ if query: title_parts.append(f"{query}")
184
+ if rank is not None: title_parts.append(f"Rank {rank}")
185
+ if score is not None: title_parts.append(f"Score: {score:.4f}")
186
+
187
+ ax.set_title("\n".join(title_parts), fontsize=10)
188
+
189
+ # Save to buffer
190
+ buf = BytesIO()
191
+ fig.savefig(buf, format='png', bbox_inches='tight')
192
+ buf.seek(0)
193
+
194
+ return Image.open(buf)
195
+
196
+ except Exception as e:
197
+ # Suppress full traceback for network errors to avoid log spam
198
+ error_msg = str(e)
199
+ if "Connection reset by peer" in error_msg or "Network is unreachable" in error_msg or "urlopen error" in error_msg:
200
+ print(f"⚠️ Network warning: Could not fetch Esri satellite map for ({lat:.4f}, {lon:.4f}). Server might be offline.")
201
+ else:
202
+ print(f"Error generating Esri image for {lat}, {lon}: {e}")
203
+ # Only print traceback for non-network errors
204
+ # import traceback
205
+ # traceback.print_exc()
206
+
207
+ # Return a placeholder image with text
208
+ return get_placeholder_image(f"Map Unavailable\n({lat:.2f}, {lon:.2f})")
209
+
210
+ def get_esri_satellite_image_url(lat, lon, zoom=14):
211
+ """
212
+ Returns the URL for the Esri World Imagery tile at the given location.
213
+ """
214
+ try:
215
+ imagery = EsriImagery()
216
+ # Calculate tile coordinates
217
+ # This is a simplification, cimgt handles this internally usually
218
+ # But for direct URL we might need more logic or just use the static map approach above
219
+ # For now, let's stick to the static map generation which works
220
+ pass
221
+ except:
222
+ pass
223
+ return None
embedding_datasets/grid_sample_center_22k_FarSLIP_384x384.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3555e0279742daa7ee27ba5587a8234f791966ce4411ef804455ee03af52e1aa
3
+ size 23547770
embedding_datasets/grid_sample_center_22k_SatCLIP_384x384.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76484097dea1f0fc65e4f2c8d3e825ec3ccda8914da83e3a65aabd86a4f59ec2
3
+ size 25158503
embedding_datasets/grid_sample_center_22k_SigLIP_384x384.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a34d949f704f8f4d9d963f28dbe547c341591645cf86d191587a3cc0a866855f
3
+ size 50178408
embedding_datasets/grid_sample_metadata.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:808fde21fdae5ef2dc8183c7e8017b286dc2d2419ed64e6058358291cbeef06c
3
+ size 1999889
embedding_datasets/zhejiang_sample_center_2k_FarSLIP_384x384.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4bc51828dd58d45c62d3168557870e8db6659c0c52e9661865326cafb11c88b
3
+ size 2088911
embedding_datasets/zhejiang_sample_center_2k_SatCLIP_384x384.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81bdf991b1d6a100108d0cad730bf79b1e9f558f261f2d4fb18c3f68c9ff2796
3
+ size 2719357
embedding_datasets/zhejiang_sample_center_2k_SigLIP_384x384.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47ab37984d86b9b15949448f36d52022c24c40e7ff3bc65f44510ff08d0cbe81
3
+ size 4381379
embedding_datasets/zhejiang_sample_metadata.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb7c46e7985c05cb010e4fd308489865271b43b250844b8d42a3fe8263d01a78
3
+ size 159438
examples/example1.png ADDED

Git LFS Details

  • SHA256: 07dd836c4dfe700657f163afdae9ebf2685f83dca1417078b3147c8c31f598a9
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
examples/example2.png ADDED

Git LFS Details

  • SHA256: e52a44517c028cb6b9828c37c974991fb20122f6cdba951e809ac66b7c591552
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
examples/example3.png ADDED

Git LFS Details

  • SHA256: d63b587c17943eb1e60f511def466696c1a12a323f0f67dff99da7631e2e48aa
  • Pointer size: 131 Bytes
  • Size of remote file: 507 kB
images/CLIP.png ADDED

Git LFS Details

  • SHA256: b6005b0baf8fa09c54d102f75f437a1f3445f5f2fa512ac78513b77809e83363
  • Pointer size: 131 Bytes
  • Size of remote file: 192 kB
images/Image_Search_Amazon.jpg ADDED

Git LFS Details

  • SHA256: d3dda8f0b849b76048cd8f2013232a31cd3d30af248eb86bee0fe724513e4f58
  • Pointer size: 131 Bytes
  • Size of remote file: 791 kB
images/Image_Search_Middle_East.jpg ADDED

Git LFS Details

  • SHA256: 7ac5769228c0869bf43cc4c9a2a202159937c44f8b872ec15b1b9c93159b414e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
images/Location_Search_Amazon.jpg ADDED

Git LFS Details

  • SHA256: 8614bae60fd581e464effe269d9dd5fb74c389136716219f8c0db72ad3560f53
  • Pointer size: 131 Bytes
  • Size of remote file: 905 kB
images/Location_Search_Hangzhou.jpg ADDED

Git LFS Details

  • SHA256: edd2aeb10fff28156716a99a1fa354bc461a6dad13ea4fcfcf768e539549251f
  • Pointer size: 131 Bytes
  • Size of remote file: 963 kB
images/Text_Search.jpg ADDED

Git LFS Details

  • SHA256: fd8b3ae6ba248fb713d06f93d14560814cd925b35edad9227711cf2e4933a901
  • Pointer size: 131 Bytes
  • Size of remote file: 942 kB
images/embedding.png ADDED

Git LFS Details

  • SHA256: e64060fe753c5f322b75645ddb1ff5380a41f5581e62136f3b67befda77abcd0
  • Pointer size: 130 Bytes
  • Size of remote file: 91.4 kB
images/framework_en.png ADDED

Git LFS Details

  • SHA256: 7b92ac54e73b446641c84d004b71a013ed85d08d308978880e76928d654bd89e
  • Pointer size: 131 Bytes
  • Size of remote file: 443 kB
images/framework_zh.png ADDED

Git LFS Details

  • SHA256: 32c7e6756a68db25e1a6d30c8567da9a5fedf6f6600d507b26b3f98e5eb34a86
  • Pointer size: 131 Bytes
  • Size of remote file: 535 kB
images/samples.png ADDED

Git LFS Details

  • SHA256: 122e7e4c21b01fc14325ce794d5286c0e1abbd6ae3c42cf102907c7e209df65e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.78 MB
models/FarSLIP/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/logs/
2
+ **/wandb/
3
+ models/
4
+ features/
5
+ results/
6
+ src/open_clip_train/config.py
7
+ src/open_clip_train/output_samples/
8
+ **/results_retrieval/
9
+ **/results_classification/
10
+ checkpoints/
11
+
12
+ tests/data/
13
+ *.pt
14
+
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ pip-wheel-metadata/
38
+ share/python-wheels/
39
+ *.egg-info/
40
+ .installed.cfg
41
+ *.egg
42
+ MANIFEST
43
+
44
+ # PyInstaller
45
+ # Usually these files are written by a python script from a template
46
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
47
+ *.manifest
48
+ *.spec
49
+
50
+ # Installer logs
51
+ pip-log.txt
52
+ pip-delete-this-directory.txt
53
+
54
+ # Unit test / coverage reports
55
+ htmlcov/
56
+ .tox/
57
+ .nox/
58
+ .coverage
59
+ .coverage.*
60
+ .cache
61
+ nosetests.xml
62
+ coverage.xml
63
+ *.cover
64
+ *.py,cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+ db.sqlite3-journal
77
+
78
+ # Flask stuff:
79
+ instance/
80
+ .webassets-cache
81
+
82
+ # Scrapy stuff:
83
+ .scrapy
84
+
85
+ # Sphinx documentation
86
+ docs/_build/
87
+
88
+ # PyBuilder
89
+ target/
90
+
91
+ # Jupyter Notebook
92
+ .ipynb_checkpoints
93
+
94
+ # IPython
95
+ profile_default/
96
+ ipython_config.py
97
+
98
+ # pyenv
99
+ .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109
+ __pypackages__/
110
+
111
+ # Celery stuff
112
+ celerybeat-schedule
113
+ celerybeat.pid
114
+
115
+ # SageMath parsed files
116
+ *.sage.py
117
+
118
+ # Environments
119
+ .env
120
+ .venv
121
+ env/
122
+ venv/
123
+ ENV/
124
+ env.bak/
125
+ venv.bak/
126
+
127
+ # Spyder project settings
128
+ .spyderproject
129
+ .spyproject
130
+
131
+ # Rope project settings
132
+ .ropeproject
133
+
134
+ # mkdocs documentation
135
+ /site
136
+
137
+ # mypy
138
+ .mypy_cache/
139
+ .dmypy.json
140
+ dmypy.json
141
+
142
+ # Pyre type checker
143
+ .pyre/
144
+ sync.sh
145
+ gpu1sync.sh
146
+ .idea
147
+ *.pdf
148
+ **/._*
149
+ **/*DS_*
150
+ **.jsonl
151
+ src/sbatch
152
+ src/misc
153
+ .vscode
154
+ src/debug
155
+ core.*
156
+
157
+ *.out
158
+
159
+ # Allow
160
+ !src/evaluation/misc/results_dbs/*