ML4RS-Anonymous commited on
Commit
eb1aec4
·
verified ·
1 Parent(s): a582d21

Upload all files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. MajorTOM/MajorTOMDataset.py +64 -0
  3. MajorTOM/__init__.py +5 -0
  4. MajorTOM/embedder/MajorTOM_Embedder.py +191 -0
  5. MajorTOM/embedder/__init__.py +2 -0
  6. MajorTOM/embedder/__pycache__/MajorTOM_Embedder.cpython-311.pyc +0 -0
  7. MajorTOM/embedder/__pycache__/__init__.cpython-311.pyc +0 -0
  8. MajorTOM/embedder/__pycache__/grid_cell_fragment.cpython-311.pyc +0 -0
  9. MajorTOM/embedder/grid_cell_fragment.py +164 -0
  10. MajorTOM/embedder/models/DINOv2_S2RGB.py +91 -0
  11. MajorTOM/embedder/models/SSL4EO_S1RTC.py +125 -0
  12. MajorTOM/embedder/models/SSL4EO_S2L1C.py +97 -0
  13. MajorTOM/embedder/models/SigLIP_S2RGB.py +65 -0
  14. MajorTOM/embedder/models/__init__.py +4 -0
  15. MajorTOM/embedder/models/__pycache__/DINOv2_S2RGB.cpython-311.pyc +0 -0
  16. MajorTOM/embedder/models/__pycache__/SSL4EO_S1RTC.cpython-311.pyc +0 -0
  17. MajorTOM/embedder/models/__pycache__/SSL4EO_S2L1C.cpython-311.pyc +0 -0
  18. MajorTOM/embedder/models/__pycache__/SigLIP_S2RGB.cpython-311.pyc +0 -0
  19. MajorTOM/embedder/models/__pycache__/__init__.cpython-311.pyc +0 -0
  20. MajorTOM/extras/coverage-example.png +3 -0
  21. MajorTOM/extras/coverage_vis.py +149 -0
  22. MajorTOM/extras/extract-sample-from-raw-S2.ipynb +0 -0
  23. MajorTOM/extras/thumbnail_dem.py +77 -0
  24. MajorTOM/extras/thumbnail_s1rtc.py +80 -0
  25. MajorTOM/extras/thumbnail_s2.py +68 -0
  26. MajorTOM/grid.py +284 -0
  27. MajorTOM/metadata_helpers.py +159 -0
  28. MajorTOM/sample_helpers.py +20 -0
  29. app.py +799 -0
  30. compute_embeddings.py +606 -0
  31. configs/huggingface.yaml +15 -0
  32. countries.geo.json +0 -0
  33. data_utils.py +223 -0
  34. examples/example1.png +3 -0
  35. examples/example2.png +3 -0
  36. examples/example3.png +3 -0
  37. logs/compute_embeddings_dinov2.log +170 -0
  38. logs/compute_embeddings_farslip.log +150 -0
  39. logs/compute_embeddings_satclip.log +182 -0
  40. logs/compute_embeddings_siglip.log +200 -0
  41. models/FarSLIP/.gitignore +160 -0
  42. models/FarSLIP/LICENSE +21 -0
  43. models/FarSLIP/README.md +237 -0
  44. models/FarSLIP/__init__.py +1 -0
  45. models/FarSLIP/open_clip/__init__.py +18 -0
  46. models/FarSLIP/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  47. models/FarSLIP/open_clip/coca_model.py +582 -0
  48. models/FarSLIP/open_clip/constants.py +11 -0
  49. models/FarSLIP/open_clip/convert.py +206 -0
  50. models/FarSLIP/open_clip/factory.py +610 -0
.gitattributes CHANGED
@@ -39,3 +39,9 @@ EarthEmbeddingExplorer/examples/example3.png filter=lfs diff=lfs merge=lfs -text
39
  EarthEmbeddingExplorer/MajorTOM/extras/coverage-example.png filter=lfs diff=lfs merge=lfs -text
40
  EarthEmbeddingExplorer/models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
41
  EarthEmbeddingExplorer/models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
39
  EarthEmbeddingExplorer/MajorTOM/extras/coverage-example.png filter=lfs diff=lfs merge=lfs -text
40
  EarthEmbeddingExplorer/models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
41
  EarthEmbeddingExplorer/models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
42
+ examples/example1.png filter=lfs diff=lfs merge=lfs -text
43
+ examples/example2.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/example3.png filter=lfs diff=lfs merge=lfs -text
45
+ MajorTOM/extras/coverage-example.png filter=lfs diff=lfs merge=lfs -text
46
+ models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
47
+ models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
MajorTOM/MajorTOMDataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from pathlib import Path
6
+ import rasterio as rio
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+
10
+ class MajorTOM(Dataset):
11
+ """MajorTOM Dataset (https://huggingface.co/Major-TOM)
12
+
13
+ Args:
14
+ df ((geo)pandas.DataFrame): Metadata dataframe
15
+ local_dir (string): Root directory of the local dataset version
16
+ tif_bands (list): A list of tif file names to be read
17
+ png_bands (list): A list of png file names to be read
18
+
19
+ """
20
+
21
+ def __init__(self,
22
+ df,
23
+ local_dir = None,
24
+ tif_bands=['B04','B03','B02'],
25
+ png_bands=['thumbnail'],
26
+ tif_transforms=[transforms.ToTensor()],
27
+ png_transforms=[transforms.ToTensor()]
28
+ ):
29
+ super().__init__()
30
+ self.df = df
31
+ self.local_dir = Path(local_dir) if isinstance(local_dir,str) else local_dir
32
+ self.tif_bands = tif_bands if not isinstance(tif_bands,str) else [tif_bands]
33
+ self.png_bands = png_bands if not isinstance(png_bands,str) else [png_bands]
34
+ self.tif_transforms = transforms.Compose(tif_transforms) if tif_transforms is not None else None
35
+ self.png_transforms = transforms.Compose(png_transforms) if png_transforms is not None else None
36
+
37
+ def __len__(self):
38
+ return len(self.df)
39
+
40
+ def __getitem__(self, idx):
41
+ meta = self.df.iloc[idx]
42
+
43
+ product_id = meta.product_id
44
+ grid_cell = meta.grid_cell
45
+ row = grid_cell.split('_')[0]
46
+
47
+ path = self.local_dir / Path("{}/{}/{}".format(row, grid_cell, product_id))
48
+ out_dict = {'meta' : meta}
49
+
50
+ for band in self.tif_bands:
51
+ with rio.open(path / '{}.tif'.format(band)) as f:
52
+ out = f.read()
53
+ if self.tif_transforms is not None:
54
+ out = self.tif_transforms(out)
55
+ out_dict[band] = out
56
+
57
+
58
+ for band in self.png_bands:
59
+ out = Image.open(path / '{}.png'.format(band))
60
+ if self.png_transforms is not None:
61
+ out = self.png_transforms(out)
62
+ out_dict[band] = out
63
+
64
+ return out_dict
MajorTOM/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .sample_helpers import *
2
+ from .metadata_helpers import *
3
+ from .MajorTOMDataset import *
4
+ from .grid import *
5
+ from .embedder import *
MajorTOM/embedder/MajorTOM_Embedder.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import geopandas as gpd
3
+ import hashlib
4
+ from rasterio.io import MemoryFile
5
+
6
+ from .grid_cell_fragment import *
7
+ from .models import *
8
+ import cv2
9
+
10
+ class MajorTOM_Embedder(torch.nn.Module):
11
+ """
12
+ MajorTOM Embedder class that applies a model to geospatial image fragments,
13
+ computes embeddings, and returns metadata for each fragment.
14
+
15
+ This class is designed to work with raster data, where the image is fragmented
16
+ into smaller tiles, and embeddings are computed for each tile using the provided
17
+ embedder model. The output is a GeoDataFrame containing spatial metadata and
18
+ the corresponding embeddings for each tile.
19
+
20
+ Attributes:
21
+ embedder: A model that generates embeddings for image fragments.
22
+ frag_params: Dictionary containing fragmentation parameters such as the
23
+ target overlap and border shift.
24
+ column_types: Dictionary specifying data types for the output GeoDataFrame columns.
25
+ """
26
+
27
+ def __init__(self, embedder, target_overlap=0.1, border_shift=True):
28
+ """
29
+ Initializes the MajorTOM Embedder with the given parameters.
30
+
31
+ Args:
32
+ embedder (torch.nn.Module): A model that generates embeddings for image fragments.
33
+ target_overlap (float): The target overlap between image fragments. Default is 0.1.
34
+ border_shift (bool): Whether to shift the borders of fragments to avoid edge artifacts. Default is True.
35
+ """
36
+ super().__init__()
37
+
38
+ # Model
39
+ self.embedder = embedder
40
+
41
+ # Fragmentation Settings
42
+ self.frag_params = params = {
43
+ 'fragment_size' : self.embedder.size[0],
44
+ 'target_overlap' : target_overlap,
45
+ 'border_shift' : border_shift
46
+ }
47
+
48
+ # Data types for the output dataframe (commented columns need no conversion)
49
+ self.column_types = {
50
+ #'unique_id' :,
51
+ #'embedding' : ,
52
+ #'timestamp' : ,
53
+ #'product_id' : ,
54
+ #'grid_cell' : ,
55
+ 'grid_row_u' : 'int16',
56
+ 'grid_col_r' : 'int16',
57
+ 'centre_lat' : 'float32',
58
+ 'centre_lon' : 'float32',
59
+ #'utm_footprint' : ,
60
+ #'utm_crs' : ,
61
+ #'pixel_bbox' : ,
62
+ }
63
+
64
+ def bands(self):
65
+ """
66
+ Returns the set of input bands in the correct order.
67
+
68
+ Returns:
69
+ list: List of input bands used by the embedder.
70
+ """
71
+ return self.embedder.bands
72
+
73
+ def size(self):
74
+ """
75
+ Returns the input image size.
76
+
77
+ Returns:
78
+ tuple: Tuple representing the image size (height, width).
79
+ """
80
+ return self.embedder.size
81
+
82
+ def calculate_checksum(self, geometry, timestamp, product_id, embedding):
83
+ """
84
+ Calculates a checksum for the given geometry, timestamp, product ID, and embedding.
85
+
86
+ Args:
87
+ geometry (shapely.geometry): The geometry object representing the fragment's footprint.
88
+ timestamp (str): Timestamp of the data.
89
+ product_id (str): Product identifier.
90
+ embedding (np.ndarray): The embedding of the image fragment.
91
+
92
+ Returns:
93
+ str: A SHA256 checksum of the concatenated input parameters.
94
+ """
95
+ combined = f"{geometry}_{timestamp}_{product_id}_{embedding}"
96
+ checksum = hashlib.sha256(combined.encode()).hexdigest()
97
+ return checksum
98
+
99
+ def _read_image(self, row):
100
+ """
101
+ Reads and processes the image bands for a given row, performs optional upsampling
102
+ if the resolution is mismatched, and returns the image data, footprint, and CRS.
103
+
104
+ Args:
105
+ row (pandas.Series): The input row containing the image bands.
106
+
107
+ Returns:
108
+ torch.Tensor: A tensor containing the stacked image bands.
109
+ shapely.geometry: The footprint of the image.
110
+ rasterio.crs.CRS: The CRS of the image.
111
+ """
112
+
113
+ # Read the file
114
+ img = []
115
+ for band in self.embedder.bands:
116
+ with MemoryFile(row[band][0].as_py()) as mem_f:
117
+ with mem_f.open(driver='GTiff') as f:
118
+ crs = f.crs
119
+ footprint = box(*f.bounds)
120
+ img.append(f.read()[0])
121
+
122
+ # optional upsampling
123
+ shapes = [layer.shape for layer in img]
124
+ if any([el!=shapes[0] for el in shapes]): # if any resolution mismatch
125
+ h, w = max([el[0] for el in shapes]), max([el[1] for el in shapes]) # maximum size
126
+ for layer_idx, layer in enumerate(img):
127
+ if layer.shape != (h,w):
128
+ img[layer_idx] = cv2.resize(layer, (h,w), interpolation=cv2.INTER_NEAREST)
129
+ img = torch.from_numpy(np.stack(img,-1).astype(np.float32))
130
+
131
+ return img, footprint, crs
132
+
133
+
134
+ def forward(self, row, row_meta, device='cuda'):
135
+ """
136
+ Forward pass of the model: Reads the image, fragments it, computes embeddings
137
+ for each fragment, and returns a GeoDataFrame with the spatial metadata and
138
+ embeddings.
139
+
140
+ Args:
141
+ row (pandas.Series): The input row containing the image data.
142
+ row_meta (pandas.Series): Metadata associated with the row (e.g., timestamp, product_id).
143
+ device (str): The device to run the model on ('cpu' or 'cuda'). Default is 'cuda'.
144
+
145
+ Returns:
146
+ geopandas.GeoDataFrame: A GeoDataFrame containing metadata and embeddings for each fragment.
147
+ """
148
+ # Read file
149
+ img, footprint, crs = self._read_image(row)
150
+
151
+ # Fragment the sample
152
+ fragments, xys = fragment_fn(img, **self.frag_params, return_indices=True, verbose=False)
153
+
154
+ nrows, ncols, c, h, w = fragments.shape
155
+ # Apply the model
156
+ with torch.no_grad():
157
+ embeddings = self.embedder(fragments.reshape(-1,c,h,w).to(device)).view(nrows, ncols, -1)
158
+
159
+ df_rows = []
160
+
161
+ # Pack rows for geoparquet
162
+ for r_idx in range(nrows):
163
+ for c_idx in range(ncols):
164
+ embedding = embeddings[r_idx, c_idx].cpu().numpy()
165
+ # spatial features per fragment
166
+ x_offset,y_offset=xys[r_idx,c_idx].int().tolist()
167
+ pixel_bbox = [x_offset, y_offset, x_offset + h,y_offset + w] # in pixels
168
+ utm_footprint = crop_footprint(footprint, *img.shape[:2], pixel_bbox)
169
+ # main footprint is in WGS84 (needs to be consistent across parquet)
170
+ transformer = Transformer.from_crs(crs, CRS.from_epsg(4326), always_xy=True)
171
+ geometry = transform(transformer.transform, utm_footprint) # WGS84
172
+ centre_lon, centre_lat = geometry.centroid.coords[0]
173
+
174
+ row_dict = {
175
+ 'unique_id' : self.calculate_checksum(geometry, row_meta.timestamp.item(), row_meta.product_id.item(), embedding),
176
+ 'embedding' : embedding,
177
+ 'timestamp' : row_meta.timestamp.item(),
178
+ 'product_id' : row_meta.product_id.item(),
179
+ 'grid_cell' : row_meta.grid_cell.item(),
180
+ 'grid_row_u' : row_meta.grid_row_u.item(),
181
+ 'grid_col_r' : row_meta.grid_col_r.item(),
182
+ 'geometry' : geometry,
183
+ 'centre_lat' : centre_lat,
184
+ 'centre_lon' : centre_lon,
185
+ 'utm_footprint' : utm_footprint.wkt,
186
+ 'utm_crs' : crs.to_string(),
187
+ 'pixel_bbox' : pixel_bbox,
188
+ }
189
+ df_rows.append(row_dict)
190
+
191
+ return gpd.GeoDataFrame(df_rows).astype(self.column_types)
MajorTOM/embedder/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .MajorTOM_Embedder import *
2
+ from .grid_cell_fragment import *
MajorTOM/embedder/__pycache__/MajorTOM_Embedder.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
MajorTOM/embedder/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (244 Bytes). View file
 
MajorTOM/embedder/__pycache__/grid_cell_fragment.cpython-311.pyc ADDED
Binary file (8.37 kB). View file
 
MajorTOM/embedder/grid_cell_fragment.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from shapely.ops import transform
5
+ from pyproj import CRS, Transformer
6
+ import geopandas as gpd
7
+ import pandas as pd
8
+ import numpy as np
9
+ from shapely.geometry import Polygon, box
10
+ from rasterio.transform import from_bounds, xy
11
+ #from rasterio.windows import Window, from_bounds
12
+ import rasterio as rio
13
+
14
+ def crop_footprint(footprint, height, width, crop_bbox):
15
+ """
16
+ Crops the given footprint to the specified bounding box.
17
+
18
+ Args:
19
+ footprint (shapely.geometry.Polygon): The original footprint of the image or area.
20
+ height (int): Height of the image (in pixels).
21
+ width (int): Width of the image (in pixels).
22
+ crop_bbox (list): The bounding box to crop the footprint. The format is
23
+ [col_start, row_start, col_end, row_end], where:
24
+ - col_start, row_start: top-left corner
25
+ - col_end, row_end: bottom-right corner
26
+
27
+ Returns:
28
+ shapely.geometry.Polygon: The cropped bounding box in the same coordinate reference system (CRS) as the original footprint.
29
+ """
30
+
31
+ transform = from_bounds(*footprint.bounds, width, height)
32
+
33
+ # Convert pixel coordinates (col, row) to spatial coordinates (e.g., UTM)
34
+ # Using the raster's affine transform
35
+ min_x, min_y = transform * (crop_bbox[0], crop_bbox[1]) # (col_start, row_start)
36
+ max_x, max_y = transform * (crop_bbox[2], crop_bbox[3]) # (col_end, row_end)
37
+
38
+ # Create a Shapely polygon for the crop's bounding box in UTM
39
+ return box(min_x, min_y, max_x, max_y)
40
+
41
+ def fragment_unfold(image,fragment_size,overlap):
42
+ """
43
+ Unfold operation for a fragment with overlap. This function extracts image patches (fragments) with a specified
44
+ size and overlap between them.
45
+
46
+ Args:
47
+ image (torch.Tensor or np.ndarray): The input image to be fragmented (height, width, channels).
48
+ fragment_size (int or list): The size of each fragment. Can be a single integer for square fragments or
49
+ a list of two integers for non-square fragments.
50
+ overlap (int or list): The overlap between adjacent fragments. Can be a single integer or a list of two integers.
51
+
52
+ Returns:
53
+ torch.Tensor: The unfolded fragments of the image, each with the specified size and overlap.
54
+ """
55
+
56
+ # Convert image to a tensor and reorder dimensions if necessary
57
+ if not torch.is_tensor(image):
58
+ image = torch.from_numpy(image).permute(2, 0, 1) # Rearrange to (channels, height, width)
59
+ if len(image.shape) < 4:
60
+ image = image.unsqueeze(0) # Add batch dimension
61
+
62
+ b, c, h, w = image.shape
63
+
64
+ # Ensure fragment size is a list
65
+ if isinstance(fragment_size, int):
66
+ fragment_size = [fragment_size, fragment_size]
67
+ if isinstance(overlap, int):
68
+ overlap = [overlap, overlap]
69
+
70
+ # Calculate stride based on fragment size and overlap
71
+ stride = [f - o for f, o in zip(fragment_size, overlap)]
72
+
73
+ # Perform the unfolding operation
74
+ uf = torch.nn.functional.unfold(image, fragment_size, dilation=1, padding=0, stride=stride)
75
+
76
+ # Reshape and permute to return the unfolded image fragments
77
+ return uf.view(b, c, *fragment_size, -1).permute(0, 4, 1, 2, 3)[0]
78
+
79
+ def fragment_fn(img,
80
+ fragment_size,
81
+ target_overlap,
82
+ border_shift=True, # determines whether the outer border is shifted to ensure full coverage
83
+ return_indices=False,
84
+ verbose=False
85
+ ):
86
+ """
87
+ Fragment an image into smaller patches with a specified fragment size and overlap.
88
+
89
+ This function handles different scenarios based on image size, fragment size, and overlap,
90
+ and creates fragments from the input image accordingly. It also supports shifting the outer
91
+ border of fragments to ensure full coverage of the image.
92
+
93
+ Args:
94
+ img (np.ndarray or torch.Tensor): The input image to be fragmented (height, width, channels).
95
+ fragment_size (int or list): The size of the fragments. Can be a single integer (square) or a list of two integers (non-square).
96
+ target_overlap (float): The target overlap between adjacent fragments, in pixels.
97
+ border_shift (bool): Whether to shift the border of fragments to ensure full coverage of the image. Default is True.
98
+ return_indices (bool): If True, the function will also return the indices (offsets) for each fragment. Default is False.
99
+ verbose (bool): If True, the function will print additional details about the overlap. Default is False.
100
+
101
+ Returns:
102
+ torch.Tensor or tuple:
103
+ - If `return_indices` is False, a tensor containing the image fragments.
104
+ - If `return_indices` is True, a tuple of the image fragments and their offsets.
105
+ """
106
+
107
+ h,w,c=img.shape
108
+
109
+ assert h==w # SQUARE IMAGES SUPPORT ONLY
110
+
111
+ hf, wf = fragment_size, fragment_size
112
+ ho, wo = target_overlap*hf, target_overlap*wf
113
+
114
+ assert h >= hf and w >= wf # reject Scenario 1
115
+
116
+ # Scenario 2
117
+ if h == hf or w == wf:
118
+ if not torch.is_tensor(img):
119
+ img=torch.from_numpy(img).permute(2,0,1)
120
+ return img.view(1,1,c,h,w)
121
+
122
+ # Scenario 3 & 4
123
+
124
+ # determine number of segments between the centers of outermost fragments
125
+ h_n = max(1, int(np.round((h-hf)/(hf-ho))))
126
+ w_n = max(1, int(np.round((w-wf)/(wf-wo))))
127
+
128
+ # adjust practical overlap (divide the distance between the centers of outermost fragments by the true number of segments)
129
+ aho = int(np.ceil(hf-(h-hf)/(h_n)))
130
+ awo = int(np.ceil(wf-(w-wf)/(w_n)))
131
+
132
+ # compute fragments (might not exactly fill the outermost border)
133
+ topleft = fragment_unfold(img.permute(2,0,1),fragment_size=(hf,wf), overlap=(aho,awo)).view(1+h_n, 1+w_n, c, hf, wf)
134
+
135
+ full = topleft
136
+
137
+ if border_shift:
138
+
139
+ if h > hf+h_n*(hf-aho) or w > wf+w_n*(wf-awo):
140
+ #print('Outers...')
141
+ bottomleft = fragment_unfold(img[-hf:,:,:],fragment_size=(hf,wf), overlap=(aho,awo)).view(1,1+w_n,c,hf,wf)
142
+ topright = fragment_unfold(img[:,-wf:,:],fragment_size=(hf,wf), overlap=(aho,awo)).view(1+h_n,1,c,hf,wf)
143
+
144
+ # Shift last row and col to the border of the original
145
+ full[:,-1,None] = topright
146
+ full[-1] = bottomleft
147
+
148
+ if verbose:
149
+ print('Target Overlap: {} pixels. Feasible Overlap: {} pixels.'.format(ho,aho))
150
+
151
+ if not return_indices:
152
+ return full
153
+ else:
154
+ offset=-1*torch.ones(*full.shape[:2],2)
155
+ for ridx in range(full.shape[0]):
156
+ for cidx in range(full.shape[1]):
157
+ offset[ridx,cidx,1] = cidx * (hf-aho)
158
+ offset[ridx,cidx,0] = ridx * (wf-awo)
159
+
160
+ if border_shift:
161
+ offset[ridx,-1,1] = h-hf
162
+ offset[-1,cidx,0] = w-wf
163
+
164
+ return full,offset
MajorTOM/embedder/models/DINOv2_S2RGB.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, AutoModel
3
+
4
+ class DINOv2_S2RGB_Embedder(torch.nn.Module):
5
+ """
6
+ Embedding wrapper for DINOv2 and Sentinel-2 data.
7
+
8
+ This model uses the DINOv2 architecture to generate embeddings for Sentinel-2 RGB data. The input data (RGB bands)
9
+ is preprocessed by normalizing and mapping it to true-color values. Then, it is passed through the DINOv2 model
10
+ to obtain feature embeddings.
11
+
12
+ Preprocessing:
13
+ The input Sentinel-2 image is divided by 10,000 and multiplied by 2.5 to map it to a true-color image
14
+ (normalized to the range [0, 1]), followed by processing using the DINOv2 image processor.
15
+
16
+ Model:
17
+ The DINOv2 model processes RGB input images of shape [224, 224] and produces embeddings, which are then
18
+ averaged across the sequence dimension to obtain a fixed-size embedding vector.
19
+
20
+ Model Components:
21
+ - `AutoImageProcessor`: Preprocessing pipeline for handling Sentinel-2 data.
22
+ - `AutoModel`: DINOv2 transformer model used for feature extraction.
23
+
24
+ Attributes:
25
+ processor (AutoImageProcessor): The DINOv2 image processor to handle preprocessing.
26
+ model (AutoModel): The DINOv2 model used to generate embeddings from preprocessed images.
27
+ bands (list): List of the Sentinel-2 bands used for RGB input (B04, B03, B02).
28
+ size (tuple): The input size expected by the model (height, width) for the RGB image.
29
+ """
30
+
31
+ def __init__(self):
32
+ """
33
+ Initializes the DINOv2_S2RGB_Embedder by loading the pre-trained DINOv2 model and processor,
34
+ and setting the expected input size for Sentinel-2 RGB data.
35
+
36
+ This embedder uses the 'facebook/dinov2-base' model for feature extraction from Sentinel-2
37
+ true-color images (RGB).
38
+
39
+ Attributes:
40
+ processor (AutoImageProcessor): The DINOv2 image processor for preprocessing Sentinel-2 images.
41
+ model (AutoModel): The pre-trained DINOv2 model for generating embeddings.
42
+ bands (list): The Sentinel-2 bands used for RGB data (B04 - Red, B03 - Green, B02 - Blue).
43
+ size (tuple): The expected input size of the image for the DINOv2 model (height, width).
44
+ """
45
+ super().__init__()
46
+
47
+ # Load the DINOv2 processor and model from Hugging Face
48
+ self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
49
+ self.model = AutoModel.from_pretrained('facebook/dinov2-base')
50
+
51
+ # Define the RGB bands for Sentinel-2 (B04, B03, B02)
52
+ self.bands = ['B04', 'B03', 'B02']
53
+
54
+ # Extract the input size from the processor settings
55
+ self.size = self.processor.crop_size['height'], self.processor.crop_size['width']
56
+
57
+
58
+ def normalize(self, input):
59
+ """
60
+ Normalizes Sentinel-2 RGB data to true-color values.
61
+
62
+ The input image (in raw Sentinel-2 reflectance values) is first divided by 10,000 to convert it
63
+ to reflectance values in the range [0, 1]. Then, the result is multiplied by 2.5 to obtain true-color
64
+ values that are suitable for input into the DINOv2 model.
65
+
66
+ Args:
67
+ input (torch.Tensor): The raw Sentinel-2 image tensor to be normalized.
68
+
69
+ Returns:
70
+ torch.Tensor: The normalized true-color image.
71
+ """
72
+ return (2.5 * (input / 1e4)).clip(0,1)
73
+
74
+ def forward(self, input):
75
+ """
76
+ Forward pass through the model to generate embeddings for the input image.
77
+
78
+ The input image is first normalized using the `normalize` method, then processed by the DINOv2 image processor
79
+ and passed through the DINOv2 model to generate embeddings. The output from the model is averaged across
80
+ the sequence dimension to obtain a fixed-size embedding.
81
+
82
+ Args:
83
+ input (torch.Tensor): The input Sentinel-2 image tensor with shape [C, H, W], where C=3 (RGB channels).
84
+
85
+ Returns:
86
+ torch.Tensor: The embedding vector, averaged over the sequence dimension, with shape [embedding_dim].
87
+ """
88
+ model_input = self.processor(self.normalize(input), return_tensors="pt")
89
+ outputs = self.model(model_input['pixel_values'].to(self.model.device))
90
+ last_hidden_states = outputs.last_hidden_state
91
+ return last_hidden_states.mean(dim=1).cpu()
MajorTOM/embedder/models/SSL4EO_S1RTC.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchgeo.models import ResNet50_Weights
3
+ import timm
4
+ import numpy as np
5
+
6
+ class SSL4EO_S1RTC_Embedder(torch.nn.Module):
7
+ """
8
+ SSL4EO Embedder for Sentinel-1 data using a pre-trained model.
9
+
10
+ This model is based on the SSL4EO (Self-Supervised Learning for Earth Observation) approach,
11
+ using a pre-trained ResNet50 model for Sentinel-1 radar data (SAR). The model is fine-tuned
12
+ to work with Sentinel-1 data and can be used directly for feature extraction.
13
+
14
+ Project Code:
15
+ https://github.com/zhu-xlab/SSL4EO-S12
16
+
17
+ Publication:
18
+ https://arxiv.org/abs/2211.07044
19
+ """
20
+
21
+ def __init__(self, s1_mean=[-12.54847273, -20.19237134], s1_std=[5.25697717,5.91150917]):
22
+ """
23
+ Initializes the SSL4EO_S1RTC_Embedder by setting up the mean and standard deviation for Sentinel-1 data normalization,
24
+ and loading the pre-trained model.
25
+
26
+ The model uses a pre-trained ResNet50 architecture adapted for Sentinel-1 radar (SAR) data, with weights provided
27
+ by the `torchgeo` library. The `s1_mean` and `s1_std` are used for normalizing the input data to the model.
28
+
29
+ Args:
30
+ s1_mean (list, optional): Mean values for Sentinel-1 radar (SAR) data. Default is set to SSL4EO's values.
31
+ s1_std (list, optional): Standard deviation values for Sentinel-1 radar (SAR) data. Default is set to SSL4EO's values.
32
+
33
+ Attributes:
34
+ s1_mean (torch.FloatTensor): Mean values for normalization.
35
+ s1_std (torch.FloatTensor): Standard deviation values for normalization.
36
+ model (torch.nn.Module): The ResNet50 model initialized with pre-trained weights.
37
+ bands (list): List of Sentinel-1 bands used for input data (VV, VH).
38
+ size (tuple): The input size expected by the model (224x224 pixels).
39
+ """
40
+ super().__init__()
41
+
42
+ self.s1_mean = torch.FloatTensor(s1_mean)
43
+ self.s1_std = torch.FloatTensor(s1_std)
44
+
45
+ # load model
46
+ self.model = self.init_model()
47
+ self.bands = ['vv','vh']
48
+ self.size = 224,224
49
+
50
+ def init_model(self):
51
+ """
52
+ Initializes the ResNet50 model with pre-trained weights for Sentinel-1 data.
53
+
54
+ This method loads the pre-trained model weights for Sentinel-1 data from `ResNet50_Weights.SENTINEL1_ALL_MOCO`
55
+ and sets the fully connected layer (`fc`) to an identity function to output embeddings directly from the last
56
+ convolutional layer.
57
+
58
+ Returns:
59
+ torch.nn.Module: The initialized ResNet50 model.
60
+ """
61
+ weights = ResNet50_Weights.SENTINEL1_ALL_MOCO
62
+ model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
63
+ model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
64
+ model.fc=torch.nn.Identity()
65
+
66
+ return model
67
+
68
+ def normalize(self, img,scale=1.0):
69
+ """
70
+ Normalizes the Sentinel-1 SAR (Synthetic Aperture Radar) data.
71
+
72
+ This method normalizes the Sentinel-1 radar signals using the mean (`s1_mean`)
73
+ and standard deviation (`s1_std`) values. The radar data is normalized to a
74
+ standard range, and the pixel values are scaled using a factor (`scale`).
75
+
76
+ Args:
77
+ img (torch.Tensor): The input Sentinel-1 image to be normalized.
78
+ scale (float, optional): The scaling factor for the normalized image. Default is 1.0.
79
+
80
+ Returns:
81
+ torch.Tensor: The normalized and scaled image.
82
+ """
83
+
84
+
85
+ min_value = (self.s1_mean - 2 * self.s1_std).to(img.device)
86
+ max_value = (self.s1_mean + 2 * self.s1_std).to(img.device)
87
+ img = (img - min_value[:,None,None]) / (max_value - min_value)[:,None,None] * scale
88
+ img = img.clip(0,scale).float()
89
+
90
+ return img
91
+
92
+ def preprocess(self, input):
93
+ """
94
+ Preprocesses the Sentinel-1 SAR (Synthetic Aperture Radar) data before feeding it into the model.
95
+
96
+ This method applies a logarithmic transformation to the input image to convert
97
+ it from linear scale to decibel (dB) scale. The image is clipped to avoid
98
+ logarithm of zero and then normalized using the `normalize` method.
99
+
100
+ Args:
101
+ input (torch.Tensor): The input Sentinel-1 image (e.g., VV or VH polarization).
102
+
103
+ Returns:
104
+ torch.Tensor: The preprocessed and normalized image in dB scale.
105
+ """
106
+ # Convert the input from linear scale to decibel (dB) scale
107
+ dB_input = 10 * input.log10(input.clip(min=1e-10)) # Clip to prevent log(0)
108
+
109
+ # Normalize the dB-scaled image
110
+ return self.normalize(dB_input)
111
+
112
+ def forward(self, input):
113
+ """
114
+ Forward pass through the model.
115
+
116
+ The input image is preprocessed using the `preprocess` method and then passed
117
+ through the ResNet50 model to obtain an embedding.
118
+
119
+ Args:
120
+ input (torch.Tensor): Preprocessed Sentinel-1 image (e.g., shape: [C, H, W]).
121
+
122
+ Returns:
123
+ torch.Tensor: The output embedding from the model.
124
+ """
125
+ return self.model(self.preprocess(input))
MajorTOM/embedder/models/SSL4EO_S2L1C.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchgeo.models import ResNet50_Weights
3
+ import timm
4
+
5
+ class SSL4EO_S2L1C_Embedder(torch.nn.Module):
6
+ """
7
+ SSL4EO Embedder for Sentinel-2 data using a pre-trained model.
8
+
9
+ This model is based on the SSL4EO (Self-Supervised Learning for Earth Observation) approach,
10
+ using a pre-trained ResNet50 model for Sentinel-2 data. The model is fine-tuned for Sentinel-2
11
+ images and can be used directly for feature extraction.
12
+
13
+ Project Code:
14
+ https://github.com/zhu-xlab/SSL4EO-S12
15
+
16
+ Publication:
17
+ https://arxiv.org/abs/2211.07044
18
+ """
19
+
20
+
21
+
22
+ def __init__(self):
23
+ """
24
+ Initializes the SSL4EO_S2L1C_Embedder by loading the pre-trained SSL4EO model.
25
+
26
+ The model uses ResNet50 architecture, adapted for Sentinel-2 data with a specific
27
+ weight configuration (`ResNet50_Weights.SENTINEL2_ALL_DINO`) provided by `torchgeo`.
28
+ It also defines the bands used for Sentinel-2 data and sets the input image size to
29
+ 224x224 pixels (the model input size).
30
+
31
+ Attributes:
32
+ model (torch.nn.Module): The ResNet50 model with pre-trained weights for Sentinel-2 data.
33
+ bands (list): List of Sentinel-2 bands used for input data.
34
+ size (tuple): The input image size expected by the model, set to 224x224 pixels.
35
+ """
36
+ super().__init__()
37
+
38
+ # Load the pre-trained SSL4EO ResNet50 model
39
+ self.model = self.init_model()
40
+
41
+ # Define the Sentinel-2 L1C bands (e.g., B01, B02, B03, etc.)
42
+ self.bands = [
43
+ 'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07',
44
+ 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
45
+ ]
46
+
47
+ # Define the expected input size of the model
48
+ self.size = 224, 224
49
+
50
+ def init_model(self):
51
+ """
52
+ Initializes the ResNet50 model with pre-trained weights for Sentinel-2 data.
53
+
54
+ The model is loaded using the `timm` library, with Sentinel-2 specific weights
55
+ (`ResNet50_Weights.SENTINEL2_ALL_DINO`). The fully connected layer (`fc`) is replaced
56
+ with an identity function to obtain embeddings directly from the last convolutional
57
+ layer.
58
+
59
+ Returns:
60
+ torch.nn.Module: The initialized ResNet50 model.
61
+ """
62
+ weights = ResNet50_Weights.SENTINEL2_ALL_DINO
63
+ model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
64
+ model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
65
+ model.fc=torch.nn.Identity()
66
+
67
+ return model
68
+
69
+ def preprocess(self, input):
70
+ """
71
+ Preprocesses the Sentinel-2 input data for the model.
72
+
73
+ This function normalizes the input image by dividing the pixel values by 10,000.
74
+ This scaling step ensures that the reflectance values are mapped into an appropriate
75
+ range for the model.
76
+
77
+ Args:
78
+ input (torch.Tensor): Input image with Sentinel-2 reflectance values (e.g., shape: [C, H, W]).
79
+
80
+ Returns:
81
+ torch.Tensor: Preprocessed input, scaled by a factor of 10,000.
82
+ """
83
+ return input / 1e4
84
+
85
+ def forward(self, input):
86
+ """
87
+ Forward pass through the model.
88
+
89
+ The input image is preprocessed and then passed through the ResNet50 model to obtain the embedding.
90
+
91
+ Args:
92
+ input (torch.Tensor): Preprocessed Sentinel-2 image (e.g., shape: [C, H, W]).
93
+
94
+ Returns:
95
+ torch.Tensor: The output embedding from the model.
96
+ """
97
+ return self.model(self.preprocess(input))
MajorTOM/embedder/models/SigLIP_S2RGB.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from open_clip import create_model_from_pretrained, get_tokenizer
2
+ import torch
3
+
4
+ class SigLIP_S2RGB_Embedder(torch.nn.Module):
5
+ """
6
+ Embedding wrapper for SigLIP and Sentinel-2 data.
7
+
8
+ This model processes Sentinel-2 RGB data and embeds it into a feature space using the DINOv@ transformer model.
9
+ The preprocessing includes normalizing Sentinel-2 values to create a True-Colour image before passing it through
10
+ the model. The final output is a high-dimensional feature vector representing the input image.
11
+
12
+ Preprocessing:
13
+ - Sentinel-2 bands are divided by 10,000 to scale the reflectance values.
14
+ - Then, the values are multiplied by 2.5 to map them into the [0, 1] range for True-Colour images.
15
+ - The model input is further processed using the DINOv@ preprocessor.
16
+
17
+ Model:
18
+ - Takes an RGB input of shape 384x384 pixels and produces an embedding vector.
19
+ """
20
+
21
+ def __init__(self):
22
+ super().__init__()
23
+
24
+ # load model
25
+ self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
26
+ # Sentinel-2 RGB bands (B04 - Red, B03 - Green, B02 - Blue)
27
+ self.bands = ['B04', 'B03', 'B02']
28
+ self.size = self.preprocess.transforms[0].size
29
+
30
+ def normalize(self, input):
31
+ """
32
+ Normalizes Sentinel-2 image data to create a True-Colour image.
33
+
34
+ Sentinel-2 images are scaled to reflectance values in the range [0, 1]. This function:
35
+ - Divides the input by 10,000 to scale Sentinel-2 values.
36
+ - Multiplies the result by 2.5 to map the values into the True-Colour image range.
37
+
38
+ Args:
39
+ input (torch.Tensor or np.ndarray): Input image with Sentinel-2 reflectance values.
40
+
41
+ Returns:
42
+ torch.Tensor: Normalized True-Colour image, clipped to the range [0, 1].
43
+ """
44
+ return (2.5 * (input / 1e4)).clip(0,1)
45
+
46
+ def forward(self, input):
47
+ """
48
+ Forward pass through the SigLIP model.
49
+
50
+ This method normalizes the input Sentinel-2 image to a True-Colour representation and processes it through
51
+ the model to obtain an embedding.
52
+
53
+ Args:
54
+ input (torch.Tensor): A Sentinel-2 image, typically of shape (C, H, W), where C=3 (RGB),
55
+ H=384, and W=384.
56
+
57
+ Returns:
58
+ torch.Tensor: The image embedding produced by the model.
59
+ """
60
+ preprocess_input = self.normalize(input)
61
+
62
+ # normalization only
63
+ model_input = self.preprocess.transforms[-1](preprocess_input)
64
+
65
+ return self.model.encode_image(model_input)
MajorTOM/embedder/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .SigLIP_S2RGB import *
2
+ from .DINOv2_S2RGB import *
3
+ from .SSL4EO_S2L1C import *
4
+ from .SSL4EO_S1RTC import *
MajorTOM/embedder/models/__pycache__/DINOv2_S2RGB.cpython-311.pyc ADDED
Binary file (5.58 kB). View file
 
MajorTOM/embedder/models/__pycache__/SSL4EO_S1RTC.cpython-311.pyc ADDED
Binary file (7.02 kB). View file
 
MajorTOM/embedder/models/__pycache__/SSL4EO_S2L1C.cpython-311.pyc ADDED
Binary file (4.75 kB). View file
 
MajorTOM/embedder/models/__pycache__/SigLIP_S2RGB.cpython-311.pyc ADDED
Binary file (3.72 kB). View file
 
MajorTOM/embedder/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (308 Bytes). View file
 
MajorTOM/extras/coverage-example.png ADDED

Git LFS Details

  • SHA256: a2ed4c9e1b6516b07b803cdced733213d3db3692665c119814fb495089231627
  • Pointer size: 132 Bytes
  • Size of remote file: 2.97 MB
MajorTOM/extras/coverage_vis.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from mpl_toolkits.basemap import Basemap
5
+ import PIL
6
+
7
+ def get_mask(df):
8
+ """
9
+ Take a Major TOM dataframe and create a mask corresponding to available cells
10
+ """
11
+
12
+ mask = np.zeros((2004,4008), dtype=np.uint8)
13
+ row_offset = -1002
14
+ col_offset = -2004
15
+
16
+ nodata = df['nodata'].values > 0.5
17
+
18
+ yy = mask.shape[0] - (np.array(df['grid_row_u']) - row_offset) - 1
19
+ xx = np.array(df['grid_col_r']) - col_offset
20
+
21
+ yy = yy[~nodata]
22
+ xx = xx[~nodata]
23
+
24
+ mask[yy, xx] = 255
25
+
26
+ return PIL.Image.fromarray(mask)
27
+
28
+ def fig2img(fig):
29
+ """Convert a Matplotlib figure to a PIL Image and return it"""
30
+ import io
31
+ buf = io.BytesIO()
32
+ fig.savefig(buf)
33
+ buf.seek(0)
34
+ img = PIL.Image.open(buf)
35
+ return img
36
+
37
+ def light_basemap():
38
+ """
39
+ Bright coloured contours
40
+ """
41
+
42
+ with plt.ioff():
43
+ fig, ax = plt.subplots(figsize=(48,24), dpi=167)
44
+
45
+ m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
46
+ m.fillcontinents(color="#9eba9b", lake_color='#CCDDFF')
47
+ m.drawmapboundary(fill_color="#CCDDFF")
48
+ m.drawcountries(color="#666666", linewidth=1)
49
+ m.drawcoastlines(color="#666666", linewidth=1)
50
+
51
+ plt.gca().set_axis_off()
52
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
53
+ hspace = 0, wspace = 0)
54
+ plt.margins(0,0)
55
+
56
+ return fig2img(fig)
57
+
58
+ def dark_basemap():
59
+ """
60
+ Dark contours
61
+ """
62
+
63
+ with plt.ioff():
64
+ fig, ax = plt.subplots(figsize=(48,24), dpi=167)
65
+
66
+ m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
67
+ m.fillcontinents(color="#242424", lake_color='#242424')
68
+ m.drawmapboundary(fill_color="#242424")
69
+ m.drawcountries(color="#000000", linewidth=1)
70
+ m.drawcoastlines(color="#000000", linewidth=1)
71
+
72
+ plt.gca().set_axis_off()
73
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
74
+ hspace = 0, wspace = 0)
75
+ plt.margins(0,0)
76
+
77
+ return fig2img(fig)
78
+
79
+ def get_coveragemap(input, input2=None):
80
+ """
81
+ Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
82
+
83
+ Optionally, input2 can be provided and then, the map plots a map with extra colours indicating cells available only in input (green) or only input2 (blue)
84
+ """
85
+
86
+ if input2 is None:
87
+ return single_coveragemap(input)
88
+ else:
89
+ cmap1 = single_coveragemap(input)
90
+ cmap2 = single_coveragemap(input2)
91
+
92
+ # arrays for mixing
93
+ inp1_arr = np.array(cmap1)[...,:3]
94
+ inp2_arr = np.array(cmap2)[...,:3]
95
+
96
+ common_arr = inp1_arr*(inp1_arr.sum(-1) == inp2_arr.sum(-1))[:,:,None]
97
+ common_arr[:,:,(1,2)] = 0
98
+ inp1_arr[:,:,(0,2)] = 0 # Green - indicates presence of S2 only
99
+ inp2_arr[:,:,(0,1)] = 0 # Blue - indicates presense of DEM only
100
+
101
+ return PIL.Image.fromarray(((common_arr + inp1_arr + inp2_arr)).astype(np.uint8))
102
+
103
+
104
+ def single_coveragemap(input):
105
+ """
106
+ Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
107
+ """
108
+
109
+ # compute mask if df is provided
110
+ if isinstance(input, pd.DataFrame):
111
+ mask = get_mask(input)
112
+ else:
113
+ mask = input
114
+
115
+ basemap = light_basemap()
116
+ basemap_d = dark_basemap()
117
+
118
+ outside_earth = np.array(basemap.convert('RGBA'))[:, :, 0] == 255
119
+ outside_earth = PIL.Image.fromarray(outside_earth)
120
+
121
+ mask = mask.resize(basemap.size, PIL.Image.NEAREST)
122
+
123
+ basemap.putalpha(mask)
124
+
125
+ # Mask outside of earth
126
+ basemap.paste(outside_earth, (0,0), outside_earth)
127
+
128
+ basemap_d.paste(basemap, (0,0), basemap)
129
+
130
+ return basemap_d
131
+
132
+ if __name__ == '__main__':
133
+ DATASET_NAME = 'Major-TOM/Core-S2L2A'
134
+ meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
135
+ df = pd.read_parquet(meta_path)
136
+
137
+ # This is how you make a coverage figure!
138
+ coverage_img = get_coveragemap(df)
139
+
140
+ coverage_img.save('coverage-example.png', format='PNG')
141
+
142
+ # and this is how you can create an overap for 2 datasets!
143
+ DATASET_NAME = 'Major-TOM/Core-DEM'
144
+ meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
145
+ dem_df = pd.read_parquet(meta_path)
146
+
147
+ coverage_img = get_coveragemap(df,dem_df)
148
+
149
+ coverage_img.save('overlap-coverage-example.png', format='PNG')
MajorTOM/extras/extract-sample-from-raw-S2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
MajorTOM/extras/thumbnail_dem.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
3
+
4
+ Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
5
+ """
6
+
7
+ from rasterio.io import MemoryFile
8
+ from PIL import Image
9
+ import numpy as np
10
+ import os
11
+ from pathlib import Path
12
+ import rasterio as rio
13
+ from matplotlib.colors import LightSource
14
+
15
+ def get_grayscale(x):
16
+ """
17
+ Normalized grayscale visualisation
18
+ """
19
+
20
+ # normalize
21
+ x_n = x-x.min()
22
+ x_n = x_n/x_n.max()
23
+
24
+ return np.uint8(x_n*255)
25
+
26
+ def get_hillshade(x, azdeg=315, altdeg=45,ve=1):
27
+ """
28
+ Hillshade visualisation for DEM
29
+ """
30
+ ls = LightSource(azdeg=azdeg, altdeg=altdeg)
31
+
32
+ return np.uint8(255*ls.hillshade(x, vert_exag=ve))
33
+
34
+ def dem_thumbnail(dem, dem_NODATA = -32768.0, hillshade=True):
35
+ """
36
+ Takes vv and vh numpy arrays along with the corresponding NODATA values (default is -32768.0)
37
+
38
+ Returns a numpy array with the thumbnail
39
+ """
40
+ if hillshade:
41
+ return get_hillshade(dem)
42
+ else:
43
+ return get_grayscale(dem)
44
+
45
+
46
+ def dem_thumbnail_from_datarow(datarow):
47
+ """
48
+ Takes a datarow directly from one of the data parquet files
49
+
50
+ Returns a PIL Image
51
+ """
52
+
53
+ with MemoryFile(datarow['DEM'][0].as_py()) as mem_f:
54
+ with mem_f.open(driver='GTiff') as f:
55
+ dem=f.read().squeeze()
56
+ dem_NODATA = f.nodata
57
+
58
+ img = dem_thumbnail(dem, dem_NODATA)
59
+
60
+ return Image.fromarray(img,'L')
61
+
62
+ if __name__ == '__main__':
63
+ from fsspec.parquet import open_parquet_file
64
+ import pyarrow.parquet as pq
65
+
66
+ print('[example run] reading file from HuggingFace...')
67
+ url = "https://huggingface.co/datasets/Major-TOM/Core-DEM/resolve/main/images/part_01001.parquet"
68
+ with open_parquet_file(url) as f:
69
+ with pq.ParquetFile(f) as pf:
70
+ first_row_group = pf.read_row_group(1)
71
+
72
+ print('[example run] computing the thumbnail...')
73
+ thumbnail = dem_thumbnail_from_datarow(first_row_group)
74
+
75
+ thumbnail_fname = 'example_thumbnail.png'
76
+ thumbnail.save(thumbnail_fname, format = 'PNG')
77
+ print('[example run] saved as "{}"'.format(thumbnail_fname))
MajorTOM/extras/thumbnail_s1rtc.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
3
+
4
+ Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
5
+ """
6
+
7
+ from rasterio.io import MemoryFile
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ def s1rtc_thumbnail(vv, vh, vv_NODATA = -32768.0, vh_NODATA = -32768.0):
12
+ """
13
+ Takes vv and vh numpy arrays along with the corresponding NODATA values (default is -32768.0)
14
+
15
+ Returns a numpy array with the thumbnail
16
+ """
17
+
18
+ # valid data masks
19
+ vv_mask = vv != vv_NODATA
20
+ vh_mask = vh != vh_NODATA
21
+
22
+ # remove invalid values before log op
23
+ vv[vv<0] = vv[vv>=0].min()
24
+ vh[vh<0] = vh[vh>=0].min()
25
+
26
+ # apply log op
27
+ vv_dB = 10*np.log10(vv)
28
+ vh_dB = 10*np.log10(vh)
29
+
30
+ # scale to 0-255
31
+ vv_dB = (vv_dB - vv_dB[vv_mask].min()) / (vv_dB[vv_mask].max() - vv_dB[vv_mask].min()) * 255
32
+ vh_dB = (vh_dB - vh_dB[vh_mask].min()) / (vh_dB[vh_mask].max() - vh_dB[vh_mask].min()) * 255
33
+
34
+ # represent nodata as 0
35
+ vv_dB[vv_mask==0] = 0
36
+ vh_dB[vh_mask==0] = 0
37
+
38
+ # false colour composite
39
+ return np.stack([vv_dB,
40
+ 255*(vv_dB+vh_dB)/np.max(vv_dB+vh_dB),
41
+ vh_dB
42
+ ],-1).astype(np.uint8)
43
+
44
+ def s1rtc_thumbnail_from_datarow(datarow):
45
+ """
46
+ Takes a datarow directly from one of the data parquet files
47
+
48
+ Returns a PIL Image
49
+ """
50
+
51
+ with MemoryFile(datarow['vv'][0].as_py()) as mem_f:
52
+ with mem_f.open(driver='GTiff') as f:
53
+ vv=f.read().squeeze()
54
+ vv_NODATA = f.nodata
55
+
56
+ with MemoryFile(datarow['vh'][0].as_py()) as mem_f:
57
+ with mem_f.open(driver='GTiff') as f:
58
+ vh=f.read().squeeze()
59
+ vh_NODATA = f.nodata
60
+
61
+ img = s1rtc_thumbnail(vv, vh, vv_NODATA=vv_NODATA, vh_NODATA=vh_NODATA)
62
+
63
+ return Image.fromarray(img)
64
+
65
+ if __name__ == '__main__':
66
+ from fsspec.parquet import open_parquet_file
67
+ import pyarrow.parquet as pq
68
+
69
+ print('[example run] reading file from HuggingFace...')
70
+ url = "https://huggingface.co/datasets/Major-TOM/Core-S1RTC/resolve/main/images/part_00001.parquet"
71
+ with open_parquet_file(url) as f:
72
+ with pq.ParquetFile(f) as pf:
73
+ first_row_group = pf.read_row_group(1)
74
+
75
+ print('[example run] computing the thumbnail...')
76
+ thumbnail = s1rtc_thumbnail_from_datarow(first_row_group)
77
+
78
+ thumbnail_fname = 'example_thumbnail.png'
79
+ thumbnail.save(thumbnail_fname, format = 'PNG')
80
+ print('[example run] saved as "{}"'.format(thumbnail_fname))
MajorTOM/extras/thumbnail_s2.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
3
+
4
+ Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
5
+ """
6
+
7
+ from rasterio.io import MemoryFile
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ def s2l2a_thumbnail(B04, B03, B02, gain=1.3, gamma=0.6):
12
+ """
13
+ Takes B04, B03, B02 numpy arrays along with the corresponding NODATA values (default is -32768.0)
14
+
15
+ Returns a numpy array with the thumbnail
16
+ """
17
+
18
+ # concatenate
19
+ thumb = np.stack([B04, B03, B02], -1)
20
+
21
+ # apply gain & gamma
22
+ thumb = gain*((thumb/10_000)**gamma)
23
+
24
+ return (thumb.clip(0,1)*255).astype(np.uint8)
25
+
26
+ def s2l2a_thumbnail_from_datarow(datarow):
27
+ """
28
+ Takes a datarow directly from one of the data parquet files
29
+
30
+ Returns a PIL Image
31
+ """
32
+
33
+ # red
34
+ with MemoryFile(datarow['B04'][0].as_py()) as mem_f:
35
+ with mem_f.open(driver='GTiff') as f:
36
+ B04=f.read().squeeze()
37
+ B04_NODATA = f.nodata
38
+
39
+ # green
40
+ with MemoryFile(datarow['B03'][0].as_py()) as mem_f:
41
+ with mem_f.open(driver='GTiff') as f:
42
+ B03=f.read().squeeze()
43
+ B03_NODATA = f.nodata
44
+
45
+ # blue
46
+ with MemoryFile(datarow['B02'][0].as_py()) as mem_f:
47
+ with mem_f.open(driver='GTiff') as f:
48
+ B02=f.read().squeeze()
49
+ B02_NODATA = f.nodata
50
+
51
+ img = s2l2a_thumbnail(B04,B03,B02)
52
+
53
+ return Image.fromarray(img)
54
+
55
+ if __name__ == '__main__':
56
+ from fsspec.parquet import open_parquet_file
57
+ import pyarrow.parquet as pq
58
+
59
+ print('[example run] reading file from HuggingFace...')
60
+ url = "https://huggingface.co/datasets/Major-TOM/Core-S2L2A/resolve/main/images/part_01000.parquet"
61
+ with open_parquet_file(url, columns = ["B04", "B03", "B02"]) as f:
62
+ with pq.ParquetFile(f) as pf:
63
+ first_row_group = pf.read_row_group(1, columns = ["B04", "B03", "B02"])
64
+
65
+ print('[example run] computing the thumbnail...')
66
+ thumbnail = s2l2a_thumbnail_from_datarow(first_row_group)
67
+
68
+ thumbnail.save('example_thumbnail.png', format = 'PNG')
MajorTOM/grid.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import pandas as pd
4
+ import geopandas as gpd
5
+ from shapely.geometry import LineString, Polygon
6
+ from tqdm import tqdm
7
+ import re
8
+
9
+
10
+
11
+ class Grid():
12
+
13
+ RADIUS_EQUATOR = 6378.137 # km
14
+
15
+ def __init__(self,dist,latitude_range=(-85,85),longitude_range=(-180,180),utm_definition='bottomleft'):
16
+ self.dist = dist
17
+ self.latitude_range = latitude_range
18
+ self.longitude_range = longitude_range
19
+ self.utm_definition = utm_definition
20
+ self.rows,self.lats = self.get_rows()
21
+ self.points, self.points_by_row = self.get_points()
22
+
23
+ def get_rows(self):
24
+
25
+ # Define set of latitudes to use, based on the grid distance
26
+ arc_pole_to_pole = math.pi * self.RADIUS_EQUATOR
27
+ num_divisions_in_hemisphere = math.ceil(arc_pole_to_pole / self.dist)
28
+
29
+ latitudes = np.linspace(-90, 90, num_divisions_in_hemisphere+1)[:-1]
30
+ latitudes = np.mod(latitudes, 180) - 90
31
+
32
+ # order should be from south to north
33
+ latitudes = np.sort(latitudes)
34
+
35
+ zeroth_row = np.searchsorted(latitudes,0)
36
+
37
+ # From 0U-NU and 1D-ND
38
+ rows = [None] * len(latitudes)
39
+ rows[zeroth_row:] = [f'{i}U' for i in range(len(latitudes)-zeroth_row)]
40
+ rows[:zeroth_row] = [f'{abs(i-zeroth_row)}D' for i in range(zeroth_row)]
41
+
42
+ # bound to range
43
+ idxs = (latitudes>=self.latitude_range[0]) * (latitudes<=self.latitude_range[1])
44
+ rows,latitudes = np.array(rows), np.array(latitudes)
45
+ rows,latitudes = rows[idxs],latitudes[idxs]
46
+
47
+ return rows,latitudes
48
+
49
+ def get_circumference_at_latitude(self,lat):
50
+
51
+ # Circumference of the cross-section of a sphere at a given latitude
52
+
53
+ radius_at_lat = self.RADIUS_EQUATOR * math.cos(lat * math.pi / 180)
54
+ circumference = 2 * math.pi * radius_at_lat
55
+
56
+ return circumference
57
+
58
+ def subdivide_circumference(self,lat,return_cols=False):
59
+ # Provide a list of longitudes that subdivide the circumference of the earth at a given latitude
60
+ # into equal parts as close as possible to dist
61
+
62
+ circumference = self.get_circumference_at_latitude(lat)
63
+ num_divisions = math.ceil(circumference / self.dist)
64
+ longitudes = np.linspace(-180,180, num_divisions+1)[:-1]
65
+ longitudes = np.mod(longitudes, 360) - 180
66
+ longitudes = np.sort(longitudes)
67
+
68
+
69
+ if return_cols:
70
+ cols = [None] * len(longitudes)
71
+ zeroth_idx = np.where(longitudes==0)[0][0]
72
+ cols[zeroth_idx:] = [f'{i}R' for i in range(len(longitudes)-zeroth_idx)]
73
+ cols[:zeroth_idx] = [f'{abs(i-zeroth_idx)}L' for i in range(zeroth_idx)]
74
+ return np.array(cols),np.array(longitudes)
75
+
76
+ return np.array(longitudes)
77
+
78
+ def get_points(self):
79
+
80
+ r_idx = 0
81
+ points_by_row = [None]*len(self.rows)
82
+ for r,lat in zip(self.rows,self.lats):
83
+ point_names,grid_row_names,grid_col_names,grid_row_idx,grid_col_idx,grid_lats,grid_lons,utm_zones,epsgs = [],[],[],[],[],[],[],[],[]
84
+ cols,lons = self.subdivide_circumference(lat,return_cols=True)
85
+
86
+ cols,lons = self.filter_longitude(cols,lons)
87
+ c_idx = 0
88
+ for c,lon in zip(cols,lons):
89
+ point_names.append(f'{r}_{c}')
90
+ grid_row_names.append(r)
91
+ grid_col_names.append(c)
92
+ grid_row_idx.append(r_idx)
93
+ grid_col_idx.append(c_idx)
94
+ grid_lats.append(lat)
95
+ grid_lons.append(lon)
96
+ if self.utm_definition == 'bottomleft':
97
+ utm_zones.append(get_utm_zone_from_latlng([lat,lon]))
98
+ elif self.utm_definition == 'center':
99
+ center_lat = lat + (1000*self.dist/2)/111_120
100
+ center_lon = lon + (1000*self.dist/2)/(111_120*math.cos(center_lat*math.pi/180))
101
+ utm_zones.append(get_utm_zone_from_latlng([center_lat,center_lon]))
102
+ else:
103
+ raise ValueError(f'Invalid utm_definition {self.utm_definition}')
104
+ epsgs.append(f'EPSG:{utm_zones[-1]}')
105
+
106
+ c_idx += 1
107
+ points_by_row[r_idx] = gpd.GeoDataFrame({
108
+ 'name':point_names,
109
+ 'row':grid_row_names,
110
+ 'col':grid_col_names,
111
+ 'row_idx':grid_row_idx,
112
+ 'col_idx':grid_col_idx,
113
+ 'utm_zone':utm_zones,
114
+ 'epsg':epsgs
115
+ },geometry=gpd.points_from_xy(grid_lons,grid_lats))
116
+ r_idx += 1
117
+ points = gpd.GeoDataFrame(pd.concat(points_by_row))
118
+ # points.reset_index(inplace=True,drop=True)
119
+ return points, points_by_row
120
+
121
+ def group_points_by_row(self):
122
+ # Make list of different gdfs for each row
123
+ points_by_row = [None]*len(self.rows)
124
+ for i,row in enumerate(self.rows):
125
+ points_by_row[i] = self.points[self.points.row==row]
126
+ return points_by_row
127
+
128
+ def filter_longitude(self,cols,lons):
129
+ idxs = (lons>=self.longitude_range[0]) * (lons<=self.longitude_range[1])
130
+ cols,lons = cols[idxs],lons[idxs]
131
+ return cols,lons
132
+
133
+ def latlon2rowcol(self,lats,lons,return_idx=False,integer=False):
134
+ """
135
+ Convert latitude and longitude to row and column number from the grid
136
+ """
137
+ # Always take bottom left corner of grid cell
138
+ rows = np.searchsorted(self.lats,lats)-1
139
+
140
+ # Get the possible points of the grid cells at the given latitude
141
+ possible_points = [self.points_by_row[row] for row in rows]
142
+
143
+ # For each point, find the rightmost point that is still to the left of the given longitude
144
+ cols = [poss_points.iloc[np.searchsorted(poss_points.geometry.x,lon)-1].col for poss_points,lon in zip(possible_points,lons)]
145
+ rows = self.rows[rows].tolist()
146
+
147
+ outputs = [rows, cols]
148
+ if return_idx:
149
+ # Get the table index for self.points with each row,col pair in rows, cols
150
+ idx = [self.points[(self.points.row==row) & (self.points.col==col)].index.values[0] for row,col in zip(rows,cols)]
151
+ outputs.append(idx)
152
+
153
+ # return raw numbers
154
+ if integer:
155
+ outputs[0] = [int(el[:-1]) if el[-1] == 'U' else -int(el[:-1]) for el in outputs[0]]
156
+ outputs[1] = [int(el[:-1]) if el[-1] == 'R' else -int(el[:-1]) for el in outputs[1]]
157
+
158
+ return outputs
159
+
160
+ def rowcol2latlon(self,rows,cols):
161
+ point_geoms = [self.points.loc[(self.points.row==row) & (self.points.col==col),'geometry'].values[0] for row,col in zip(rows,cols)]
162
+ lats = [point.y for point in point_geoms]
163
+ lons = [point.x for point in point_geoms]
164
+ return lats,lons
165
+
166
+ def get_bounded_footprint(self,point,buffer_ratio=0):
167
+ # Gets the polygon footprint of the grid cell for a given point, bounded by the other grid points' cells.
168
+ # Grid point defined as bottom-left corner of polygon. Buffer ratio is the ratio of the grid cell's width/height to buffer by.
169
+
170
+ bottom,left = point.geometry.y,point.geometry.x
171
+ row_idx = point.row_idx
172
+ col_idx = point.col_idx
173
+ next_row_idx = row_idx+1
174
+ next_col_idx = col_idx+1
175
+
176
+ if next_row_idx >= len(self.lats): # If at top row, use difference between top and second-to-top row for height
177
+ height = (self.lats[row_idx] - self.lats[row_idx-1])
178
+ top = self.lats[row_idx] + height
179
+ else:
180
+ top = self.lats[next_row_idx]
181
+
182
+ max_col = len(self.points_by_row[row_idx].col_idx)-1
183
+ if next_col_idx > max_col: # If at rightmost column, use difference between rightmost and second-to-rightmost column for width
184
+ width = (self.points_by_row[row_idx].iloc[col_idx].geometry.x - self.points_by_row[row_idx].iloc[col_idx-1].geometry.x)
185
+ right = self.points_by_row[row_idx].iloc[col_idx].geometry.x + width
186
+ else:
187
+ right = self.points_by_row[row_idx].iloc[next_col_idx].geometry.x
188
+
189
+ # Buffer the polygon by the ratio of the grid cell's width/height
190
+ width = right - left
191
+ height = top - bottom
192
+
193
+ buffer_horizontal = width * buffer_ratio
194
+ buffer_vertical = height * buffer_ratio
195
+
196
+ new_left = left - buffer_horizontal
197
+ new_right = right + buffer_horizontal
198
+
199
+ new_bottom = bottom - buffer_vertical
200
+ new_top = top + buffer_vertical
201
+
202
+ bbox = Polygon([(new_left,new_bottom),(new_left,new_top),(new_right,new_top),(new_right,new_bottom)])
203
+
204
+ return bbox
205
+
206
+ def get_utm_zone_from_latlng(latlng):
207
+ """
208
+ Get the UTM zone from a latlng list and return the corresponding EPSG code.
209
+
210
+ Parameters
211
+ ----------
212
+ latlng : List[Union[int, float]]
213
+ The latlng list to get the UTM zone from.
214
+
215
+ Returns
216
+ -------
217
+ str
218
+ The EPSG code for the UTM zone.
219
+ """
220
+ assert isinstance(latlng, (list, tuple)), "latlng must be in the form of a list or tuple."
221
+
222
+ longitude = latlng[1]
223
+ latitude = latlng[0]
224
+
225
+ zone_number = (math.floor((longitude + 180) / 6)) % 60 + 1
226
+
227
+ # Special zones for Svalbard and Norway
228
+ if latitude >= 56.0 and latitude < 64.0 and longitude >= 3.0 and longitude < 12.0:
229
+ zone_number = 32
230
+ elif latitude >= 72.0 and latitude < 84.0:
231
+ if longitude >= 0.0 and longitude < 9.0:
232
+ zone_number = 31
233
+ elif longitude >= 9.0 and longitude < 21.0:
234
+ zone_number = 33
235
+ elif longitude >= 21.0 and longitude < 33.0:
236
+ zone_number = 35
237
+ elif longitude >= 33.0 and longitude < 42.0:
238
+ zone_number = 37
239
+
240
+ # Determine the hemisphere and construct the EPSG code
241
+ if latitude < 0:
242
+ epsg_code = f"327{zone_number:02d}"
243
+ else:
244
+ epsg_code = f"326{zone_number:02d}"
245
+ if not re.match(r"32[6-7](0[1-9]|[1-5][0-9]|60)",epsg_code):
246
+ print(f"latlng: {latlng}, epsg_code: {epsg_code}")
247
+ raise ValueError(f"out of bound latlng resulted in incorrect EPSG code for the point")
248
+
249
+ return epsg_code
250
+
251
+
252
+ if __name__ == '__main__':
253
+
254
+ assert get_utm_zone_from_latlng([-1,-174.34]) == "32701"
255
+ assert get_utm_zone_from_latlng([48,-4]) == "32630"
256
+ assert get_utm_zone_from_latlng([78,13]) == "32633"
257
+ assert get_utm_zone_from_latlng([-34,19.7]) == "32734"
258
+ assert get_utm_zone_from_latlng([-36,175.7]) == "32760"
259
+
260
+
261
+ dist = 100
262
+ grid = Grid(dist)
263
+
264
+ np.random.seed(0)
265
+ test_lons = np.random.uniform(-20,20,size=(1000)) % 180 # Checks edge-case of crossing 180th meridian
266
+ test_lats = np.random.uniform(-20,68,size=(1000))
267
+
268
+ test_rows,test_cols = grid.latlon2rowcol(test_lats,test_lons)
269
+ test_lats2,test_lons2 = grid.rowcol2latlon(test_rows,test_cols)
270
+
271
+ print(test_lons[:10])
272
+ print(test_lats[:10])
273
+ print(test_rows[:10])
274
+ print(test_cols[:10])
275
+
276
+ # Make line segments from the points to their corresponding grid points
277
+ lines = []
278
+ for i in range(len(test_lats)):
279
+ lines.append([(test_lons[i],test_lats[i]),(test_lons2[i],test_lats2[i])])
280
+
281
+ lines = gpd.GeoDataFrame(geometry=gpd.GeoSeries([LineString(line) for line in lines]))
282
+
283
+ lines.to_file(f'testlines_{dist}km.geojson',driver='GeoJSON')
284
+ grid.points.to_file(f'testgrid_{dist}km.geojson',driver='GeoJSON')
MajorTOM/metadata_helpers.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyarrow.parquet as pq
2
+ import pandas as pd
3
+ import geopandas as gpd
4
+ from pathlib import Path
5
+ import urllib.request
6
+ import fsspec
7
+ from fsspec.parquet import open_parquet_file
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ from rasterio.io import MemoryFile
11
+ from tqdm.notebook import tqdm
12
+ import os
13
+
14
+ from .sample_helpers import *
15
+
16
+ def metadata_from_url(access_url, local_url):
17
+ local_url, response = urllib.request.urlretrieve(access_url, local_url)
18
+ df = pq.read_table(local_url).to_pandas()
19
+ df['timestamp'] = pd.to_datetime(df.timestamp)
20
+ gdf = gpd.GeoDataFrame(
21
+ df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
22
+ )
23
+ return gdf
24
+
25
+ def filter_metadata(df,
26
+ region=None,
27
+ daterange=None,
28
+ cloud_cover=(0,100),
29
+ nodata=(0, 1.0)
30
+ ):
31
+ """Filters the Major-TOM dataframe based on several parameters
32
+
33
+ Args:
34
+ df (geopandas dataframe): Parent dataframe
35
+ region (shapely geometry object) : Region of interest
36
+ daterange (tuple) : Inclusive range of dates (example format: '2020-01-01')
37
+ cloud_cover (tuple) : Inclusive percentage range (0-100) of cloud cover
38
+ nodata (tuple) : Inclusive fraction (0.0-1.0) of no data allowed in a sample
39
+
40
+ Returns:
41
+ df: a filtered dataframe
42
+ """
43
+ # temporal filtering
44
+ if daterange is not None:
45
+ assert (isinstance(daterange, list) or isinstance(daterange, tuple)) and len(daterange)==2
46
+ df = df[df.timestamp >= daterange[0]]
47
+ df = df[df.timestamp <= daterange[1]]
48
+
49
+ # spatial filtering
50
+ if region is not None:
51
+ idxs = df.sindex.query(region)
52
+ df = df.take(idxs)
53
+ # cloud filtering
54
+ if cloud_cover is not None:
55
+ df = df[df.cloud_cover >= cloud_cover[0]]
56
+ df = df[df.cloud_cover <= cloud_cover[1]]
57
+
58
+ # spatial filtering
59
+ if nodata is not None:
60
+ df = df[df.nodata >= nodata[0]]
61
+ df = df[df.nodata <= nodata[1]]
62
+
63
+ return df
64
+
65
+ def read_row(row, columns=["thumbnail"]):
66
+ """Reads a row from a Major-TOM dataframe
67
+
68
+ Args:
69
+ row (row from geopandas dataframe): The row of metadata
70
+ columns (list): columns to be read from the file
71
+
72
+ Returns:
73
+ data (dict): dictionary with returned data from requested columns
74
+ """
75
+ with open_parquet_file(row.parquet_url, columns=columns, footer_sample_size=2000000) as f:
76
+ with pq.ParquetFile(f) as pf:
77
+ row_group = pf.read_row_group(row.parquet_row, columns=columns)
78
+
79
+ if columns == ["thumbnail"]:
80
+ stream = BytesIO(row_group['thumbnail'][0].as_py())
81
+ return Image.open(stream)
82
+ else:
83
+ row_output = {}
84
+ for col in columns:
85
+ bytes = row_group[col][0].as_py()
86
+
87
+ if col != 'thumbnail':
88
+ row_output[col] = read_tif_bytes(bytes)
89
+ else:
90
+ stream = BytesIO(bytes)
91
+ row_output[col] = Image.open(stream)
92
+
93
+ return row_output
94
+
95
+ def filter_download(df, local_dir, source_name, by_row = False, verbose = False, tif_columns=None):
96
+ """Downloads and unpacks the data of Major-TOM based on a metadata dataframe
97
+
98
+ Args:
99
+ df (geopandas dataframe): Metadata dataframe
100
+ local_dir (str or Path) : Path to the where the data is to be stored locally
101
+ source_name (str) : Name alias of the resulting dataset
102
+ by_row (bool): If True, it will access individual rows of parquet via http - otherwise entire parquets are downloaded temporarily
103
+ verbose (bool) : option for potential internal state printing
104
+ tif_columns (list of str) : Optionally specified columns to be downloaded as .tifs, e.g. ['B04', 'B03', 'B02']
105
+
106
+ Returns:
107
+ None
108
+
109
+ """
110
+
111
+ if isinstance(local_dir, str):
112
+ local_dir = Path(local_dir)
113
+
114
+ temp_file = local_dir / 'temp.parquet'
115
+
116
+ # identify all parquets that need to be downloaded (group them)
117
+ urls = df.parquet_url.unique()
118
+ print('Starting download of {} parquet files.'.format(len(urls))) if verbose else None
119
+
120
+ for url in tqdm(urls, desc='Downloading and unpacking...', disable=not verbose):
121
+ # identify all relevant rows
122
+ rows = df[df.parquet_url == url].parquet_row.unique()
123
+
124
+ if not by_row: # (downloads entire parquet)
125
+ # download a temporary file
126
+ temp_path, http_resp = urllib.request.urlretrieve(url, temp_file)
127
+ else:
128
+ f=fsspec.open(url)
129
+ temp_path = f.open()
130
+
131
+ # populate the bands
132
+ with pq.ParquetFile(temp_path) as pf:
133
+ for row_idx in rows:
134
+ table = pf.read_row_group(row_idx)
135
+
136
+ product_id = table['product_id'][0].as_py()
137
+ grid_cell = table['grid_cell'][0].as_py()
138
+ row = grid_cell.split('_')[0]
139
+
140
+ dest = local_dir / Path("{}/{}/{}/{}".format(source_name, row, grid_cell, product_id))
141
+ dest.mkdir(exist_ok=True, parents=True)
142
+
143
+ columns = [col for col in table.column_names if col[0] == 'B'] + ['cloud_mask'] if tif_columns is None else tif_columns
144
+ # tifs
145
+ for col in columns:
146
+ with open(dest / "{}.tif".format(col), "wb") as f:
147
+ # Write bytes to file
148
+ f.write(table[col][0].as_py())
149
+
150
+ # thumbnail (png)
151
+ col = 'thumbnail'
152
+ with open(dest / "{}.png".format(col), "wb") as f:
153
+ # Write bytes to file
154
+ f.write(table[col][0].as_py())
155
+ if not by_row:
156
+ # remove downloaded file
157
+ os.remove(temp_path)
158
+ else:
159
+ f.close()
MajorTOM/sample_helpers.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rasterio.io import MemoryFile
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from PIL import Image
5
+ from io import BytesIO
6
+
7
+ def plot(sample, bands = ['B04', 'B03', 'B02'], scaling=2e3):
8
+ img = []
9
+ for b in bands:
10
+ img.append(read_tif_bytes(sample[b]))
11
+ plt.imshow(np.stack(img, -1)/2e3)
12
+
13
+ def read_tif_bytes(tif_bytes):
14
+ with MemoryFile(tif_bytes) as mem_f:
15
+ with mem_f.open(driver='GTiff') as f:
16
+ return f.read().squeeze()
17
+
18
+ def read_png_bytes(png_bytes):
19
+ stream = BytesIO(png_bytes)
20
+ return Image.open(stream)
app.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ import os
5
+ import tempfile
6
+ import zipfile
7
+ import numpy as np
8
+ import pandas as pd
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+
11
+ # Import custom modules
12
+ from models.siglip_model import SigLIPModel
13
+ from models.satclip_model import SatCLIPModel
14
+ from models.farslip_model import FarSLIPModel
15
+ from models.dinov2_model import DINOv2Model
16
+ from models.load_config import load_and_process_config
17
+ from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution
18
+ from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image
19
+ from PIL import Image as PILImage
20
+ from PIL import ImageDraw, ImageFont
21
+
22
+ # Configuration
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ print(f"Running on device: {device}")
25
+
26
+ # Load and process configuration
27
+ config = load_and_process_config()
28
+ print(config)
29
+
30
+ # Initialize Models
31
+ print("Initializing models...")
32
+ models = {}
33
+
34
+ # DINOv2
35
+ try:
36
+ if config and 'dinov2' in config:
37
+ models['DINOv2'] = DINOv2Model(
38
+ ckpt_path=config['dinov2'].get('ckpt_path'),
39
+ embedding_path=config['dinov2'].get('embedding_path'),
40
+ device=device
41
+ )
42
+ else:
43
+ models['DINOv2'] = DINOv2Model(device=device)
44
+ except Exception as e:
45
+ print(f"Failed to load DINOv2: {e}")
46
+
47
+ # SigLIP
48
+ try:
49
+ if config and 'siglip' in config:
50
+ models['SigLIP'] = SigLIPModel(
51
+ ckpt_path=config['siglip'].get('ckpt_path'),
52
+ tokenizer_path=config['siglip'].get('tokenizer_path'),
53
+ embedding_path=config['siglip'].get('embedding_path'),
54
+ device=device
55
+ )
56
+ else:
57
+ models['SigLIP'] = SigLIPModel(device=device)
58
+ except Exception as e:
59
+ print(f"Failed to load SigLIP: {e}")
60
+
61
+ # SatCLIP
62
+ try:
63
+ if config and 'satclip' in config:
64
+ models['SatCLIP'] = SatCLIPModel(
65
+ ckpt_path=config['satclip'].get('ckpt_path'),
66
+ embedding_path=config['satclip'].get('embedding_path'),
67
+ device=device
68
+ )
69
+ else:
70
+ models['SatCLIP'] = SatCLIPModel(device=device)
71
+ except Exception as e:
72
+ print(f"Failed to load SatCLIP: {e}")
73
+
74
+ # FarSLIP
75
+ try:
76
+ if config and 'farslip' in config:
77
+ models['FarSLIP'] = FarSLIPModel(
78
+ ckpt_path=config['farslip'].get('ckpt_path'),
79
+ model_name=config['farslip'].get('model_name'),
80
+ embedding_path=config['farslip'].get('embedding_path'),
81
+ device=device
82
+ )
83
+ else:
84
+ models['FarSLIP'] = FarSLIPModel(device=device)
85
+ except Exception as e:
86
+ print(f"Failed to load FarSLIP: {e}")
87
+
88
+ def get_active_model(model_name):
89
+ if model_name not in models:
90
+ return None, f"Model {model_name} not loaded."
91
+ return models[model_name], None
92
+
93
+ def combine_images(img1, img2):
94
+ if img1 is None: return img2
95
+ if img2 is None: return img1
96
+
97
+ # Resize to match width
98
+ w1, h1 = img1.size
99
+ w2, h2 = img2.size
100
+
101
+ new_w = max(w1, w2)
102
+ new_h1 = int(h1 * new_w / w1)
103
+ new_h2 = int(h2 * new_w / w2)
104
+
105
+ img1 = img1.resize((new_w, new_h1))
106
+ img2 = img2.resize((new_w, new_h2))
107
+
108
+ dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255))
109
+ dst.paste(img1, (0, 0))
110
+ dst.paste(img2, (0, new_h1))
111
+ return dst
112
+
113
+ def create_text_image(text, size=(384, 384)):
114
+ img = PILImage.new('RGB', size, color=(240, 240, 240))
115
+ d = ImageDraw.Draw(img)
116
+
117
+ # Try to load a font, fallback to default
118
+ try:
119
+ # Try to find a font that supports larger size
120
+ font = ImageFont.truetype("DejaVuSans.ttf", 40)
121
+ except:
122
+ font = ImageFont.load_default()
123
+
124
+ # Wrap text simply
125
+ margin = 20
126
+ offset = 100
127
+ for line in text.split(','):
128
+ d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0))
129
+ offset += 50
130
+
131
+ d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255))
132
+ return img
133
+
134
+ def fetch_top_k_images(top_indices, probs, df_embed, query_text=None):
135
+ """
136
+ Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image.
137
+ """
138
+ results = []
139
+
140
+ # We can run this in parallel
141
+ with ThreadPoolExecutor(max_workers=5) as executor:
142
+ future_to_idx = {}
143
+ for i, idx in enumerate(top_indices):
144
+ row = df_embed.iloc[idx]
145
+ pid = row['product_id']
146
+
147
+ # Use download_and_process_image to get real data
148
+ future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False)
149
+ future_to_idx[future] = idx
150
+
151
+ for future in as_completed(future_to_idx):
152
+ idx = future_to_idx[future]
153
+ try:
154
+ img_384, img_full = future.result()
155
+
156
+ if img_384 is None:
157
+ # Fallback to Esri if download fails
158
+ print(f"Download failed for idx {idx}, falling back to Esri...")
159
+ row = df_embed.iloc[idx]
160
+ img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text)
161
+ img_full = img_384
162
+
163
+ row = df_embed.iloc[idx]
164
+ results.append({
165
+ 'image_384': img_384,
166
+ 'image_full': img_full,
167
+ 'score': probs[idx],
168
+ 'lat': row['centre_lat'],
169
+ 'lon': row['centre_lon'],
170
+ 'id': row['product_id']
171
+ })
172
+ except Exception as e:
173
+ print(f"Error fetching image for idx {idx}: {e}")
174
+
175
+ # Sort results by score descending (since futures complete in random order)
176
+ results.sort(key=lambda x: x['score'], reverse=True)
177
+ return results
178
+
179
+ def get_all_results_metadata(model, filtered_indices, probs):
180
+ if len(filtered_indices) == 0:
181
+ return []
182
+
183
+ # Sort by score descending
184
+ filtered_scores = probs[filtered_indices]
185
+ sorted_order = np.argsort(filtered_scores)[::-1]
186
+ sorted_indices = filtered_indices[sorted_order]
187
+
188
+ # Extract from DataFrame
189
+ df_results = model.df_embed.iloc[sorted_indices].copy()
190
+ df_results['score'] = probs[sorted_indices]
191
+
192
+ # Rename columns
193
+ df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'})
194
+
195
+ # Convert to list of dicts
196
+ return df_results[['id', 'lat', 'lon', 'score']].to_dict('records')
197
+
198
+ def search_text(query, threshold, model_name):
199
+ model, error = get_active_model(model_name)
200
+ if error:
201
+ yield None, None, error, None, None, None, None
202
+ return
203
+
204
+ if not query:
205
+ yield None, None, "Please enter a query.", None, None, None, None
206
+ return
207
+
208
+ try:
209
+ timings = {}
210
+
211
+ # 1. Encode Text
212
+ yield None, None, "Encoding text...", None, None, None, None
213
+ t0 = time.time()
214
+ text_features = model.encode_text(query)
215
+ timings['Encoding'] = time.time() - t0
216
+
217
+ if text_features is None:
218
+ yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None
219
+ return
220
+
221
+ # 2. Search
222
+ yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None
223
+ t0 = time.time()
224
+ probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0)
225
+ timings['Retrieval'] = time.time() - t0
226
+
227
+ if probs is None:
228
+ yield None, None, "Search failed (embeddings missing?).", None, None, None, None
229
+ return
230
+
231
+ # Show geographic distribution (not timed)
232
+ df_embed = model.df_embed
233
+ geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})')
234
+
235
+ # 3. Download Images
236
+ yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
237
+ t0 = time.time()
238
+ top_indices = top_indices[:10]
239
+ results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query)
240
+ timings['Download'] = time.time() - t0
241
+
242
+ # 4. Visualize - keep geo_dist_map visible
243
+ yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
244
+ t0 = time.time()
245
+ fig_results = plot_top5_overview(None, results, query_info=query)
246
+ gallery_items = format_results_for_gallery(results)
247
+ timings['Visualization'] = time.time() - t0
248
+
249
+ # 5. Generate Final Status
250
+ timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
251
+ status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
252
+
253
+ all_results = get_all_results_metadata(model, filtered_indices, probs)
254
+ results_txt = format_results_to_text(all_results)
255
+
256
+ yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
257
+
258
+ except Exception as e:
259
+ import traceback
260
+ traceback.print_exc()
261
+ yield None, None, f"Error: {str(e)}", None, None, None, None
262
+
263
+ def search_image(image_input, threshold, model_name):
264
+ model, error = get_active_model(model_name)
265
+ if error:
266
+ yield None, None, error, None, None, None, None
267
+ return
268
+
269
+ if image_input is None:
270
+ yield None, None, "Please upload an image.", None, None, None, None
271
+ return
272
+
273
+ try:
274
+ timings = {}
275
+
276
+ # 1. Encode Image
277
+ yield None, None, "Encoding image...", None, None, None, None
278
+ t0 = time.time()
279
+ image_features = model.encode_image(image_input)
280
+ timings['Encoding'] = time.time() - t0
281
+
282
+ if image_features is None:
283
+ yield None, None, "Model does not support image encoding.", None, None, None, None
284
+ return
285
+
286
+ # 2. Search
287
+ yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None
288
+ t0 = time.time()
289
+ probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0)
290
+ timings['Retrieval'] = time.time() - t0
291
+
292
+ # Show geographic distribution (not timed)
293
+ df_embed = model.df_embed
294
+ geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})')
295
+
296
+ # 3. Download Images
297
+ yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
298
+ t0 = time.time()
299
+ top_indices = top_indices[:6]
300
+ results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query")
301
+ timings['Download'] = time.time() - t0
302
+
303
+ # 4. Visualize - keep geo_dist_map visible
304
+ yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
305
+ t0 = time.time()
306
+ fig_results = plot_top5_overview(image_input, results, query_info="Image Query")
307
+ gallery_items = format_results_for_gallery(results)
308
+ timings['Visualization'] = time.time() - t0
309
+
310
+ # 5. Generate Final Status
311
+ timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
312
+ status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
313
+
314
+ all_results = get_all_results_metadata(model, filtered_indices, probs)
315
+ results_txt = format_results_to_text(all_results[:50])
316
+
317
+ yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
318
+
319
+ except Exception as e:
320
+ import traceback
321
+ traceback.print_exc()
322
+ yield None, None, f"Error: {str(e)}", None, None, None, None
323
+
324
+ def search_location(lat, lon, threshold):
325
+ model_name = "SatCLIP"
326
+ model, error = get_active_model(model_name)
327
+ if error:
328
+ yield None, None, error, None, None, None, None
329
+ return
330
+
331
+ try:
332
+ timings = {}
333
+
334
+ # 1. Encode Location
335
+ yield None, None, "Encoding location...", None, None, None, None
336
+ t0 = time.time()
337
+ loc_features = model.encode_location(float(lat), float(lon))
338
+ timings['Encoding'] = time.time() - t0
339
+
340
+ if loc_features is None:
341
+ yield None, None, "Location encoding failed.", None, None, None, None
342
+ return
343
+
344
+ # 2. Search
345
+ yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None
346
+ t0 = time.time()
347
+ probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0)
348
+ timings['Retrieval'] = time.time() - t0
349
+
350
+ # 3. Generate Distribution Map (not timed for location distribution)
351
+ yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None
352
+ df_embed = model.df_embed
353
+ top_10_indices = top_indices[:10]
354
+ top_10_results = []
355
+ for idx in top_10_indices:
356
+ row = df_embed.iloc[idx]
357
+ top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']})
358
+
359
+ # Show geographic distribution (not timed)
360
+ geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})')
361
+
362
+ # 4. Download Images
363
+ yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
364
+ t0 = time.time()
365
+ top_6_indices = top_indices[:6]
366
+ results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}")
367
+
368
+ # Get query tile
369
+ query_tile = None
370
+ try:
371
+ lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce')
372
+ lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce')
373
+ dists = (lats - float(lat))**2 + (lons - float(lon))**2
374
+ nearest_idx = dists.idxmin()
375
+ pid = df_embed.loc[nearest_idx, 'product_id']
376
+ query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False)
377
+ except Exception as e:
378
+ print(f"Error fetching nearest MajorTOM image: {e}")
379
+ if query_tile is None:
380
+ query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})")
381
+ timings['Download'] = time.time() - t0
382
+
383
+ # 5. Visualize - keep geo_dist_map visible
384
+ yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
385
+ t0 = time.time()
386
+ fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}")
387
+ gallery_items = format_results_for_gallery(results)
388
+ timings['Visualization'] = time.time() - t0
389
+
390
+ # 6. Generate Final Status
391
+ timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
392
+ status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
393
+
394
+ all_results = get_all_results_metadata(model, filtered_indices, probs)
395
+ results_txt = format_results_to_text(all_results)
396
+
397
+ yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
398
+
399
+ except Exception as e:
400
+ import traceback
401
+ traceback.print_exc()
402
+ yield None, None, f"Error: {str(e)}", None, None, None, None
403
+
404
+ def generate_status_msg(count, threshold, results):
405
+ status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n"
406
+ for i, res in enumerate(results[:5]):
407
+ status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n"
408
+ return status_msg
409
+
410
+ def get_initial_plot():
411
+ # Use FarSLIP as default for initial plot, fallback to SigLIP
412
+ df_vis = None
413
+ img = None
414
+ if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
415
+ img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
416
+ # fig = plot_global_map(models['FarSLIP'].df_embed)
417
+ else:
418
+ img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
419
+ return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False)
420
+
421
+ def handle_map_click(evt: gr.SelectData, df_vis):
422
+ if evt is None:
423
+ return None, None, None, "No point selected."
424
+
425
+ try:
426
+ x, y = evt.index[0], evt.index[1]
427
+
428
+ # Image dimensions (New)
429
+ img_width = 3000
430
+ img_height = 1500
431
+
432
+ # Scaled Margins (Proportional to 4000x2000)
433
+ left_margin = 110 * 0.75
434
+ right_margin = 110 * 0.75
435
+ top_margin = 100 * 0.75
436
+ bottom_margin = 67 * 0.75
437
+
438
+ plot_width = img_width - left_margin - right_margin
439
+ plot_height = img_height - top_margin - bottom_margin
440
+
441
+ # Adjust for aspect ratio preservation
442
+ map_aspect = 360.0 / 180.0 # 2.0
443
+ plot_aspect = plot_width / plot_height
444
+
445
+ if plot_aspect > map_aspect:
446
+ actual_map_width = plot_height * map_aspect
447
+ actual_map_height = plot_height
448
+ h_offset = (plot_width - actual_map_width) / 2
449
+ v_offset = 0
450
+ else:
451
+ actual_map_width = plot_width
452
+ actual_map_height = plot_width / map_aspect
453
+ h_offset = 0
454
+ v_offset = (plot_height - actual_map_height) / 2
455
+
456
+ # Calculate relative position within the plot area
457
+ x_in_plot = x - left_margin
458
+ y_in_plot = y - top_margin
459
+
460
+ # Check if click is within the actual map bounds
461
+ if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or
462
+ y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height):
463
+ return None, None, None, "Click outside map area. Please click on the map."
464
+
465
+ # Calculate relative position within the map (0 to 1)
466
+ x_rel = (x_in_plot - h_offset) / actual_map_width
467
+ y_rel = (y_in_plot - v_offset) / actual_map_height
468
+
469
+ # Clamp to [0, 1]
470
+ x_rel = max(0, min(1, x_rel))
471
+ y_rel = max(0, min(1, y_rel))
472
+
473
+ # Convert to geographic coordinates
474
+ lon = x_rel * 360 - 180
475
+ lat = 90 - y_rel * 180
476
+
477
+ # Find nearest point in df_vis if available
478
+ pid = ""
479
+ if df_vis is not None:
480
+ dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2
481
+ min_idx = dists.idxmin()
482
+ nearest_row = df_vis.loc[min_idx]
483
+
484
+ if dists[min_idx] < 25:
485
+ lat = nearest_row['centre_lat']
486
+ lon = nearest_row['centre_lon']
487
+ pid = nearest_row['product_id']
488
+
489
+ except Exception as e:
490
+ print(f"Error handling click: {e}")
491
+ import traceback
492
+ traceback.print_exc()
493
+ return None, None, None, f"Error: {e}"
494
+
495
+ return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})"
496
+
497
+ def download_image_by_location(lat, lon, pid, model_name):
498
+ """Download and return the image at the specified location"""
499
+ if lat is None or lon is None:
500
+ return None, "Please specify coordinates first."
501
+
502
+ model, error = get_active_model(model_name)
503
+ if error:
504
+ return None, error
505
+
506
+ try:
507
+ # Convert to float to ensure proper formatting
508
+ lat = float(lat)
509
+ lon = float(lon)
510
+
511
+ # Find Product ID if not provided
512
+ if not pid:
513
+ df = model.df_embed
514
+ lats = pd.to_numeric(df['centre_lat'], errors='coerce')
515
+ lons = pd.to_numeric(df['centre_lon'], errors='coerce')
516
+ dists = (lats - lat)**2 + (lons - lon)**2
517
+ nearest_idx = dists.idxmin()
518
+ pid = df.loc[nearest_idx, 'product_id']
519
+
520
+ # Download image
521
+ img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True)
522
+
523
+ if img_384 is None:
524
+ return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})"
525
+
526
+ return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})"
527
+
528
+ except Exception as e:
529
+ import traceback
530
+ traceback.print_exc()
531
+ return None, f"Error: {str(e)}"
532
+
533
+ def reset_to_global_map():
534
+ """Reset the map to the initial global distribution view"""
535
+ img = None
536
+ df_vis = None
537
+ if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
538
+ img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
539
+ else:
540
+ img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
541
+
542
+ return gr.update(value=img, visible=True), [img], df_vis
543
+
544
+ def format_results_to_text(results):
545
+ if not results:
546
+ return "No results found."
547
+
548
+ txt = f"Top {len(results)} Retrieval Results\n"
549
+ txt += "=" * 30 + "\n\n"
550
+ for i, res in enumerate(results):
551
+ txt += f"Rank: {i+1}\n"
552
+ txt += f"Product ID: {res['id']}\n"
553
+ txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n"
554
+ txt += f"Similarity Score: {res['score']:.6f}\n"
555
+ txt += "-" * 30 + "\n"
556
+ return txt
557
+
558
+ def save_plot(figs):
559
+ if figs is None:
560
+ return None
561
+ try:
562
+ # If it's a single image (initial state), save as png
563
+ if isinstance(figs, PILImage.Image):
564
+ fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
565
+ os.close(fd)
566
+ figs.save(path)
567
+ return path
568
+
569
+ # If it's a list/tuple of images [map_img, results_img]
570
+ if isinstance(figs, (list, tuple)):
571
+ # If only one image in list, save as PNG
572
+ if len(figs) == 1 and isinstance(figs[0], PILImage.Image):
573
+ fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
574
+ os.close(fd)
575
+ figs[0].save(path)
576
+ return path
577
+
578
+ fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_')
579
+ os.close(fd)
580
+
581
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
582
+ # Save Map
583
+ if figs[0] is not None:
584
+ map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png')
585
+ figs[0].save(map_path)
586
+ zipf.write(map_path, arcname='map_distribution.png')
587
+
588
+ # Save Results
589
+ if len(figs) > 1 and figs[1] is not None:
590
+ res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png')
591
+ figs[1].save(res_path)
592
+ zipf.write(res_path, arcname='retrieval_results.png')
593
+
594
+ # Save Results Text
595
+ if len(figs) > 2 and figs[2] is not None:
596
+ txt_path = os.path.join(tempfile.gettempdir(), 'results.txt')
597
+ with open(txt_path, 'w', encoding='utf-8') as f:
598
+ f.write(figs[2])
599
+ zipf.write(txt_path, arcname='results.txt')
600
+
601
+ return zip_path
602
+
603
+ # Fallback for Plotly figure (if any)
604
+ # Create a temporary file
605
+ fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_')
606
+ os.close(fd)
607
+
608
+ # Write to the temporary file
609
+ figs.write_html(path)
610
+ return path
611
+ except Exception as e:
612
+ print(f"Error saving: {e}")
613
+ return None
614
+
615
+ # Gradio Blocks Interface
616
+ with gr.Blocks(title="EarthEmbeddingExplorer") as demo:
617
+ gr.Markdown("# EarthEmbeddingExplorer")
618
+ gr.HTML("""
619
+ <div style="font-size: 1.2em;">
620
+ EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images.
621
+ </div>
622
+
623
+ """)
624
+
625
+ with gr.Row():
626
+ with gr.Column(scale=4):
627
+ with gr.Tabs():
628
+ with gr.TabItem("Text Search") as tab_text:
629
+ model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model")
630
+ query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier")
631
+
632
+ gr.Examples(
633
+ examples=[
634
+ ["a satellite image of a river around a city"],
635
+ ["a satellite image of a rainforest"],
636
+ ["a satellite image of a slum"],
637
+ ["a satellite image of a glacier"],
638
+ ["a satellite image of snow covered mountains"]
639
+ ],
640
+ inputs=[query_input],
641
+ label="Text Examples"
642
+ )
643
+
644
+ search_btn = gr.Button("Search by Text", variant="primary")
645
+
646
+ with gr.TabItem("Image Search") as tab_image:
647
+ model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP", "DINOv2"], value="FarSLIP", label="Model")
648
+
649
+ gr.Markdown("### Option 1: Upload or Select Image")
650
+ image_input = gr.Image(type="pil", label="Upload Image")
651
+
652
+ gr.Examples(
653
+ examples=[
654
+ ["./examples/example1.png"],
655
+ ["./examples/example2.png"],
656
+ ["./examples/example3.png"]
657
+ ],
658
+ inputs=[image_input],
659
+ label="Image Examples"
660
+ )
661
+
662
+ gr.Markdown("### Option 2: Click Map or Enter Coordinates")
663
+ btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
664
+
665
+ with gr.Row():
666
+ img_lat = gr.Number(label="Latitude", interactive=True)
667
+ img_lon = gr.Number(label="Longitude", interactive=True)
668
+
669
+ img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
670
+ img_click_status = gr.Markdown("")
671
+
672
+ btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary")
673
+
674
+ search_img_btn = gr.Button("Search by Image", variant="primary")
675
+
676
+ with gr.TabItem("Location Search") as tab_location:
677
+ gr.Markdown("Search using **SatCLIP** location encoder.")
678
+
679
+ gr.Markdown("### Click Map or Enter Coordinates")
680
+ btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
681
+
682
+ with gr.Row():
683
+ lat_input = gr.Number(label="Latitude", value=30.0, interactive=True)
684
+ lon_input = gr.Number(label="Longitude", value=120.0, interactive=True)
685
+
686
+ loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
687
+ loc_click_status = gr.Markdown("")
688
+
689
+ gr.Examples(
690
+ examples=[
691
+ [30.32, 120.15],
692
+ [40.7128, -74.0060],
693
+ [24.65, 46.71],
694
+ [-3.4653, -62.2159],
695
+ [64.4, 16.8]
696
+ ],
697
+ inputs=[lat_input, lon_input],
698
+ label="Location Examples"
699
+ )
700
+
701
+ search_loc_btn = gr.Button("Search by Location", variant="primary")
702
+
703
+ threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)")
704
+ status_output = gr.Textbox(label="Status", lines=10)
705
+ save_btn = gr.Button("Download Result")
706
+ download_file = gr.File(label="Zipped Results", height=40)
707
+
708
+ with gr.Column(scale=6):
709
+ plot_map = gr.Image(
710
+ label="Geographical Distribution",
711
+ type="pil",
712
+ interactive=False,
713
+ height=400,
714
+ width=800,
715
+ visible=True
716
+ )
717
+ plot_map_interactive = gr.Plot(
718
+ label="Geographical Distribution (Interactive)",
719
+ visible=False
720
+ )
721
+ results_plot = gr.Image(label="Top 5 Matched Images", type="pil")
722
+ gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto")
723
+
724
+ current_fig = gr.State()
725
+ map_data_state = gr.State()
726
+
727
+ # Initial Load
728
+ demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive])
729
+
730
+ # Reset Map Buttons
731
+ btn_reset_map_img.click(
732
+ fn=reset_to_global_map,
733
+ outputs=[plot_map, current_fig, map_data_state]
734
+ )
735
+
736
+ btn_reset_map_loc.click(
737
+ fn=reset_to_global_map,
738
+ outputs=[plot_map, current_fig, map_data_state]
739
+ )
740
+
741
+ # Map Click Event - updates Image Search coordinates
742
+ plot_map.select(
743
+ fn=handle_map_click,
744
+ inputs=[map_data_state],
745
+ outputs=[img_lat, img_lon, img_pid, img_click_status]
746
+ )
747
+
748
+ # Map Click Event - also updates Location Search coordinates
749
+ plot_map.select(
750
+ fn=handle_map_click,
751
+ inputs=[map_data_state],
752
+ outputs=[lat_input, lon_input, loc_pid, loc_click_status]
753
+ )
754
+
755
+ # Download Image by Geolocation
756
+ btn_download_img.click(
757
+ fn=download_image_by_location,
758
+ inputs=[img_lat, img_lon, img_pid, model_selector_img],
759
+ outputs=[image_input, img_click_status]
760
+ )
761
+
762
+ # Search Event (Text)
763
+ search_btn.click(
764
+ fn=search_text,
765
+ inputs=[query_input, threshold_slider, model_selector_text],
766
+ outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
767
+ )
768
+
769
+ # Search Event (Image)
770
+ search_img_btn.click(
771
+ fn=search_image,
772
+ inputs=[image_input, threshold_slider, model_selector_img],
773
+ outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
774
+ )
775
+
776
+ # Search Event (Location)
777
+ search_loc_btn.click(
778
+ fn=search_location,
779
+ inputs=[lat_input, lon_input, threshold_slider],
780
+ outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
781
+ )
782
+
783
+ # Save Event
784
+ save_btn.click(
785
+ fn=save_plot,
786
+ inputs=[current_fig],
787
+ outputs=[download_file]
788
+ )
789
+
790
+ # Tab Selection Events
791
+ def show_static_map():
792
+ return gr.update(visible=True), gr.update(visible=False)
793
+
794
+ tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
795
+ tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
796
+ tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
797
+
798
+ if __name__ == "__main__":
799
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
compute_embeddings.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/env python3
3
+ """
4
+ Compute Embeddings for Major-TOM Sentinel-2 Images
5
+
6
+ This script generates embeddings for Sentinel-2 imagery using various models:
7
+ - DINOv2: Vision Transformer trained with self-supervised learning
8
+ - SigLIP: Vision-Language model with sigmoid loss
9
+ - FarSLIP: Remote sensing fine-tuned CLIP
10
+ - SatCLIP: Satellite imagery CLIP with location awareness
11
+
12
+ Usage:
13
+ python compute_embeddings.py --model dinov2 --device cuda:1
14
+ python compute_embeddings.py --model siglip --device cuda:5
15
+ python compute_embeddings.py --model satclip --device cuda:3
16
+ python compute_embeddings.py --model farslip --device cuda:4
17
+
18
+ Author: Generated by Copilot
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import argparse
24
+ import logging
25
+ from pathlib import Path
26
+ from datetime import datetime
27
+
28
+ import numpy as np
29
+ import pandas as pd
30
+ import torch
31
+ from PIL import Image
32
+ from tqdm.auto import tqdm
33
+
34
+ # Add project root to path
35
+ PROJECT_ROOT = Path(__file__).parent.absolute()
36
+ if str(PROJECT_ROOT) not in sys.path:
37
+ sys.path.insert(0, str(PROJECT_ROOT))
38
+
39
+ from models.load_config import load_and_process_config
40
+
41
+
42
+ # =============================================================================
43
+ # Configuration
44
+ # =============================================================================
45
+ METADATA_PATH = Path("/data1/zyj/Core-S2L2A-249k/Core_S2L2A_249k_crop_384x384_metadata.parquet")
46
+ IMAGE_PARQUET_DIR = Path("/data1/zyj/Core-S2L2A-249k/images")
47
+ OUTPUT_BASE_DIR = Path("/data1/zyj/EarthEmbeddings/Core-S2L2A-249k")
48
+
49
+ # Columns to remove from output
50
+ COLUMNS_TO_REMOVE = ['cloud_cover', 'nodata', 'geometry_wkt', 'bands', 'image_shape', 'image_dtype']
51
+
52
+ # Columns to rename
53
+ COLUMNS_RENAME = {'crs': 'utm_crs'}
54
+
55
+ # Pixel bbox for center 384x384 crop from 1068x1068 original
56
+ # (1068 - 384) / 2 = 342
57
+ PIXEL_BBOX = [342, 342, 726, 726] # [x_min, y_min, x_max, y_max]
58
+
59
+ # Model output paths
60
+ MODEL_OUTPUT_PATHS = {
61
+ 'dinov2': OUTPUT_BASE_DIR / 'dinov2' / 'DINOv2_crop_384x384.parquet',
62
+ 'siglip': OUTPUT_BASE_DIR / 'siglip' / 'SigLIP_crop_384x384.parquet',
63
+ 'farslip': OUTPUT_BASE_DIR / 'farslip' / 'FarSLIP_crop_384x384.parquet',
64
+ 'satclip': OUTPUT_BASE_DIR / 'satclip' / 'SatCLIP_crop_384x384.parquet',
65
+ }
66
+
67
+ # Batch sizes for different models
68
+ BATCH_SIZES = {
69
+ 'dinov2': 64,
70
+ 'siglip': 64,
71
+ 'farslip': 64,
72
+ 'satclip': 128,
73
+ }
74
+
75
+
76
+ # =============================================================================
77
+ # Setup Logging
78
+ # =============================================================================
79
+ def setup_logging(model_name: str):
80
+ """Configure logging to both file and console."""
81
+ log_dir = PROJECT_ROOT / "logs"
82
+ log_dir.mkdir(parents=True, exist_ok=True)
83
+ log_file = log_dir / f"compute_embeddings_{model_name}.log"
84
+
85
+ logging.basicConfig(
86
+ level=logging.INFO,
87
+ format="%(asctime)s [%(levelname)s] %(message)s",
88
+ handlers=[
89
+ logging.FileHandler(log_file),
90
+ logging.StreamHandler(sys.stdout)
91
+ ]
92
+ )
93
+ return logging.getLogger(__name__)
94
+
95
+
96
+ # =============================================================================
97
+ # Image Preprocessing Functions
98
+ # =============================================================================
99
+ def decode_image_bytes(row) -> np.ndarray:
100
+ """
101
+ Decode image bytes from parquet row to numpy array.
102
+
103
+ Args:
104
+ row: pandas Series with 'image_bytes', 'image_shape', 'image_dtype'
105
+
106
+ Returns:
107
+ np.ndarray of shape (H, W, 12) with uint16 values
108
+ """
109
+ shape = tuple(map(int, row['image_shape']))
110
+ dtype = np.dtype(row['image_dtype'])
111
+ img_flat = np.frombuffer(row['image_bytes'], dtype=dtype)
112
+ return img_flat.reshape(shape)
113
+
114
+
115
+ def extract_rgb_image(img_array: np.ndarray, clip_max: float = 4000.0) -> Image.Image:
116
+ """
117
+ Extract RGB channels from 12-band Sentinel-2 array.
118
+
119
+ Sentinel-2 Bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12]
120
+ RGB Mapping: R=B04(idx 3), G=B03(idx 2), B=B02(idx 1)
121
+
122
+ Args:
123
+ img_array: numpy array of shape (H, W, 12)
124
+ clip_max: Value to clip reflectance data for visualization
125
+
126
+ Returns:
127
+ PIL.Image: RGB image
128
+ """
129
+ # Select RGB Channels: R=B04(3), G=B03(2), B=B02(1)
130
+ rgb_bands = img_array[:, :, [3, 2, 1]].astype(np.float32)
131
+
132
+ # Normalize and Clip
133
+ rgb_normalized = np.clip(rgb_bands / clip_max, 0, 1)
134
+
135
+ # Convert to 8-bit
136
+ rgb_uint8 = (rgb_normalized * 255).astype(np.uint8)
137
+
138
+ return Image.fromarray(rgb_uint8)
139
+
140
+
141
+ # =============================================================================
142
+ # Model Loading Functions
143
+ # =============================================================================
144
+ def load_model(model_name: str, device: str, config: dict):
145
+ """
146
+ Load the specified model.
147
+
148
+ Args:
149
+ model_name: One of 'dinov2', 'siglip', 'farslip', 'satclip'
150
+ device: Device string like 'cuda:0' or 'cpu'
151
+ config: Configuration dictionary from local.yaml
152
+
153
+ Returns:
154
+ Model instance
155
+ """
156
+ logger = logging.getLogger(__name__)
157
+
158
+ if model_name == 'dinov2':
159
+ from models.dinov2_model import DINOv2Model
160
+ model_config = config.get('dinov2', {})
161
+ model = DINOv2Model(
162
+ ckpt_path=model_config.get('ckpt_path', '/data1/zyj/checkpoints/dinov2-large'),
163
+ model_name='facebook/dinov2-large',
164
+ embedding_path=None, # We're generating, not loading
165
+ device=device
166
+ )
167
+ logger.info(f"DINOv2 model loaded on {device}")
168
+ return model
169
+
170
+ elif model_name == 'siglip':
171
+ from models.siglip_model import SigLIPModel
172
+ model_config = config.get('siglip', {})
173
+ model = SigLIPModel(
174
+ ckpt_path=model_config.get('ckpt_path', './checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin'),
175
+ model_name='ViT-SO400M-14-SigLIP-384',
176
+ tokenizer_path=model_config.get('tokenizer_path', './checkpoints/ViT-SO400M-14-SigLIP-384'),
177
+ embedding_path=None,
178
+ device=device
179
+ )
180
+ # Disable embedding loading since we set path to None
181
+ model.df_embed = None
182
+ model.image_embeddings = None
183
+ logger.info(f"SigLIP model loaded on {device}")
184
+ return model
185
+
186
+ elif model_name == 'farslip':
187
+ from models.farslip_model import FarSLIPModel
188
+ model_config = config.get('farslip', {})
189
+ model = FarSLIPModel(
190
+ ckpt_path=model_config.get('ckpt_path', './checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt'),
191
+ model_name='ViT-B-16',
192
+ embedding_path=None,
193
+ device=device
194
+ )
195
+ logger.info(f"FarSLIP model loaded on {device}")
196
+ return model
197
+
198
+ elif model_name == 'satclip':
199
+ from models.satclip_ms_model import SatCLIPMSModel
200
+ model_config = config.get('satclip', {})
201
+ model = SatCLIPMSModel(
202
+ ckpt_path=model_config.get('ckpt_path', './checkpoints/SatCLIP/satclip-vit16-l40.ckpt'),
203
+ embedding_path=None,
204
+ device=device
205
+ )
206
+ logger.info(f"SatCLIP-MS model loaded on {device}")
207
+ return model
208
+
209
+ else:
210
+ raise ValueError(f"Unknown model: {model_name}")
211
+
212
+
213
+ # =============================================================================
214
+ # Embedding Computation Functions
215
+ # =============================================================================
216
+ def compute_embedding_single(model, model_name: str, img_array: np.ndarray) -> np.ndarray:
217
+ """
218
+ Compute embedding for a single image.
219
+
220
+ Args:
221
+ model: Model instance
222
+ model_name: Model identifier
223
+ img_array: numpy array of shape (H, W, 12)
224
+
225
+ Returns:
226
+ np.ndarray: 1D embedding vector
227
+ """
228
+ if model_name in ['dinov2', 'siglip', 'farslip']:
229
+ # These models use RGB input
230
+ rgb_img = extract_rgb_image(img_array)
231
+ feature = model.encode_image(rgb_img)
232
+ if feature is not None:
233
+ return feature.cpu().numpy().flatten()
234
+ return None
235
+
236
+ elif model_name == 'satclip':
237
+ # SatCLIP can use multi-spectral input directly
238
+ feature = model.encode_image(img_array, is_multispectral=True)
239
+ if feature is not None:
240
+ return feature.cpu().numpy().flatten()
241
+ return None
242
+
243
+ return None
244
+
245
+
246
+ def compute_embedding_batch(model, model_name: str, img_arrays: list) -> list:
247
+ """
248
+ Compute embeddings for a batch of images.
249
+ Falls back to single-image processing if batch method unavailable.
250
+
251
+ Args:
252
+ model: Model instance
253
+ model_name: Model identifier
254
+ img_arrays: List of numpy arrays of shape (H, W, 12)
255
+
256
+ Returns:
257
+ List of 1D embedding vectors (numpy arrays), None for failed items
258
+ """
259
+ n_images = len(img_arrays)
260
+
261
+ if model_name in ['dinov2', 'siglip', 'farslip']:
262
+ # These models use RGB input
263
+ rgb_imgs = [extract_rgb_image(arr) for arr in img_arrays]
264
+
265
+ # Try batch encoding first
266
+ if hasattr(model, 'encode_images'):
267
+ try:
268
+ features = model.encode_images(rgb_imgs)
269
+ if features is not None:
270
+ return [features[i].cpu().numpy().flatten() for i in range(len(features))]
271
+ except Exception:
272
+ pass # Fall back to single processing
273
+
274
+ # Fall back to single image encoding
275
+ results = []
276
+ for img in rgb_imgs:
277
+ try:
278
+ feature = model.encode_image(img)
279
+ if feature is not None:
280
+ results.append(feature.cpu().numpy().flatten())
281
+ else:
282
+ results.append(None)
283
+ except Exception:
284
+ results.append(None)
285
+ return results
286
+
287
+ elif model_name == 'satclip':
288
+ # SatCLIP uses multi-spectral input
289
+ if hasattr(model, 'encode_images'):
290
+ try:
291
+ features = model.encode_images(img_arrays, is_multispectral=True)
292
+ if features is not None:
293
+ return [features[i].cpu().numpy().flatten() for i in range(len(features))]
294
+ except Exception:
295
+ pass # Fall back to single processing
296
+
297
+ # Fall back to single image encoding
298
+ results = []
299
+ for arr in img_arrays:
300
+ try:
301
+ feature = model.encode_image(arr, is_multispectral=True)
302
+ if feature is not None:
303
+ results.append(feature.cpu().numpy().flatten())
304
+ else:
305
+ results.append(None)
306
+ except Exception:
307
+ results.append(None)
308
+ return results
309
+
310
+ return [None] * n_images
311
+
312
+ # def process_parquet_file(
313
+ # file_path: Path,
314
+ # model,
315
+ # model_name: str,
316
+ # batch_size: int = 64
317
+ # ) -> pd.DataFrame:
318
+ # """
319
+ # Process a single parquet file and generate embeddings.
320
+
321
+ # Args:
322
+ # file_path: Path to input parquet file
323
+ # model: Model instance
324
+ # model_name: Model identifier
325
+ # batch_size: Batch size for processing
326
+
327
+ # Returns:
328
+ # DataFrame with embeddings
329
+ # """
330
+ # logger = logging.getLogger(__name__)
331
+
332
+ # # Load data
333
+ # df = pd.read_parquet(file_path)
334
+
335
+ # embeddings_list = []
336
+ # valid_indices = []
337
+
338
+ # # Process in batches (for future batch optimization)
339
+ # for idx, row in df.iterrows():
340
+ # try:
341
+ # # Decode image
342
+ # img_array = decode_image_bytes(row)
343
+
344
+ # # Compute embedding
345
+ # embedding = compute_embedding_single(model, model_name, img_array)
346
+
347
+ # if embedding is not None:
348
+ # embeddings_list.append(embedding)
349
+ # valid_indices.append(idx)
350
+
351
+ # except Exception as e:
352
+ # logger.warning(f"Error processing row {idx}: {e}")
353
+ # continue
354
+
355
+ # if not embeddings_list:
356
+ # logger.warning(f"No valid embeddings for {file_path.name}")
357
+ # return None
358
+
359
+ # # Build result DataFrame
360
+ # result_df = df.loc[valid_indices].copy()
361
+
362
+ # # Remove unwanted columns
363
+ # cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns]
364
+ # if cols_to_drop:
365
+ # result_df = result_df.drop(columns=cols_to_drop)
366
+
367
+ # # Remove image_bytes (large binary data)
368
+ # if 'image_bytes' in result_df.columns:
369
+ # result_df = result_df.drop(columns=['image_bytes'])
370
+
371
+ # # Remove geometry column (binary)
372
+ # if 'geometry' in result_df.columns:
373
+ # result_df = result_df.drop(columns=['geometry'])
374
+
375
+ # # Rename columns
376
+ # result_df = result_df.rename(columns=COLUMNS_RENAME)
377
+
378
+ # # Add pixel_bbox
379
+ # result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df)
380
+
381
+ # # Add embedding
382
+ # result_df['embedding'] = embeddings_list
383
+
384
+ # return result_df
385
+
386
+ def process_parquet_file(
387
+ file_path: Path,
388
+ model,
389
+ model_name: str,
390
+ batch_size: int = 64
391
+ ) -> pd.DataFrame:
392
+ """
393
+ Process a single parquet file and generate embeddings using batch processing.
394
+
395
+ Args:
396
+ file_path: Path to input parquet file
397
+ model: Model instance
398
+ model_name: Model identifier
399
+ batch_size: Batch size for processing
400
+
401
+ Returns:
402
+ DataFrame with embeddings
403
+ """
404
+ logger = logging.getLogger(__name__)
405
+
406
+ # Load data
407
+ df = pd.read_parquet(file_path)
408
+ n_rows = len(df)
409
+
410
+ embeddings_list = [None] * n_rows
411
+ valid_mask = [False] * n_rows
412
+
413
+ # Process in batches
414
+ for batch_start in range(0, n_rows, batch_size):
415
+ batch_end = min(batch_start + batch_size, n_rows)
416
+ batch_indices = list(range(batch_start, batch_end))
417
+
418
+ # Decode images for this batch
419
+ batch_arrays = []
420
+ batch_valid_indices = []
421
+
422
+ for idx in batch_indices:
423
+ try:
424
+ row = df.iloc[idx]
425
+ img_array = decode_image_bytes(row)
426
+ batch_arrays.append(img_array)
427
+ batch_valid_indices.append(idx)
428
+ except Exception as e:
429
+ logger.warning(f"Error decoding row {idx}: {e}")
430
+ continue
431
+
432
+ if not batch_arrays:
433
+ continue
434
+
435
+ # Compute embeddings for this batch
436
+ try:
437
+ batch_embeddings = compute_embedding_batch(model, model_name, batch_arrays)
438
+
439
+ # Store results
440
+ for i, idx in enumerate(batch_valid_indices):
441
+ if batch_embeddings[i] is not None:
442
+ embeddings_list[idx] = batch_embeddings[i]
443
+ valid_mask[idx] = True
444
+
445
+ except Exception as e:
446
+ logger.warning(f"Error computing batch embeddings: {e}")
447
+ # Fall back to single image processing for this batch
448
+ for i, idx in enumerate(batch_valid_indices):
449
+ try:
450
+ embedding = compute_embedding_single(model, model_name, batch_arrays[i])
451
+ if embedding is not None:
452
+ embeddings_list[idx] = embedding
453
+ valid_mask[idx] = True
454
+ except Exception as inner_e:
455
+ logger.warning(f"Error processing row {idx}: {inner_e}")
456
+ continue
457
+
458
+ # Filter to valid rows only
459
+ valid_indices = [i for i, v in enumerate(valid_mask) if v]
460
+
461
+ if not valid_indices:
462
+ logger.warning(f"No valid embeddings for {file_path.name}")
463
+ return None
464
+
465
+ # Build result DataFrame
466
+ result_df = df.iloc[valid_indices].copy()
467
+ valid_embeddings = [embeddings_list[i] for i in valid_indices]
468
+
469
+ # Remove unwanted columns
470
+ cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns]
471
+ if cols_to_drop:
472
+ result_df = result_df.drop(columns=cols_to_drop)
473
+
474
+ # Remove image_bytes (large binary data)
475
+ if 'image_bytes' in result_df.columns:
476
+ result_df = result_df.drop(columns=['image_bytes'])
477
+
478
+ # Remove geometry column (binary)
479
+ if 'geometry' in result_df.columns:
480
+ result_df = result_df.drop(columns=['geometry'])
481
+
482
+ # Rename columns
483
+ result_df = result_df.rename(columns=COLUMNS_RENAME)
484
+
485
+ # Add pixel_bbox
486
+ result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df)
487
+
488
+ # Add embedding
489
+ result_df['embedding'] = valid_embeddings
490
+
491
+ return result_df
492
+
493
+ # =============================================================================
494
+ # Main Processing Pipeline
495
+ # =============================================================================
496
+ def main():
497
+ parser = argparse.ArgumentParser(description='Compute embeddings for Major-TOM images')
498
+ parser.add_argument('--model', type=str, required=True,
499
+ choices=['dinov2', 'siglip', 'farslip', 'satclip'],
500
+ help='Model to use for embedding computation')
501
+ parser.add_argument('--device', type=str, default='cuda:0',
502
+ help='Device to run on (e.g., cuda:0, cuda:1, cpu)')
503
+ parser.add_argument('--batch-size', type=int, default=None,
504
+ help='Batch size for processing (default: model-specific)')
505
+ parser.add_argument('--max-files', type=int, default=None,
506
+ help='Maximum number of files to process (for testing)')
507
+
508
+ args = parser.parse_args()
509
+
510
+ # Setup logging
511
+ logger = setup_logging(args.model)
512
+
513
+ logger.info("=" * 80)
514
+ logger.info(f"Computing {args.model.upper()} embeddings")
515
+ logger.info(f"Timestamp: {datetime.now().isoformat()}")
516
+ logger.info(f"Device: {args.device}")
517
+ logger.info("=" * 80)
518
+
519
+ # Load configuration
520
+ config = load_and_process_config()
521
+ if config is None:
522
+ logger.warning("No config file found, using default paths")
523
+ config = {}
524
+
525
+ # Determine batch size
526
+ batch_size = args.batch_size or BATCH_SIZES.get(args.model, 64)
527
+ logger.info(f"Batch size: {batch_size}")
528
+
529
+ # Get output path
530
+ output_path = MODEL_OUTPUT_PATHS[args.model]
531
+ output_path.parent.mkdir(parents=True, exist_ok=True)
532
+ logger.info(f"Output path: {output_path}")
533
+
534
+ # Load model
535
+ logger.info(f"Loading {args.model} model...")
536
+ model = load_model(args.model, args.device, config)
537
+
538
+ # Get input files
539
+ parquet_files = sorted(IMAGE_PARQUET_DIR.glob("batch_*.parquet"))
540
+ if args.max_files:
541
+ parquet_files = parquet_files[:args.max_files]
542
+
543
+ logger.info(f"Found {len(parquet_files)} input files")
544
+
545
+ # Process files
546
+ all_results = []
547
+ total_rows = 0
548
+
549
+ for file_path in tqdm(parquet_files, desc=f"Processing {args.model}"):
550
+ try:
551
+ result_df = process_parquet_file(file_path, model, args.model, batch_size)
552
+
553
+ if result_df is not None:
554
+ all_results.append(result_df)
555
+ total_rows += len(result_df)
556
+ logger.info(f"[{file_path.name}] Processed {len(result_df)} rows")
557
+
558
+ except Exception as e:
559
+ logger.error(f"Error processing {file_path.name}: {e}")
560
+ import traceback
561
+ traceback.print_exc()
562
+ continue
563
+
564
+ # Merge and save
565
+ if all_results:
566
+ logger.info("Merging all results...")
567
+ final_df = pd.concat(all_results, ignore_index=True)
568
+
569
+ # Validate columns
570
+ logger.info(f"Final columns: {list(final_df.columns)}")
571
+
572
+ # Check for removed columns
573
+ removed = [c for c in COLUMNS_TO_REMOVE if c in final_df.columns]
574
+ if removed:
575
+ logger.warning(f"Columns still present that should be removed: {removed}")
576
+ else:
577
+ logger.info("✓ All unwanted columns removed")
578
+
579
+ # Check for renamed columns
580
+ if 'utm_crs' in final_df.columns and 'crs' not in final_df.columns:
581
+ logger.info("✓ Column 'crs' renamed to 'utm_crs'")
582
+
583
+ # Check for pixel_bbox
584
+ if 'pixel_bbox' in final_df.columns:
585
+ logger.info("✓ Column 'pixel_bbox' added")
586
+
587
+ # Save
588
+ logger.info(f"Saving to {output_path}...")
589
+ final_df.to_parquet(output_path, index=False)
590
+
591
+ logger.info(f"=" * 80)
592
+ logger.info(f"Processing complete!")
593
+ logger.info(f"Total rows: {len(final_df):,}")
594
+ logger.info(f"Embedding dimension: {len(final_df['embedding'].iloc[0])}")
595
+ logger.info(f"Output file: {output_path}")
596
+ logger.info(f"=" * 80)
597
+
598
+ else:
599
+ logger.error("No data processed!")
600
+ return 1
601
+
602
+ return 0
603
+
604
+
605
+ if __name__ == "__main__":
606
+ sys.exit(main())
configs/huggingface.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ siglip:
2
+ ckpt_path: "hf"
3
+ model_name: "ViT-SO400M-14-SigLIP-384"
4
+ tokenizer_path: "hf"
5
+ embedding_path: "hf://ML4RS-Anonymous/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet.parquet"
6
+ farslip:
7
+ ckpt_path: "hf"
8
+ model_name: "ViT-B-16"
9
+ embedding_path: "hf://ML4RS-Anonymous/EarthEmbeddings/Core-S2L2A-249k/farslip/FarSLIP_crop_384x384.parquet.parquet"
10
+ satclip:
11
+ ckpt_path: "hf"
12
+ embedding_path: "hf://ML4RS-Anonymous/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet.parquet"
13
+ dinov2:
14
+ ckpt_path: "hf"
15
+ embedding_path: "hf://ML4RS-Anonymous/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet.parquet"
countries.geo.json ADDED
The diff for this file is too large to render. See raw diff
 
data_utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fsspec
2
+ import pyarrow.parquet as pq
3
+ import numpy as np
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from rasterio.io import MemoryFile
7
+ import matplotlib.pyplot as plt
8
+ import cartopy.crs as ccrs
9
+ import cartopy.io.img_tiles as cimgt
10
+ from matplotlib.patches import Rectangle
11
+ import math
12
+ from matplotlib.figure import Figure
13
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
14
+
15
+
16
+ def crop_center(img_array, cropx, cropy):
17
+ y, x, c = img_array.shape
18
+ startx = x // 2 - (cropx // 2)
19
+ starty = y // 2 - (cropy // 2)
20
+ return img_array[starty:starty+cropy, startx:startx+cropx]
21
+
22
+ def read_tif_bytes(tif_bytes):
23
+ with MemoryFile(tif_bytes) as mem_f:
24
+ with mem_f.open(driver='GTiff') as f:
25
+ return f.read().squeeze()
26
+
27
+ def read_row_memory(row_dict, columns=["thumbnail"]):
28
+ url = row_dict['parquet_url']
29
+ row_idx = row_dict['parquet_row']
30
+
31
+ fs_options = {
32
+ "cache_type": "readahead",
33
+ "block_size": 5 * 1024 * 1024
34
+ }
35
+
36
+ with fsspec.open(url, mode='rb', **fs_options) as f:
37
+ with pq.ParquetFile(f) as pf:
38
+ table = pf.read_row_group(row_idx, columns=columns)
39
+
40
+ row_output = {}
41
+ for col in columns:
42
+ col_data = table[col][0].as_py()
43
+
44
+ if col != 'thumbnail':
45
+ row_output[col] = read_tif_bytes(col_data)
46
+ else:
47
+ stream = BytesIO(col_data)
48
+ row_output[col] = Image.open(stream)
49
+
50
+ return row_output
51
+
52
+ def download_and_process_image(product_id, df_source=None, verbose=True):
53
+ if df_source is None:
54
+ if verbose: print("❌ Error: No DataFrame provided.")
55
+ return None, None
56
+
57
+ row_subset = df_source[df_source['product_id'] == product_id]
58
+ if len(row_subset) == 0:
59
+ if verbose: print(f"❌ Error: Product ID {product_id} not found in DataFrame.")
60
+ return None, None
61
+
62
+ row_dict = row_subset.iloc[0].to_dict()
63
+
64
+ if 'parquet_url' in row_dict:
65
+ url = row_dict['parquet_url']
66
+ if 'huggingface.co' in url:
67
+ row_dict['parquet_url'] = url.replace('https://huggingface.co', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
68
+ elif 'hf-mirror.com' in url:
69
+ row_dict['parquet_url'] = url.replace('https://hf-mirror.com', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
70
+ else:
71
+ if verbose: print("❌ Error: 'parquet_url' missing in metadata.")
72
+ return None, None
73
+
74
+ if verbose: print(f"⬇️ Fetching data for {product_id} from {row_dict['parquet_url']}...")
75
+
76
+ try:
77
+ bands_data = read_row_memory(row_dict, columns=['B04', 'B03', 'B02'])
78
+
79
+ if not all(b in bands_data for b in ['B04', 'B03', 'B02']):
80
+ if verbose: print(f"❌ Error: Missing bands in fetched data for {product_id}")
81
+ return None, None
82
+
83
+ rgb_img = np.stack([bands_data['B04'], bands_data['B03'], bands_data['B02']], axis=-1)
84
+
85
+ if verbose:
86
+ print(f"Raw RGB stats: Min={rgb_img.min()}, Max={rgb_img.max()}, Mean={rgb_img.mean()}, Dtype={rgb_img.dtype}")
87
+
88
+ # Check if data is already 0-255 or 0-1
89
+ if rgb_img.max() <= 255:
90
+ # Assume it might be uint8 or scaled
91
+ pass
92
+
93
+ rgb_norm = (2.5 * (rgb_img.astype(float) / 10000.0)).clip(0, 1)
94
+ rgb_uint8 = (rgb_norm * 255).astype(np.uint8)
95
+
96
+ if verbose:
97
+ print(f"Processed RGB stats: Min={rgb_uint8.min()}, Max={rgb_uint8.max()}, Mean={rgb_uint8.mean()}")
98
+
99
+ img_full = Image.fromarray(rgb_uint8)
100
+
101
+ if rgb_uint8.shape[0] >= 384 and rgb_uint8.shape[1] >= 384:
102
+ cropped_array = crop_center(rgb_uint8, 384, 384)
103
+ img_384 = Image.fromarray(cropped_array)
104
+ else:
105
+ if verbose: print(f"⚠️ Image too small {rgb_uint8.shape}, resizing to 384x384.")
106
+ img_384 = img_full.resize((384, 384))
107
+
108
+ if verbose: print(f"✅ Successfully processed {product_id}")
109
+ return img_384, img_full
110
+
111
+ except Exception as e:
112
+ if verbose: print(f"❌ Error processing {product_id}: {e}")
113
+ import traceback
114
+ traceback.print_exc()
115
+ return None, None
116
+
117
+ # Define Esri Imagery Class
118
+ class EsriImagery(cimgt.GoogleTiles):
119
+ def _image_url(self, tile):
120
+ x, y, z = tile
121
+ return f'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
122
+
123
+ from PIL import Image, ImageDraw, ImageFont
124
+
125
+ def get_placeholder_image(text="Image Unavailable", size=(384, 384)):
126
+ img = Image.new('RGB', size, color=(200, 200, 200))
127
+ d = ImageDraw.Draw(img)
128
+ try:
129
+ # Try to load a default font
130
+ font = ImageFont.load_default()
131
+ except:
132
+ font = None
133
+
134
+ # Draw text in center (rough approximation)
135
+ # For better centering we would need font metrics, but simple is fine here
136
+ d.text((20, size[1]//2), text, fill=(0, 0, 0), font=font)
137
+ return img
138
+
139
+ def get_esri_satellite_image(lat, lon, score=None, rank=None, query=None):
140
+ """
141
+ Generates a satellite image visualization using Esri World Imagery via Cartopy.
142
+ Matches the style of the provided notebook.
143
+ Uses OO Matplotlib API for thread safety.
144
+ """
145
+ try:
146
+ imagery = EsriImagery()
147
+
148
+ # Create figure using OO API
149
+ fig = Figure(figsize=(5, 5), dpi=100)
150
+ canvas = FigureCanvasAgg(fig)
151
+ ax = fig.add_subplot(1, 1, 1, projection=imagery.crs)
152
+
153
+ # Set extent to approx 10km x 10km around the point
154
+ extent_deg = 0.05
155
+ ax.set_extent([lon - extent_deg, lon + extent_deg, lat - extent_deg, lat + extent_deg], crs=ccrs.PlateCarree())
156
+
157
+ # Add the imagery
158
+ ax.add_image(imagery, 14)
159
+
160
+ # Add a marker for the center
161
+ ax.plot(lon, lat, marker='+', color='yellow', markersize=12, markeredgewidth=2, transform=ccrs.PlateCarree())
162
+
163
+ # Add Bounding Box (3840m x 3840m)
164
+ box_size_m = 384 * 10 # 3840m
165
+
166
+ # Convert meters to degrees (approx)
167
+ # 1 deg lat = 111320m
168
+ # 1 deg lon = 111320m * cos(lat)
169
+ dlat = (box_size_m / 111320)
170
+ dlon = (box_size_m / (111320 * math.cos(math.radians(lat))))
171
+
172
+ # Bottom-Left corner
173
+ rect_lon = lon - dlon / 2
174
+ rect_lat = lat - dlat / 2
175
+
176
+ # Add Rectangle
177
+ rect = Rectangle((rect_lon, rect_lat), dlon, dlat,
178
+ linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree())
179
+ ax.add_patch(rect)
180
+
181
+ # Title
182
+ title_parts = []
183
+ if query: title_parts.append(f"{query}")
184
+ if rank is not None: title_parts.append(f"Rank {rank}")
185
+ if score is not None: title_parts.append(f"Score: {score:.4f}")
186
+
187
+ ax.set_title("\n".join(title_parts), fontsize=10)
188
+
189
+ # Save to buffer
190
+ buf = BytesIO()
191
+ fig.savefig(buf, format='png', bbox_inches='tight')
192
+ buf.seek(0)
193
+
194
+ return Image.open(buf)
195
+
196
+ except Exception as e:
197
+ # Suppress full traceback for network errors to avoid log spam
198
+ error_msg = str(e)
199
+ if "Connection reset by peer" in error_msg or "Network is unreachable" in error_msg or "urlopen error" in error_msg:
200
+ print(f"⚠️ Network warning: Could not fetch Esri satellite map for ({lat:.4f}, {lon:.4f}). Server might be offline.")
201
+ else:
202
+ print(f"Error generating Esri image for {lat}, {lon}: {e}")
203
+ # Only print traceback for non-network errors
204
+ # import traceback
205
+ # traceback.print_exc()
206
+
207
+ # Return a placeholder image with text
208
+ return get_placeholder_image(f"Map Unavailable\n({lat:.2f}, {lon:.2f})")
209
+
210
+ def get_esri_satellite_image_url(lat, lon, zoom=14):
211
+ """
212
+ Returns the URL for the Esri World Imagery tile at the given location.
213
+ """
214
+ try:
215
+ imagery = EsriImagery()
216
+ # Calculate tile coordinates
217
+ # This is a simplification, cimgt handles this internally usually
218
+ # But for direct URL we might need more logic or just use the static map approach above
219
+ # For now, let's stick to the static map generation which works
220
+ pass
221
+ except:
222
+ pass
223
+ return None
examples/example1.png ADDED

Git LFS Details

  • SHA256: 07dd836c4dfe700657f163afdae9ebf2685f83dca1417078b3147c8c31f598a9
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
examples/example2.png ADDED

Git LFS Details

  • SHA256: e52a44517c028cb6b9828c37c974991fb20122f6cdba951e809ac66b7c591552
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
examples/example3.png ADDED

Git LFS Details

  • SHA256: d63b587c17943eb1e60f511def466696c1a12a323f0f67dff99da7631e2e48aa
  • Pointer size: 131 Bytes
  • Size of remote file: 507 kB
logs/compute_embeddings_dinov2.log ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-02-01 09:07:55,115 [INFO] ================================================================================
2
+ 2026-02-01 09:07:55,115 [INFO] Computing DINOV2 embeddings
3
+ 2026-02-01 09:07:55,115 [INFO] Timestamp: 2026-02-01T09:07:55.115269
4
+ 2026-02-01 09:07:55,115 [INFO] Device: cuda:0
5
+ 2026-02-01 09:07:55,115 [INFO] ================================================================================
6
+ 2026-02-01 09:07:55,116 [INFO] Batch size: 64
7
+ 2026-02-01 09:07:55,116 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet
8
+ 2026-02-01 09:07:55,116 [INFO] Loading dinov2 model...
9
+ 2026-02-01 09:07:58,665 [INFO] DINOv2 model loaded on cuda:0
10
+ 2026-02-01 09:07:58,666 [INFO] Found 1 input files
11
+ 2026-02-01 09:08:48,122 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
12
+ 2026-02-01 09:08:48,122 [INFO] Merging all results...
13
+ 2026-02-01 09:08:48,122 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
14
+ 2026-02-01 09:08:48,122 [INFO] ✓ All unwanted columns removed
15
+ 2026-02-01 09:08:48,122 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
16
+ 2026-02-01 09:08:48,122 [INFO] ✓ Column 'pixel_bbox' added
17
+ 2026-02-01 09:08:48,122 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet...
18
+ 2026-02-01 09:08:48,228 [INFO] ================================================================================
19
+ 2026-02-01 09:08:48,228 [INFO] Processing complete!
20
+ 2026-02-01 09:08:48,228 [INFO] Total rows: 1,996
21
+ 2026-02-01 09:08:48,228 [INFO] Embedding dimension: 1024
22
+ 2026-02-01 09:08:48,228 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet
23
+ 2026-02-01 09:08:48,228 [INFO] ================================================================================
24
+ 2026-02-01 09:43:06,596 [INFO] ================================================================================
25
+ 2026-02-01 09:43:06,596 [INFO] Computing DINOV2 embeddings
26
+ 2026-02-01 09:43:06,596 [INFO] Timestamp: 2026-02-01T09:43:06.596521
27
+ 2026-02-01 09:43:06,596 [INFO] Device: cuda:1
28
+ 2026-02-01 09:43:06,596 [INFO] ================================================================================
29
+ 2026-02-01 09:43:06,597 [INFO] Batch size: 64
30
+ 2026-02-01 09:43:06,597 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet
31
+ 2026-02-01 09:43:06,597 [INFO] Loading dinov2 model...
32
+ 2026-02-01 09:43:08,665 [INFO] DINOv2 model loaded on cuda:1
33
+ 2026-02-01 09:43:08,666 [INFO] Found 125 input files
34
+ 2026-02-01 09:43:59,600 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
35
+ 2026-02-01 09:44:50,531 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
36
+ 2026-02-01 09:45:40,104 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
37
+ 2026-02-01 09:46:31,203 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
38
+ 2026-02-01 09:47:22,240 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
39
+ 2026-02-01 09:48:17,789 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
40
+ 2026-02-01 09:49:12,206 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
41
+ 2026-02-01 09:50:04,633 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
42
+ 2026-02-01 09:51:01,688 [INFO] [batch_0009_384x384.parquet] Processed 1993 rows
43
+ 2026-02-01 09:51:52,258 [INFO] [batch_0010_384x384.parquet] Processed 1993 rows
44
+ 2026-02-01 09:52:43,385 [INFO] [batch_0011_384x384.parquet] Processed 1990 rows
45
+ 2026-02-01 09:53:33,664 [INFO] [batch_0012_384x384.parquet] Processed 1993 rows
46
+ 2026-02-01 09:54:23,450 [INFO] [batch_0013_384x384.parquet] Processed 1993 rows
47
+ 2026-02-01 09:55:14,741 [INFO] [batch_0014_384x384.parquet] Processed 1991 rows
48
+ 2026-02-01 09:56:05,637 [INFO] [batch_0015_384x384.parquet] Processed 1993 rows
49
+ 2026-02-01 09:57:02,579 [INFO] [batch_0016_384x384.parquet] Processed 1990 rows
50
+ 2026-02-01 09:57:59,164 [INFO] [batch_0017_384x384.parquet] Processed 1991 rows
51
+ 2026-02-01 09:58:54,668 [INFO] [batch_0018_384x384.parquet] Processed 1991 rows
52
+ 2026-02-01 09:59:50,748 [INFO] [batch_0019_384x384.parquet] Processed 1996 rows
53
+ 2026-02-01 10:00:44,987 [INFO] [batch_0020_384x384.parquet] Processed 1994 rows
54
+ 2026-02-01 10:01:41,422 [INFO] [batch_0021_384x384.parquet] Processed 1996 rows
55
+ 2026-02-01 10:02:39,884 [INFO] [batch_0022_384x384.parquet] Processed 1995 rows
56
+ 2026-02-01 10:03:41,408 [INFO] [batch_0023_384x384.parquet] Processed 1992 rows
57
+ 2026-02-01 10:04:44,392 [INFO] [batch_0024_384x384.parquet] Processed 1989 rows
58
+ 2026-02-01 10:05:47,970 [INFO] [batch_0025_384x384.parquet] Processed 1993 rows
59
+ 2026-02-01 10:06:47,594 [INFO] [batch_0026_384x384.parquet] Processed 1995 rows
60
+ 2026-02-01 10:07:46,292 [INFO] [batch_0027_384x384.parquet] Processed 1997 rows
61
+ 2026-02-01 10:08:43,976 [INFO] [batch_0028_384x384.parquet] Processed 1991 rows
62
+ 2026-02-01 10:09:43,099 [INFO] [batch_0029_384x384.parquet] Processed 1992 rows
63
+ 2026-02-01 10:10:40,183 [INFO] [batch_0030_384x384.parquet] Processed 1993 rows
64
+ 2026-02-01 10:11:44,485 [INFO] [batch_0031_384x384.parquet] Processed 1988 rows
65
+ 2026-02-01 10:12:39,796 [INFO] [batch_0032_384x384.parquet] Processed 1994 rows
66
+ 2026-02-01 10:13:45,836 [INFO] [batch_0033_384x384.parquet] Processed 1992 rows
67
+ 2026-02-01 10:14:44,908 [INFO] [batch_0034_384x384.parquet] Processed 1994 rows
68
+ 2026-02-01 10:15:44,326 [INFO] [batch_0035_384x384.parquet] Processed 1994 rows
69
+ 2026-02-01 10:16:43,931 [INFO] [batch_0036_384x384.parquet] Processed 1996 rows
70
+ 2026-02-01 10:17:41,513 [INFO] [batch_0037_384x384.parquet] Processed 1995 rows
71
+ 2026-02-01 10:18:39,810 [INFO] [batch_0038_384x384.parquet] Processed 1993 rows
72
+ 2026-02-01 10:19:36,710 [INFO] [batch_0039_384x384.parquet] Processed 1989 rows
73
+ 2026-02-01 10:20:31,841 [INFO] [batch_0040_384x384.parquet] Processed 1990 rows
74
+ 2026-02-01 10:21:29,236 [INFO] [batch_0041_384x384.parquet] Processed 1998 rows
75
+ 2026-02-01 10:22:32,483 [INFO] [batch_0042_384x384.parquet] Processed 1997 rows
76
+ 2026-02-01 10:23:28,852 [INFO] [batch_0043_384x384.parquet] Processed 1988 rows
77
+ 2026-02-01 10:24:24,324 [INFO] [batch_0044_384x384.parquet] Processed 1991 rows
78
+ 2026-02-01 10:25:22,097 [INFO] [batch_0045_384x384.parquet] Processed 1993 rows
79
+ 2026-02-01 10:26:18,196 [INFO] [batch_0046_384x384.parquet] Processed 1994 rows
80
+ 2026-02-01 10:27:34,649 [INFO] [batch_0047_384x384.parquet] Processed 1995 rows
81
+ 2026-02-01 10:28:30,976 [INFO] [batch_0048_384x384.parquet] Processed 1992 rows
82
+ 2026-02-01 10:29:41,715 [INFO] [batch_0049_384x384.parquet] Processed 1996 rows
83
+ 2026-02-01 10:30:45,082 [INFO] [batch_0050_384x384.parquet] Processed 1991 rows
84
+ 2026-02-01 10:31:46,711 [INFO] [batch_0051_384x384.parquet] Processed 1997 rows
85
+ 2026-02-01 10:32:45,127 [INFO] [batch_0052_384x384.parquet] Processed 1993 rows
86
+ 2026-02-01 10:33:48,960 [INFO] [batch_0053_384x384.parquet] Processed 1995 rows
87
+ 2026-02-01 10:35:01,705 [INFO] [batch_0054_384x384.parquet] Processed 1997 rows
88
+ 2026-02-01 10:36:11,677 [INFO] [batch_0055_384x384.parquet] Processed 1995 rows
89
+ 2026-02-01 10:37:17,746 [INFO] [batch_0056_384x384.parquet] Processed 1997 rows
90
+ 2026-02-01 10:38:28,458 [INFO] [batch_0057_384x384.parquet] Processed 1991 rows
91
+ 2026-02-01 10:39:38,673 [INFO] [batch_0058_384x384.parquet] Processed 1994 rows
92
+ 2026-02-01 10:40:48,784 [INFO] [batch_0059_384x384.parquet] Processed 1993 rows
93
+ 2026-02-01 10:41:47,477 [INFO] [batch_0060_384x384.parquet] Processed 1995 rows
94
+ 2026-02-01 10:42:55,595 [INFO] [batch_0061_384x384.parquet] Processed 1995 rows
95
+ 2026-02-01 10:44:08,413 [INFO] [batch_0062_384x384.parquet] Processed 1998 rows
96
+ 2026-02-01 10:45:27,616 [INFO] [batch_0063_384x384.parquet] Processed 1997 rows
97
+ 2026-02-01 10:46:40,936 [INFO] [batch_0064_384x384.parquet] Processed 1992 rows
98
+ 2026-02-01 10:47:38,737 [INFO] [batch_0065_384x384.parquet] Processed 1994 rows
99
+ 2026-02-01 10:48:46,233 [INFO] [batch_0066_384x384.parquet] Processed 1992 rows
100
+ 2026-02-01 10:49:56,228 [INFO] [batch_0067_384x384.parquet] Processed 1993 rows
101
+ 2026-02-01 10:51:12,380 [INFO] [batch_0068_384x384.parquet] Processed 1994 rows
102
+ 2026-02-01 10:52:27,369 [INFO] [batch_0069_384x384.parquet] Processed 1992 rows
103
+ 2026-02-01 10:53:42,056 [INFO] [batch_0070_384x384.parquet] Processed 1997 rows
104
+ 2026-02-01 10:54:50,573 [INFO] [batch_0071_384x384.parquet] Processed 1996 rows
105
+ 2026-02-01 10:56:03,974 [INFO] [batch_0072_384x384.parquet] Processed 1992 rows
106
+ 2026-02-01 10:57:09,742 [INFO] [batch_0073_384x384.parquet] Processed 1995 rows
107
+ 2026-02-01 10:58:22,365 [INFO] [batch_0074_384x384.parquet] Processed 1992 rows
108
+ 2026-02-01 10:59:33,712 [INFO] [batch_0075_384x384.parquet] Processed 1991 rows
109
+ 2026-02-01 11:00:48,387 [INFO] [batch_0076_384x384.parquet] Processed 1998 rows
110
+ 2026-02-01 11:01:47,919 [INFO] [batch_0077_384x384.parquet] Processed 1996 rows
111
+ 2026-02-01 11:03:01,336 [INFO] [batch_0078_384x384.parquet] Processed 1992 rows
112
+ 2026-02-01 11:04:04,437 [INFO] [batch_0079_384x384.parquet] Processed 1995 rows
113
+ 2026-02-01 11:05:15,344 [INFO] [batch_0080_384x384.parquet] Processed 1993 rows
114
+ 2026-02-01 11:06:26,434 [INFO] [batch_0081_384x384.parquet] Processed 1995 rows
115
+ 2026-02-01 11:07:29,500 [INFO] [batch_0082_384x384.parquet] Processed 1989 rows
116
+ 2026-02-01 11:08:41,452 [INFO] [batch_0083_384x384.parquet] Processed 1995 rows
117
+ 2026-02-01 11:09:52,372 [INFO] [batch_0084_384x384.parquet] Processed 1996 rows
118
+ 2026-02-01 11:10:54,102 [INFO] [batch_0085_384x384.parquet] Processed 1997 rows
119
+ 2026-02-01 11:12:05,011 [INFO] [batch_0086_384x384.parquet] Processed 1996 rows
120
+ 2026-02-01 11:13:18,046 [INFO] [batch_0087_384x384.parquet] Processed 1994 rows
121
+ 2026-02-01 11:14:28,554 [INFO] [batch_0088_384x384.parquet] Processed 1992 rows
122
+ 2026-02-01 11:15:30,371 [INFO] [batch_0089_384x384.parquet] Processed 1993 rows
123
+ 2026-02-01 11:16:36,098 [INFO] [batch_0090_384x384.parquet] Processed 1993 rows
124
+ 2026-02-01 11:17:47,559 [INFO] [batch_0091_384x384.parquet] Processed 1995 rows
125
+ 2026-02-01 11:18:59,181 [INFO] [batch_0092_384x384.parquet] Processed 1994 rows
126
+ 2026-02-01 11:20:10,040 [INFO] [batch_0093_384x384.parquet] Processed 1998 rows
127
+ 2026-02-01 11:21:11,780 [INFO] [batch_0094_384x384.parquet] Processed 1993 rows
128
+ 2026-02-01 11:22:13,323 [INFO] [batch_0095_384x384.parquet] Processed 1995 rows
129
+ 2026-02-01 11:23:13,963 [INFO] [batch_0096_384x384.parquet] Processed 1997 rows
130
+ 2026-02-01 11:24:11,380 [INFO] [batch_0097_384x384.parquet] Processed 1990 rows
131
+ 2026-02-01 11:25:16,113 [INFO] [batch_0098_384x384.parquet] Processed 1995 rows
132
+ 2026-02-01 11:26:15,319 [INFO] [batch_0099_384x384.parquet] Processed 1992 rows
133
+ 2026-02-01 11:27:09,846 [INFO] [batch_0100_384x384.parquet] Processed 1993 rows
134
+ 2026-02-01 11:28:13,634 [INFO] [batch_0101_384x384.parquet] Processed 1994 rows
135
+ 2026-02-01 11:29:19,508 [INFO] [batch_0102_384x384.parquet] Processed 1991 rows
136
+ 2026-02-01 11:30:27,321 [INFO] [batch_0103_384x384.parquet] Processed 1990 rows
137
+ 2026-02-01 11:31:38,038 [INFO] [batch_0104_384x384.parquet] Processed 1995 rows
138
+ 2026-02-01 11:32:55,342 [INFO] [batch_0105_384x384.parquet] Processed 1993 rows
139
+ 2026-02-01 11:34:02,868 [INFO] [batch_0106_384x384.parquet] Processed 1988 rows
140
+ 2026-02-01 11:35:08,481 [INFO] [batch_0107_384x384.parquet] Processed 1996 rows
141
+ 2026-02-01 11:36:17,025 [INFO] [batch_0108_384x384.parquet] Processed 1992 rows
142
+ 2026-02-01 11:37:26,799 [INFO] [batch_0109_384x384.parquet] Processed 1993 rows
143
+ 2026-02-01 11:38:39,274 [INFO] [batch_0110_384x384.parquet] Processed 1996 rows
144
+ 2026-02-01 11:39:49,743 [INFO] [batch_0111_384x384.parquet] Processed 1991 rows
145
+ 2026-02-01 11:40:47,923 [INFO] [batch_0112_384x384.parquet] Processed 1994 rows
146
+ 2026-02-01 11:41:53,376 [INFO] [batch_0113_384x384.parquet] Processed 1996 rows
147
+ 2026-02-01 11:42:53,847 [INFO] [batch_0114_384x384.parquet] Processed 1997 rows
148
+ 2026-02-01 11:43:47,456 [INFO] [batch_0115_384x384.parquet] Processed 1997 rows
149
+ 2026-02-01 11:44:47,188 [INFO] [batch_0116_384x384.parquet] Processed 1992 rows
150
+ 2026-02-01 11:45:44,350 [INFO] [batch_0117_384x384.parquet] Processed 1992 rows
151
+ 2026-02-01 11:46:51,765 [INFO] [batch_0118_384x384.parquet] Processed 1993 rows
152
+ 2026-02-01 11:47:54,777 [INFO] [batch_0119_384x384.parquet] Processed 1998 rows
153
+ 2026-02-01 11:48:58,907 [INFO] [batch_0120_384x384.parquet] Processed 1995 rows
154
+ 2026-02-01 11:49:59,917 [INFO] [batch_0121_384x384.parquet] Processed 1995 rows
155
+ 2026-02-01 11:51:00,476 [INFO] [batch_0122_384x384.parquet] Processed 1992 rows
156
+ 2026-02-01 11:52:05,414 [INFO] [batch_0123_384x384.parquet] Processed 1994 rows
157
+ 2026-02-01 11:53:06,075 [INFO] [batch_0124_384x384.parquet] Processed 1995 rows
158
+ 2026-02-01 11:53:54,915 [INFO] [batch_0125_384x384.parquet] Processed 1505 rows
159
+ 2026-02-01 11:53:54,915 [INFO] Merging all results...
160
+ 2026-02-01 11:53:54,970 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
161
+ 2026-02-01 11:53:54,971 [INFO] ✓ All unwanted columns removed
162
+ 2026-02-01 11:53:54,971 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
163
+ 2026-02-01 11:53:54,971 [INFO] ✓ Column 'pixel_bbox' added
164
+ 2026-02-01 11:53:54,971 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet...
165
+ 2026-02-01 11:54:03,559 [INFO] ================================================================================
166
+ 2026-02-01 11:54:03,559 [INFO] Processing complete!
167
+ 2026-02-01 11:54:03,559 [INFO] Total rows: 248,719
168
+ 2026-02-01 11:54:03,560 [INFO] Embedding dimension: 1024
169
+ 2026-02-01 11:54:03,560 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet
170
+ 2026-02-01 11:54:03,560 [INFO] ================================================================================
logs/compute_embeddings_farslip.log ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-02-01 09:54:48,604 [INFO] ================================================================================
2
+ 2026-02-01 09:54:48,605 [INFO] Computing FARSLIP embeddings
3
+ 2026-02-01 09:54:48,605 [INFO] Timestamp: 2026-02-01T09:54:48.605134
4
+ 2026-02-01 09:54:48,605 [INFO] Device: cuda:4
5
+ 2026-02-01 09:54:48,605 [INFO] ================================================================================
6
+ 2026-02-01 09:54:48,606 [INFO] Batch size: 64
7
+ 2026-02-01 09:54:48,607 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/farslip/FarSLIP_crop_384x384.parquet
8
+ 2026-02-01 09:54:48,607 [INFO] Loading farslip model...
9
+ 2026-02-01 09:54:48,613 [INFO] Loaded ViT-B-16 model config.
10
+ 2026-02-01 09:54:50,536 [INFO] Loading pretrained ViT-B-16 weights (/data1/zyj/checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt).
11
+ 2026-02-01 09:54:51,666 [INFO] Missing keys: []
12
+ 2026-02-01 09:54:51,745 [INFO] FarSLIP model loaded on cuda:4
13
+ 2026-02-01 09:54:51,745 [INFO] Found 125 input files
14
+ 2026-02-01 09:55:38,785 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
15
+ 2026-02-01 09:56:18,239 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
16
+ 2026-02-01 09:57:17,259 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
17
+ 2026-02-01 09:58:08,339 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
18
+ 2026-02-01 09:59:00,302 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
19
+ 2026-02-01 10:00:15,416 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
20
+ 2026-02-01 10:01:22,601 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
21
+ 2026-02-01 10:02:25,131 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
22
+ 2026-02-01 10:03:31,735 [INFO] [batch_0009_384x384.parquet] Processed 1993 rows
23
+ 2026-02-01 10:04:47,342 [INFO] [batch_0010_384x384.parquet] Processed 1993 rows
24
+ 2026-02-01 10:05:54,617 [INFO] [batch_0011_384x384.parquet] Processed 1990 rows
25
+ 2026-02-01 10:06:58,372 [INFO] [batch_0012_384x384.parquet] Processed 1993 rows
26
+ 2026-02-01 10:08:16,301 [INFO] [batch_0013_384x384.parquet] Processed 1993 rows
27
+ 2026-02-01 10:09:11,722 [INFO] [batch_0014_384x384.parquet] Processed 1991 rows
28
+ 2026-02-01 10:10:23,603 [INFO] [batch_0015_384x384.parquet] Processed 1993 rows
29
+ 2026-02-01 10:11:38,047 [INFO] [batch_0016_384x384.parquet] Processed 1990 rows
30
+ 2026-02-01 10:12:22,943 [INFO] [batch_0017_384x384.parquet] Processed 1991 rows
31
+ 2026-02-01 10:13:41,095 [INFO] [batch_0018_384x384.parquet] Processed 1991 rows
32
+ 2026-02-01 10:14:47,596 [INFO] [batch_0019_384x384.parquet] Processed 1996 rows
33
+ 2026-02-01 10:15:40,983 [INFO] [batch_0020_384x384.parquet] Processed 1994 rows
34
+ 2026-02-01 10:16:52,878 [INFO] [batch_0021_384x384.parquet] Processed 1996 rows
35
+ 2026-02-01 10:17:43,460 [INFO] [batch_0022_384x384.parquet] Processed 1995 rows
36
+ 2026-02-01 10:18:41,479 [INFO] [batch_0023_384x384.parquet] Processed 1992 rows
37
+ 2026-02-01 10:19:40,728 [INFO] [batch_0024_384x384.parquet] Processed 1989 rows
38
+ 2026-02-01 10:20:25,503 [INFO] [batch_0025_384x384.parquet] Processed 1993 rows
39
+ 2026-02-01 10:21:27,428 [INFO] [batch_0026_384x384.parquet] Processed 1995 rows
40
+ 2026-02-01 10:22:23,776 [INFO] [batch_0027_384x384.parquet] Processed 1997 rows
41
+ 2026-02-01 10:23:16,992 [INFO] [batch_0028_384x384.parquet] Processed 1991 rows
42
+ 2026-02-01 10:24:14,634 [INFO] [batch_0029_384x384.parquet] Processed 1992 rows
43
+ 2026-02-01 10:24:55,464 [INFO] [batch_0030_384x384.parquet] Processed 1993 rows
44
+ 2026-02-01 10:25:56,600 [INFO] [batch_0031_384x384.parquet] Processed 1988 rows
45
+ 2026-02-01 10:26:40,392 [INFO] [batch_0032_384x384.parquet] Processed 1994 rows
46
+ 2026-02-01 10:27:49,696 [INFO] [batch_0033_384x384.parquet] Processed 1992 rows
47
+ 2026-02-01 10:28:49,831 [INFO] [batch_0034_384x384.parquet] Processed 1994 rows
48
+ 2026-02-01 10:29:42,378 [INFO] [batch_0035_384x384.parquet] Processed 1994 rows
49
+ 2026-02-01 10:30:48,969 [INFO] [batch_0036_384x384.parquet] Processed 1996 rows
50
+ 2026-02-01 10:32:01,922 [INFO] [batch_0037_384x384.parquet] Processed 1995 rows
51
+ 2026-02-01 10:32:47,057 [INFO] [batch_0038_384x384.parquet] Processed 1993 rows
52
+ 2026-02-01 10:34:01,196 [INFO] [batch_0039_384x384.parquet] Processed 1989 rows
53
+ 2026-02-01 10:35:19,501 [INFO] [batch_0040_384x384.parquet] Processed 1990 rows
54
+ 2026-02-01 10:36:09,997 [INFO] [batch_0041_384x384.parquet] Processed 1998 rows
55
+ 2026-02-01 10:37:25,589 [INFO] [batch_0042_384x384.parquet] Processed 1997 rows
56
+ 2026-02-01 10:38:42,876 [INFO] [batch_0043_384x384.parquet] Processed 1988 rows
57
+ 2026-02-01 10:39:31,979 [INFO] [batch_0044_384x384.parquet] Processed 1991 rows
58
+ 2026-02-01 10:40:43,745 [INFO] [batch_0045_384x384.parquet] Processed 1993 rows
59
+ 2026-02-01 10:41:59,576 [INFO] [batch_0046_384x384.parquet] Processed 1994 rows
60
+ 2026-02-01 10:42:53,620 [INFO] [batch_0047_384x384.parquet] Processed 1995 rows
61
+ 2026-02-01 10:44:25,584 [INFO] [batch_0048_384x384.parquet] Processed 1992 rows
62
+ 2026-02-01 10:46:13,258 [INFO] [batch_0049_384x384.parquet] Processed 1996 rows
63
+ 2026-02-01 10:47:13,109 [INFO] [batch_0050_384x384.parquet] Processed 1991 rows
64
+ 2026-02-01 10:48:13,385 [INFO] [batch_0051_384x384.parquet] Processed 1997 rows
65
+ 2026-02-01 10:49:48,140 [INFO] [batch_0052_384x384.parquet] Processed 1993 rows
66
+ 2026-02-01 10:51:22,710 [INFO] [batch_0053_384x384.parquet] Processed 1995 rows
67
+ 2026-02-01 10:52:23,823 [INFO] [batch_0054_384x384.parquet] Processed 1997 rows
68
+ 2026-02-01 10:53:48,669 [INFO] [batch_0055_384x384.parquet] Processed 1995 rows
69
+ 2026-02-01 10:55:03,785 [INFO] [batch_0056_384x384.parquet] Processed 1997 rows
70
+ 2026-02-01 10:55:56,653 [INFO] [batch_0057_384x384.parquet] Processed 1991 rows
71
+ 2026-02-01 10:56:50,364 [INFO] [batch_0058_384x384.parquet] Processed 1994 rows
72
+ 2026-02-01 10:57:33,268 [INFO] [batch_0059_384x384.parquet] Processed 1993 rows
73
+ 2026-02-01 10:58:36,103 [INFO] [batch_0060_384x384.parquet] Processed 1995 rows
74
+ 2026-02-01 10:59:43,156 [INFO] [batch_0061_384x384.parquet] Processed 1995 rows
75
+ 2026-02-01 11:00:45,280 [INFO] [batch_0062_384x384.parquet] Processed 1998 rows
76
+ 2026-02-01 11:02:03,960 [INFO] [batch_0063_384x384.parquet] Processed 1997 rows
77
+ 2026-02-01 11:03:01,993 [INFO] [batch_0064_384x384.parquet] Processed 1992 rows
78
+ 2026-02-01 11:04:18,812 [INFO] [batch_0065_384x384.parquet] Processed 1994 rows
79
+ 2026-02-01 11:05:34,954 [INFO] [batch_0066_384x384.parquet] Processed 1992 rows
80
+ 2026-02-01 11:06:26,502 [INFO] [batch_0067_384x384.parquet] Processed 1993 rows
81
+ 2026-02-01 11:07:42,754 [INFO] [batch_0068_384x384.parquet] Processed 1994 rows
82
+ 2026-02-01 11:09:01,751 [INFO] [batch_0069_384x384.parquet] Processed 1992 rows
83
+ 2026-02-01 11:09:49,394 [INFO] [batch_0070_384x384.parquet] Processed 1997 rows
84
+ 2026-02-01 11:11:06,518 [INFO] [batch_0071_384x384.parquet] Processed 1996 rows
85
+ 2026-02-01 11:12:22,688 [INFO] [batch_0072_384x384.parquet] Processed 1992 rows
86
+ 2026-02-01 11:13:14,831 [INFO] [batch_0073_384x384.parquet] Processed 1995 rows
87
+ 2026-02-01 11:14:14,879 [INFO] [batch_0074_384x384.parquet] Processed 1992 rows
88
+ 2026-02-01 11:14:58,098 [INFO] [batch_0075_384x384.parquet] Processed 1991 rows
89
+ 2026-02-01 11:15:43,764 [INFO] [batch_0076_384x384.parquet] Processed 1998 rows
90
+ 2026-02-01 11:16:53,710 [INFO] [batch_0077_384x384.parquet] Processed 1996 rows
91
+ 2026-02-01 11:17:51,040 [INFO] [batch_0078_384x384.parquet] Processed 1992 rows
92
+ 2026-02-01 11:18:57,871 [INFO] [batch_0079_384x384.parquet] Processed 1995 rows
93
+ 2026-02-01 11:20:06,930 [INFO] [batch_0080_384x384.parquet] Processed 1993 rows
94
+ 2026-02-01 11:20:51,630 [INFO] [batch_0081_384x384.parquet] Processed 1995 rows
95
+ 2026-02-01 11:21:43,270 [INFO] [batch_0082_384x384.parquet] Processed 1989 rows
96
+ 2026-02-01 11:22:29,228 [INFO] [batch_0083_384x384.parquet] Processed 1995 rows
97
+ 2026-02-01 11:23:23,236 [INFO] [batch_0084_384x384.parquet] Processed 1996 rows
98
+ 2026-02-01 11:24:32,532 [INFO] [batch_0085_384x384.parquet] Processed 1997 rows
99
+ 2026-02-01 11:25:20,336 [INFO] [batch_0086_384x384.parquet] Processed 1996 rows
100
+ 2026-02-01 11:26:33,616 [INFO] [batch_0087_384x384.parquet] Processed 1994 rows
101
+ 2026-02-01 11:27:24,449 [INFO] [batch_0088_384x384.parquet] Processed 1992 rows
102
+ 2026-02-01 11:28:20,047 [INFO] [batch_0089_384x384.parquet] Processed 1993 rows
103
+ 2026-02-01 11:29:43,109 [INFO] [batch_0090_384x384.parquet] Processed 1993 rows
104
+ 2026-02-01 11:30:41,652 [INFO] [batch_0091_384x384.parquet] Processed 1995 rows
105
+ 2026-02-01 11:31:43,751 [INFO] [batch_0092_384x384.parquet] Processed 1994 rows
106
+ 2026-02-01 11:33:10,661 [INFO] [batch_0093_384x384.parquet] Processed 1998 rows
107
+ 2026-02-01 11:34:12,721 [INFO] [batch_0094_384x384.parquet] Processed 1993 rows
108
+ 2026-02-01 11:35:09,887 [INFO] [batch_0095_384x384.parquet] Processed 1995 rows
109
+ 2026-02-01 11:36:36,141 [INFO] [batch_0096_384x384.parquet] Processed 1997 rows
110
+ 2026-02-01 11:37:41,740 [INFO] [batch_0097_384x384.parquet] Processed 1990 rows
111
+ 2026-02-01 11:38:40,066 [INFO] [batch_0098_384x384.parquet] Processed 1995 rows
112
+ 2026-02-01 11:39:45,765 [INFO] [batch_0099_384x384.parquet] Processed 1992 rows
113
+ 2026-02-01 11:40:40,739 [INFO] [batch_0100_384x384.parquet] Processed 1993 rows
114
+ 2026-02-01 11:41:41,583 [INFO] [batch_0101_384x384.parquet] Processed 1994 rows
115
+ 2026-02-01 11:42:47,504 [INFO] [batch_0102_384x384.parquet] Processed 1991 rows
116
+ 2026-02-01 11:43:31,148 [INFO] [batch_0103_384x384.parquet] Processed 1990 rows
117
+ 2026-02-01 11:44:38,070 [INFO] [batch_0104_384x384.parquet] Processed 1995 rows
118
+ 2026-02-01 11:45:48,089 [INFO] [batch_0105_384x384.parquet] Processed 1993 rows
119
+ 2026-02-01 11:46:47,156 [INFO] [batch_0106_384x384.parquet] Processed 1988 rows
120
+ 2026-02-01 11:48:06,340 [INFO] [batch_0107_384x384.parquet] Processed 1996 rows
121
+ 2026-02-01 11:49:08,016 [INFO] [batch_0108_384x384.parquet] Processed 1992 rows
122
+ 2026-02-01 11:50:27,665 [INFO] [batch_0109_384x384.parquet] Processed 1993 rows
123
+ 2026-02-01 11:51:38,073 [INFO] [batch_0110_384x384.parquet] Processed 1996 rows
124
+ 2026-02-01 11:52:26,956 [INFO] [batch_0111_384x384.parquet] Processed 1991 rows
125
+ 2026-02-01 11:53:44,395 [INFO] [batch_0112_384x384.parquet] Processed 1994 rows
126
+ 2026-02-01 11:54:23,803 [INFO] [batch_0113_384x384.parquet] Processed 1996 rows
127
+ 2026-02-01 11:55:07,867 [INFO] [batch_0114_384x384.parquet] Processed 1997 rows
128
+ 2026-02-01 11:55:54,834 [INFO] [batch_0115_384x384.parquet] Processed 1997 rows
129
+ 2026-02-01 11:56:36,849 [INFO] [batch_0116_384x384.parquet] Processed 1992 rows
130
+ 2026-02-01 11:57:20,506 [INFO] [batch_0117_384x384.parquet] Processed 1992 rows
131
+ 2026-02-01 11:57:58,985 [INFO] [batch_0118_384x384.parquet] Processed 1993 rows
132
+ 2026-02-01 11:58:38,965 [INFO] [batch_0119_384x384.parquet] Processed 1998 rows
133
+ 2026-02-01 11:59:16,459 [INFO] [batch_0120_384x384.parquet] Processed 1995 rows
134
+ 2026-02-01 11:59:55,497 [INFO] [batch_0121_384x384.parquet] Processed 1995 rows
135
+ 2026-02-01 12:00:33,857 [INFO] [batch_0122_384x384.parquet] Processed 1992 rows
136
+ 2026-02-01 12:01:15,871 [INFO] [batch_0123_384x384.parquet] Processed 1994 rows
137
+ 2026-02-01 12:01:53,537 [INFO] [batch_0124_384x384.parquet] Processed 1995 rows
138
+ 2026-02-01 12:02:22,334 [INFO] [batch_0125_384x384.parquet] Processed 1505 rows
139
+ 2026-02-01 12:02:22,334 [INFO] Merging all results...
140
+ 2026-02-01 12:02:22,384 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
141
+ 2026-02-01 12:02:22,384 [INFO] ✓ All unwanted columns removed
142
+ 2026-02-01 12:02:22,384 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
143
+ 2026-02-01 12:02:22,384 [INFO] ✓ Column 'pixel_bbox' added
144
+ 2026-02-01 12:02:22,384 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/farslip/FarSLIP_crop_384x384.parquet...
145
+ 2026-02-01 12:02:25,588 [INFO] ================================================================================
146
+ 2026-02-01 12:02:25,588 [INFO] Processing complete!
147
+ 2026-02-01 12:02:25,588 [INFO] Total rows: 248,719
148
+ 2026-02-01 12:02:25,589 [INFO] Embedding dimension: 512
149
+ 2026-02-01 12:02:25,589 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/farslip/FarSLIP_crop_384x384.parquet
150
+ 2026-02-01 12:02:25,589 [INFO] ================================================================================
logs/compute_embeddings_satclip.log ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-02-01 09:09:57,720 [INFO] ================================================================================
2
+ 2026-02-01 09:09:57,720 [INFO] Computing SATCLIP embeddings
3
+ 2026-02-01 09:09:57,720 [INFO] Timestamp: 2026-02-01T09:09:57.720447
4
+ 2026-02-01 09:09:57,720 [INFO] Device: cuda:1
5
+ 2026-02-01 09:09:57,720 [INFO] ================================================================================
6
+ 2026-02-01 09:09:57,721 [INFO] Batch size: 128
7
+ 2026-02-01 09:09:57,721 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
8
+ 2026-02-01 09:09:57,721 [INFO] Loading satclip model...
9
+ 2026-02-01 09:09:57,727 [INFO] SatCLIP-MS model loaded on cuda:1
10
+ 2026-02-01 09:09:57,728 [INFO] Found 1 input files
11
+ 2026-02-01 09:10:21,830 [WARNING] No valid embeddings for batch_0001_384x384.parquet
12
+ 2026-02-01 09:10:22,107 [ERROR] No data processed!
13
+ 2026-02-01 09:39:17,993 [INFO] ================================================================================
14
+ 2026-02-01 09:39:17,993 [INFO] Computing SATCLIP embeddings
15
+ 2026-02-01 09:39:17,993 [INFO] Timestamp: 2026-02-01T09:39:17.993775
16
+ 2026-02-01 09:39:17,993 [INFO] Device: cuda:1
17
+ 2026-02-01 09:39:17,993 [INFO] ================================================================================
18
+ 2026-02-01 09:39:17,994 [INFO] Batch size: 128
19
+ 2026-02-01 09:39:17,994 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
20
+ 2026-02-01 09:39:17,994 [INFO] Loading satclip model...
21
+ 2026-02-01 09:39:20,179 [INFO] SatCLIP-MS model loaded on cuda:1
22
+ 2026-02-01 09:39:20,180 [INFO] Found 1 input files
23
+ 2026-02-01 09:40:01,084 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
24
+ 2026-02-01 09:40:01,084 [INFO] Merging all results...
25
+ 2026-02-01 09:40:01,085 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
26
+ 2026-02-01 09:40:01,085 [INFO] ✓ All unwanted columns removed
27
+ 2026-02-01 09:40:01,085 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
28
+ 2026-02-01 09:40:01,085 [INFO] ✓ Column 'pixel_bbox' added
29
+ 2026-02-01 09:40:01,085 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet...
30
+ 2026-02-01 09:40:01,134 [INFO] ================================================================================
31
+ 2026-02-01 09:40:01,134 [INFO] Processing complete!
32
+ 2026-02-01 09:40:01,134 [INFO] Total rows: 1,996
33
+ 2026-02-01 09:40:01,134 [INFO] Embedding dimension: 256
34
+ 2026-02-01 09:40:01,134 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
35
+ 2026-02-01 09:40:01,134 [INFO] ================================================================================
36
+ 2026-02-01 09:43:19,666 [INFO] ================================================================================
37
+ 2026-02-01 09:43:19,666 [INFO] Computing SATCLIP embeddings
38
+ 2026-02-01 09:43:19,666 [INFO] Timestamp: 2026-02-01T09:43:19.666577
39
+ 2026-02-01 09:43:19,666 [INFO] Device: cuda:3
40
+ 2026-02-01 09:43:19,666 [INFO] ================================================================================
41
+ 2026-02-01 09:43:19,668 [INFO] Batch size: 128
42
+ 2026-02-01 09:43:19,668 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
43
+ 2026-02-01 09:43:19,668 [INFO] Loading satclip model...
44
+ 2026-02-01 09:43:21,344 [INFO] SatCLIP-MS model loaded on cuda:3
45
+ 2026-02-01 09:43:21,345 [INFO] Found 125 input files
46
+ 2026-02-01 09:44:03,000 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
47
+ 2026-02-01 09:44:46,041 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
48
+ 2026-02-01 09:45:27,652 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
49
+ 2026-02-01 09:46:15,446 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
50
+ 2026-02-01 09:47:09,769 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
51
+ 2026-02-01 09:47:59,773 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
52
+ 2026-02-01 09:48:51,057 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
53
+ 2026-02-01 09:49:34,202 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
54
+ 2026-02-01 09:50:25,944 [INFO] [batch_0009_384x384.parquet] Processed 1993 rows
55
+ 2026-02-01 09:51:09,586 [INFO] [batch_0010_384x384.parquet] Processed 1993 rows
56
+ 2026-02-01 09:51:56,545 [INFO] [batch_0011_384x384.parquet] Processed 1990 rows
57
+ 2026-02-01 09:52:44,526 [INFO] [batch_0012_384x384.parquet] Processed 1993 rows
58
+ 2026-02-01 09:53:32,729 [INFO] [batch_0013_384x384.parquet] Processed 1993 rows
59
+ 2026-02-01 09:54:14,312 [INFO] [batch_0014_384x384.parquet] Processed 1991 rows
60
+ 2026-02-01 09:55:05,975 [INFO] [batch_0015_384x384.parquet] Processed 1993 rows
61
+ 2026-02-01 09:55:57,268 [INFO] [batch_0016_384x384.parquet] Processed 1990 rows
62
+ 2026-02-01 09:57:00,591 [INFO] [batch_0017_384x384.parquet] Processed 1991 rows
63
+ 2026-02-01 09:57:48,464 [INFO] [batch_0018_384x384.parquet] Processed 1991 rows
64
+ 2026-02-01 09:58:52,420 [INFO] [batch_0019_384x384.parquet] Processed 1996 rows
65
+ 2026-02-01 10:00:04,202 [INFO] [batch_0020_384x384.parquet] Processed 1994 rows
66
+ 2026-02-01 10:01:10,309 [INFO] [batch_0021_384x384.parquet] Processed 1996 rows
67
+ 2026-02-01 10:02:15,265 [INFO] [batch_0022_384x384.parquet] Processed 1995 rows
68
+ 2026-02-01 10:03:31,554 [INFO] [batch_0023_384x384.parquet] Processed 1992 rows
69
+ 2026-02-01 10:04:40,240 [INFO] [batch_0024_384x384.parquet] Processed 1989 rows
70
+ 2026-02-01 10:05:55,812 [INFO] [batch_0025_384x384.parquet] Processed 1993 rows
71
+ 2026-02-01 10:07:00,366 [INFO] [batch_0026_384x384.parquet] Processed 1995 rows
72
+ 2026-02-01 10:08:10,532 [INFO] [batch_0027_384x384.parquet] Processed 1997 rows
73
+ 2026-02-01 10:09:11,505 [INFO] [batch_0028_384x384.parquet] Processed 1991 rows
74
+ 2026-02-01 10:10:21,951 [INFO] [batch_0029_384x384.parquet] Processed 1992 rows
75
+ 2026-02-01 10:11:30,988 [INFO] [batch_0030_384x384.parquet] Processed 1993 rows
76
+ 2026-02-01 10:12:26,034 [INFO] [batch_0031_384x384.parquet] Processed 1988 rows
77
+ 2026-02-01 10:13:36,732 [INFO] [batch_0032_384x384.parquet] Processed 1994 rows
78
+ 2026-02-01 10:14:36,787 [INFO] [batch_0033_384x384.parquet] Processed 1992 rows
79
+ 2026-02-01 10:15:36,921 [INFO] [batch_0034_384x384.parquet] Processed 1994 rows
80
+ 2026-02-01 10:16:38,623 [INFO] [batch_0035_384x384.parquet] Processed 1994 rows
81
+ 2026-02-01 10:17:27,583 [INFO] [batch_0036_384x384.parquet] Processed 1996 rows
82
+ 2026-02-01 10:18:29,976 [INFO] [batch_0037_384x384.parquet] Processed 1995 rows
83
+ 2026-02-01 10:19:26,843 [INFO] [batch_0038_384x384.parquet] Processed 1993 rows
84
+ 2026-02-01 10:20:14,532 [INFO] [batch_0039_384x384.parquet] Processed 1989 rows
85
+ 2026-02-01 10:21:13,694 [INFO] [batch_0040_384x384.parquet] Processed 1990 rows
86
+ 2026-02-01 10:22:05,858 [INFO] [batch_0041_384x384.parquet] Processed 1998 rows
87
+ 2026-02-01 10:23:04,226 [INFO] [batch_0042_384x384.parquet] Processed 1997 rows
88
+ 2026-02-01 10:23:56,641 [INFO] [batch_0043_384x384.parquet] Processed 1988 rows
89
+ 2026-02-01 10:24:38,594 [INFO] [batch_0044_384x384.parquet] Processed 1991 rows
90
+ 2026-02-01 10:25:42,517 [INFO] [batch_0045_384x384.parquet] Processed 1993 rows
91
+ 2026-02-01 10:26:23,732 [INFO] [batch_0046_384x384.parquet] Processed 1994 rows
92
+ 2026-02-01 10:27:39,298 [INFO] [batch_0047_384x384.parquet] Processed 1995 rows
93
+ 2026-02-01 10:28:34,546 [INFO] [batch_0048_384x384.parquet] Processed 1992 rows
94
+ 2026-02-01 10:29:35,568 [INFO] [batch_0049_384x384.parquet] Processed 1996 rows
95
+ 2026-02-01 10:30:38,004 [INFO] [batch_0050_384x384.parquet] Processed 1991 rows
96
+ 2026-02-01 10:31:50,544 [INFO] [batch_0051_384x384.parquet] Processed 1997 rows
97
+ 2026-02-01 10:32:38,165 [INFO] [batch_0052_384x384.parquet] Processed 1993 rows
98
+ 2026-02-01 10:33:54,330 [INFO] [batch_0053_384x384.parquet] Processed 1995 rows
99
+ 2026-02-01 10:35:11,070 [INFO] [batch_0054_384x384.parquet] Processed 1997 rows
100
+ 2026-02-01 10:36:06,495 [INFO] [batch_0055_384x384.parquet] Processed 1995 rows
101
+ 2026-02-01 10:37:26,449 [INFO] [batch_0056_384x384.parquet] Processed 1997 rows
102
+ 2026-02-01 10:38:40,433 [INFO] [batch_0057_384x384.parquet] Processed 1991 rows
103
+ 2026-02-01 10:39:36,229 [INFO] [batch_0058_384x384.parquet] Processed 1994 rows
104
+ 2026-02-01 10:40:50,558 [INFO] [batch_0059_384x384.parquet] Processed 1993 rows
105
+ 2026-02-01 10:42:00,100 [INFO] [batch_0060_384x384.parquet] Processed 1995 rows
106
+ 2026-02-01 10:42:53,440 [INFO] [batch_0061_384x384.parquet] Processed 1995 rows
107
+ 2026-02-01 10:44:21,706 [INFO] [batch_0062_384x384.parquet] Processed 1998 rows
108
+ 2026-02-01 10:45:56,656 [INFO] [batch_0063_384x384.parquet] Processed 1997 rows
109
+ 2026-02-01 10:46:53,942 [INFO] [batch_0064_384x384.parquet] Processed 1992 rows
110
+ 2026-02-01 10:47:47,760 [INFO] [batch_0065_384x384.parquet] Processed 1994 rows
111
+ 2026-02-01 10:48:37,571 [INFO] [batch_0066_384x384.parquet] Processed 1992 rows
112
+ 2026-02-01 10:50:00,819 [INFO] [batch_0067_384x384.parquet] Processed 1993 rows
113
+ 2026-02-01 10:51:30,799 [INFO] [batch_0068_384x384.parquet] Processed 1994 rows
114
+ 2026-02-01 10:52:28,413 [INFO] [batch_0069_384x384.parquet] Processed 1992 rows
115
+ 2026-02-01 10:53:50,597 [INFO] [batch_0070_384x384.parquet] Processed 1997 rows
116
+ 2026-02-01 10:55:01,173 [INFO] [batch_0071_384x384.parquet] Processed 1996 rows
117
+ 2026-02-01 10:56:03,395 [INFO] [batch_0072_384x384.parquet] Processed 1992 rows
118
+ 2026-02-01 10:57:10,601 [INFO] [batch_0073_384x384.parquet] Processed 1995 rows
119
+ 2026-02-01 10:58:22,789 [INFO] [batch_0074_384x384.parquet] Processed 1992 rows
120
+ 2026-02-01 10:59:39,697 [INFO] [batch_0075_384x384.parquet] Processed 1991 rows
121
+ 2026-02-01 11:00:48,962 [INFO] [batch_0076_384x384.parquet] Processed 1998 rows
122
+ 2026-02-01 11:01:59,729 [INFO] [batch_0077_384x384.parquet] Processed 1996 rows
123
+ 2026-02-01 11:03:01,575 [INFO] [batch_0078_384x384.parquet] Processed 1992 rows
124
+ 2026-02-01 11:04:15,721 [INFO] [batch_0079_384x384.parquet] Processed 1995 rows
125
+ 2026-02-01 11:05:26,147 [INFO] [batch_0080_384x384.parquet] Processed 1993 rows
126
+ 2026-02-01 11:06:21,742 [INFO] [batch_0081_384x384.parquet] Processed 1995 rows
127
+ 2026-02-01 11:07:34,071 [INFO] [batch_0082_384x384.parquet] Processed 1989 rows
128
+ 2026-02-01 11:08:51,443 [INFO] [batch_0083_384x384.parquet] Processed 1995 rows
129
+ 2026-02-01 11:09:45,289 [INFO] [batch_0084_384x384.parquet] Processed 1996 rows
130
+ 2026-02-01 11:10:59,507 [INFO] [batch_0085_384x384.parquet] Processed 1997 rows
131
+ 2026-02-01 11:12:12,671 [INFO] [batch_0086_384x384.parquet] Processed 1996 rows
132
+ 2026-02-01 11:13:16,945 [INFO] [batch_0087_384x384.parquet] Processed 1994 rows
133
+ 2026-02-01 11:14:26,324 [INFO] [batch_0088_384x384.parquet] Processed 1992 rows
134
+ 2026-02-01 11:15:25,871 [INFO] [batch_0089_384x384.parquet] Processed 1993 rows
135
+ 2026-02-01 11:16:43,653 [INFO] [batch_0090_384x384.parquet] Processed 1993 rows
136
+ 2026-02-01 11:17:52,205 [INFO] [batch_0091_384x384.parquet] Processed 1995 rows
137
+ 2026-02-01 11:19:02,073 [INFO] [batch_0092_384x384.parquet] Processed 1994 rows
138
+ 2026-02-01 11:20:14,843 [INFO] [batch_0093_384x384.parquet] Processed 1998 rows
139
+ 2026-02-01 11:21:09,193 [INFO] [batch_0094_384x384.parquet] Processed 1993 rows
140
+ 2026-02-01 11:22:03,303 [INFO] [batch_0095_384x384.parquet] Processed 1995 rows
141
+ 2026-02-01 11:23:12,708 [INFO] [batch_0096_384x384.parquet] Processed 1997 rows
142
+ 2026-02-01 11:24:18,831 [INFO] [batch_0097_384x384.parquet] Processed 1990 rows
143
+ 2026-02-01 11:25:07,701 [INFO] [batch_0098_384x384.parquet] Processed 1995 rows
144
+ 2026-02-01 11:26:18,306 [INFO] [batch_0099_384x384.parquet] Processed 1992 rows
145
+ 2026-02-01 11:27:02,698 [INFO] [batch_0100_384x384.parquet] Processed 1993 rows
146
+ 2026-02-01 11:28:08,644 [INFO] [batch_0101_384x384.parquet] Processed 1994 rows
147
+ 2026-02-01 11:29:33,678 [INFO] [batch_0102_384x384.parquet] Processed 1991 rows
148
+ 2026-02-01 11:30:25,760 [INFO] [batch_0103_384x384.parquet] Processed 1990 rows
149
+ 2026-02-01 11:31:38,365 [INFO] [batch_0104_384x384.parquet] Processed 1995 rows
150
+ 2026-02-01 11:33:06,206 [INFO] [batch_0105_384x384.parquet] Processed 1993 rows
151
+ 2026-02-01 11:33:59,497 [INFO] [batch_0106_384x384.parquet] Processed 1988 rows
152
+ 2026-02-01 11:35:04,565 [INFO] [batch_0107_384x384.parquet] Processed 1996 rows
153
+ 2026-02-01 11:36:30,898 [INFO] [batch_0108_384x384.parquet] Processed 1992 rows
154
+ 2026-02-01 11:37:34,766 [INFO] [batch_0109_384x384.parquet] Processed 1993 rows
155
+ 2026-02-01 11:38:36,780 [INFO] [batch_0110_384x384.parquet] Processed 1996 rows
156
+ 2026-02-01 11:39:53,826 [INFO] [batch_0111_384x384.parquet] Processed 1991 rows
157
+ 2026-02-01 11:40:48,014 [INFO] [batch_0112_384x384.parquet] Processed 1994 rows
158
+ 2026-02-01 11:41:49,113 [INFO] [batch_0113_384x384.parquet] Processed 1996 rows
159
+ 2026-02-01 11:42:56,188 [INFO] [batch_0114_384x384.parquet] Processed 1997 rows
160
+ 2026-02-01 11:43:43,288 [INFO] [batch_0115_384x384.parquet] Processed 1997 rows
161
+ 2026-02-01 11:44:48,748 [INFO] [batch_0116_384x384.parquet] Processed 1992 rows
162
+ 2026-02-01 11:45:54,394 [INFO] [batch_0117_384x384.parquet] Processed 1992 rows
163
+ 2026-02-01 11:46:53,275 [INFO] [batch_0118_384x384.parquet] Processed 1993 rows
164
+ 2026-02-01 11:48:08,611 [INFO] [batch_0119_384x384.parquet] Processed 1998 rows
165
+ 2026-02-01 11:49:07,195 [INFO] [batch_0120_384x384.parquet] Processed 1995 rows
166
+ 2026-02-01 11:50:22,347 [INFO] [batch_0121_384x384.parquet] Processed 1995 rows
167
+ 2026-02-01 11:51:26,391 [INFO] [batch_0122_384x384.parquet] Processed 1992 rows
168
+ 2026-02-01 11:52:22,734 [INFO] [batch_0123_384x384.parquet] Processed 1994 rows
169
+ 2026-02-01 11:53:34,357 [INFO] [batch_0124_384x384.parquet] Processed 1995 rows
170
+ 2026-02-01 11:54:05,024 [INFO] [batch_0125_384x384.parquet] Processed 1505 rows
171
+ 2026-02-01 11:54:05,024 [INFO] Merging all results...
172
+ 2026-02-01 11:54:05,057 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
173
+ 2026-02-01 11:54:05,058 [INFO] ✓ All unwanted columns removed
174
+ 2026-02-01 11:54:05,058 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
175
+ 2026-02-01 11:54:05,058 [INFO] ✓ Column 'pixel_bbox' added
176
+ 2026-02-01 11:54:05,058 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet...
177
+ 2026-02-01 11:54:06,861 [INFO] ================================================================================
178
+ 2026-02-01 11:54:06,861 [INFO] Processing complete!
179
+ 2026-02-01 11:54:06,861 [INFO] Total rows: 248,719
180
+ 2026-02-01 11:54:06,862 [INFO] Embedding dimension: 256
181
+ 2026-02-01 11:54:06,862 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
182
+ 2026-02-01 11:54:06,862 [INFO] ================================================================================
logs/compute_embeddings_siglip.log ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-02-01 09:43:14,001 [INFO] ================================================================================
2
+ 2026-02-01 09:43:14,002 [INFO] Computing SIGLIP embeddings
3
+ 2026-02-01 09:43:14,002 [INFO] Timestamp: 2026-02-01T09:43:14.002069
4
+ 2026-02-01 09:43:14,002 [INFO] Device: cuda:2
5
+ 2026-02-01 09:43:14,002 [INFO] ================================================================================
6
+ 2026-02-01 09:43:14,003 [INFO] Batch size: 64
7
+ 2026-02-01 09:43:14,003 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet
8
+ 2026-02-01 09:43:14,004 [INFO] Loading siglip model...
9
+ 2026-02-01 09:43:14,196 [INFO] Parsing model identifier. Schema: None, Identifier: ViT-SO400M-14-SigLIP-384
10
+ 2026-02-01 09:43:14,196 [INFO] Loaded built-in ViT-SO400M-14-SigLIP-384 model config.
11
+ 2026-02-01 09:43:14,197 [INFO] `pretrained` specifies file path: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
12
+ 2026-02-01 09:43:14,197 [INFO] Instantiating model architecture: CustomTextCLIP
13
+ 2026-02-01 09:43:22,955 [INFO] Loading full pretrained weights from: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
14
+ 2026-02-01 09:43:24,815 [INFO] Final image preprocessing configuration set: {'size': (384, 384), 'mode': 'RGB', 'mean': (0.48145466, 0.4578275, 0.40821073), 'std': (0.26862954, 0.26130258, 0.27577711), 'interpolation': 'bicubic', 'resize_mode': 'shortest', 'fill_color': 0}
15
+ 2026-02-01 09:43:24,815 [INFO] Model ViT-SO400M-14-SigLIP-384 creation process complete.
16
+ 2026-02-01 09:43:25,908 [INFO] SigLIP model loaded on cuda:2
17
+ 2026-02-01 09:43:25,909 [INFO] Found 125 input files
18
+ 2026-02-01 09:44:47,927 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
19
+ 2026-02-01 09:46:05,633 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
20
+ 2026-02-01 09:47:28,903 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
21
+ 2026-02-01 09:48:39,715 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
22
+ 2026-02-01 09:49:56,387 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
23
+ 2026-02-01 09:51:18,436 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
24
+ 2026-02-01 09:52:45,064 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
25
+ 2026-02-01 09:54:13,231 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
26
+ 2026-02-01 09:55:40,342 [INFO] ================================================================================
27
+ 2026-02-01 09:55:40,343 [INFO] Computing SIGLIP embeddings
28
+ 2026-02-01 09:55:40,343 [INFO] Timestamp: 2026-02-01T09:55:40.343045
29
+ 2026-02-01 09:55:40,343 [INFO] Device: cuda:2
30
+ 2026-02-01 09:55:40,343 [INFO] ================================================================================
31
+ 2026-02-01 09:55:40,344 [INFO] Batch size: 256
32
+ 2026-02-01 09:55:40,344 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet
33
+ 2026-02-01 09:55:40,344 [INFO] Loading siglip model...
34
+ 2026-02-01 09:55:40,494 [INFO] Parsing model identifier. Schema: None, Identifier: ViT-SO400M-14-SigLIP-384
35
+ 2026-02-01 09:55:40,494 [INFO] Loaded built-in ViT-SO400M-14-SigLIP-384 model config.
36
+ 2026-02-01 09:55:40,494 [INFO] `pretrained` specifies file path: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
37
+ 2026-02-01 09:55:40,494 [INFO] Instantiating model architecture: CustomTextCLIP
38
+ 2026-02-01 09:55:50,054 [INFO] Loading full pretrained weights from: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
39
+ 2026-02-01 09:55:52,457 [INFO] Final image preprocessing configuration set: {'size': (384, 384), 'mode': 'RGB', 'mean': (0.48145466, 0.4578275, 0.40821073), 'std': (0.26862954, 0.26130258, 0.27577711), 'interpolation': 'bicubic', 'resize_mode': 'shortest', 'fill_color': 0}
40
+ 2026-02-01 09:55:52,457 [INFO] Model ViT-SO400M-14-SigLIP-384 creation process complete.
41
+ 2026-02-01 09:55:53,533 [INFO] SigLIP model loaded on cuda:2
42
+ 2026-02-01 09:55:53,534 [INFO] Found 125 input files
43
+ 2026-02-01 09:57:15,361 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
44
+ 2026-02-01 09:58:38,916 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
45
+ 2026-02-01 10:00:13,289 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
46
+ 2026-02-01 10:01:38,351 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
47
+ 2026-02-01 10:03:13,561 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
48
+ 2026-02-01 10:04:55,295 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
49
+ 2026-02-01 10:06:42,957 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
50
+ 2026-02-01 10:08:27,547 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
51
+ 2026-02-01 10:10:15,515 [INFO] [batch_0009_384x384.parquet] Processed 1993 rows
52
+ 2026-02-01 10:11:54,632 [INFO] [batch_0010_384x384.parquet] Processed 1993 rows
53
+ 2026-02-01 10:13:42,862 [INFO] [batch_0011_384x384.parquet] Processed 1990 rows
54
+ 2026-02-01 10:15:23,412 [INFO] [batch_0012_384x384.parquet] Processed 1993 rows
55
+ 2026-02-01 10:16:55,431 [INFO] [batch_0013_384x384.parquet] Processed 1993 rows
56
+ 2026-02-01 10:18:30,326 [INFO] [batch_0014_384x384.parquet] Processed 1991 rows
57
+ 2026-02-01 10:19:54,738 [INFO] [batch_0015_384x384.parquet] Processed 1993 rows
58
+ 2026-02-01 10:21:25,001 [INFO] [batch_0016_384x384.parquet] Processed 1990 rows
59
+ 2026-02-01 10:23:00,423 [INFO] [batch_0017_384x384.parquet] Processed 1991 rows
60
+ 2026-02-01 10:24:21,837 [INFO] [batch_0018_384x384.parquet] Processed 1991 rows
61
+ 2026-02-01 10:26:00,517 [INFO] [batch_0019_384x384.parquet] Processed 1996 rows
62
+ 2026-02-01 10:27:39,553 [INFO] [batch_0020_384x384.parquet] Processed 1994 rows
63
+ 2026-02-01 10:29:02,772 [INFO] [batch_0021_384x384.parquet] Processed 1996 rows
64
+ 2026-02-01 10:30:43,286 [INFO] [batch_0022_384x384.parquet] Processed 1995 rows
65
+ 2026-02-01 10:32:18,498 [INFO] [batch_0023_384x384.parquet] Processed 1992 rows
66
+ 2026-02-01 10:33:59,552 [INFO] [batch_0024_384x384.parquet] Processed 1989 rows
67
+ 2026-02-01 10:35:36,652 [INFO] [batch_0025_384x384.parquet] Processed 1993 rows
68
+ 2026-02-01 10:37:22,505 [INFO] [batch_0026_384x384.parquet] Processed 1995 rows
69
+ 2026-02-01 10:39:04,911 [INFO] [batch_0027_384x384.parquet] Processed 1997 rows
70
+ 2026-02-01 10:40:47,184 [INFO] [batch_0028_384x384.parquet] Processed 1991 rows
71
+ 2026-02-01 10:42:27,627 [INFO] [batch_0029_384x384.parquet] Processed 1992 rows
72
+ 2026-02-01 10:43:40,600 [INFO] ================================================================================
73
+ 2026-02-01 10:43:40,600 [INFO] Computing SIGLIP embeddings
74
+ 2026-02-01 10:43:40,600 [INFO] Timestamp: 2026-02-01T10:43:40.600706
75
+ 2026-02-01 10:43:40,600 [INFO] Device: cuda:5
76
+ 2026-02-01 10:43:40,600 [INFO] ================================================================================
77
+ 2026-02-01 10:43:40,602 [INFO] Batch size: 64
78
+ 2026-02-01 10:43:40,602 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet
79
+ 2026-02-01 10:43:40,602 [INFO] Loading siglip model...
80
+ 2026-02-01 10:43:40,778 [INFO] Parsing model identifier. Schema: None, Identifier: ViT-SO400M-14-SigLIP-384
81
+ 2026-02-01 10:43:40,778 [INFO] Loaded built-in ViT-SO400M-14-SigLIP-384 model config.
82
+ 2026-02-01 10:43:40,778 [INFO] `pretrained` specifies file path: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
83
+ 2026-02-01 10:43:40,778 [INFO] Instantiating model architecture: CustomTextCLIP
84
+ 2026-02-01 10:43:59,641 [INFO] Loading full pretrained weights from: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
85
+ 2026-02-01 10:44:04,702 [INFO] Final image preprocessing configuration set: {'size': (384, 384), 'mode': 'RGB', 'mean': (0.48145466, 0.4578275, 0.40821073), 'std': (0.26862954, 0.26130258, 0.27577711), 'interpolation': 'bicubic', 'resize_mode': 'shortest', 'fill_color': 0}
86
+ 2026-02-01 10:44:04,702 [INFO] Model ViT-SO400M-14-SigLIP-384 creation process complete.
87
+ 2026-02-01 10:44:06,271 [INFO] SigLIP model loaded on cuda:5
88
+ 2026-02-01 10:44:06,272 [INFO] Found 125 input files
89
+ 2026-02-01 10:44:20,369 [INFO] [batch_0030_384x384.parquet] Processed 1993 rows
90
+ 2026-02-01 10:45:59,867 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
91
+ 2026-02-01 10:46:32,133 [INFO] [batch_0031_384x384.parquet] Processed 1988 rows
92
+ 2026-02-01 10:47:08,397 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
93
+ 2026-02-01 10:48:03,827 [INFO] [batch_0032_384x384.parquet] Processed 1994 rows
94
+ 2026-02-01 10:48:20,770 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
95
+ 2026-02-01 10:50:02,578 [INFO] [batch_0033_384x384.parquet] Processed 1992 rows
96
+ 2026-02-01 10:50:06,189 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
97
+ 2026-02-01 10:52:02,296 [INFO] [batch_0034_384x384.parquet] Processed 1994 rows
98
+ 2026-02-01 10:53:52,804 [INFO] [batch_0035_384x384.parquet] Processed 1994 rows
99
+ 2026-02-01 10:55:40,379 [INFO] [batch_0036_384x384.parquet] Processed 1996 rows
100
+ 2026-02-01 10:57:08,912 [INFO] [batch_0037_384x384.parquet] Processed 1995 rows
101
+ 2026-02-01 10:58:42,083 [INFO] [batch_0038_384x384.parquet] Processed 1993 rows
102
+ 2026-02-01 11:00:31,963 [INFO] [batch_0039_384x384.parquet] Processed 1989 rows
103
+ 2026-02-01 11:02:16,803 [INFO] [batch_0040_384x384.parquet] Processed 1990 rows
104
+ 2026-02-01 11:04:12,580 [INFO] [batch_0041_384x384.parquet] Processed 1998 rows
105
+ 2026-02-01 11:05:52,695 [INFO] [batch_0042_384x384.parquet] Processed 1997 rows
106
+ 2026-02-01 11:07:38,215 [INFO] [batch_0043_384x384.parquet] Processed 1988 rows
107
+ 2026-02-01 11:09:18,740 [INFO] [batch_0044_384x384.parquet] Processed 1991 rows
108
+ 2026-02-01 11:10:59,852 [INFO] [batch_0045_384x384.parquet] Processed 1993 rows
109
+ 2026-02-01 11:12:35,695 [INFO] [batch_0046_384x384.parquet] Processed 1994 rows
110
+ 2026-02-01 11:14:12,998 [INFO] [batch_0047_384x384.parquet] Processed 1995 rows
111
+ 2026-02-01 11:15:30,214 [INFO] [batch_0048_384x384.parquet] Processed 1992 rows
112
+ 2026-02-01 11:17:05,225 [INFO] [batch_0049_384x384.parquet] Processed 1996 rows
113
+ 2026-02-01 11:18:50,252 [INFO] [batch_0050_384x384.parquet] Processed 1991 rows
114
+ 2026-02-01 11:20:25,931 [INFO] [batch_0051_384x384.parquet] Processed 1997 rows
115
+ 2026-02-01 11:21:43,527 [INFO] [batch_0052_384x384.parquet] Processed 1993 rows
116
+ 2026-02-01 11:23:12,150 [INFO] [batch_0053_384x384.parquet] Processed 1995 rows
117
+ 2026-02-01 11:24:47,385 [INFO] [batch_0054_384x384.parquet] Processed 1997 rows
118
+ 2026-02-01 11:26:31,520 [INFO] [batch_0055_384x384.parquet] Processed 1995 rows
119
+ 2026-02-01 11:28:03,476 [INFO] [batch_0056_384x384.parquet] Processed 1997 rows
120
+ 2026-02-01 11:29:48,548 [INFO] [batch_0057_384x384.parquet] Processed 1991 rows
121
+ 2026-02-01 11:31:29,605 [INFO] [batch_0058_384x384.parquet] Processed 1994 rows
122
+ 2026-02-01 11:33:17,760 [INFO] [batch_0059_384x384.parquet] Processed 1993 rows
123
+ 2026-02-01 11:34:50,684 [INFO] [batch_0060_384x384.parquet] Processed 1995 rows
124
+ 2026-02-01 11:36:38,080 [INFO] [batch_0061_384x384.parquet] Processed 1995 rows
125
+ 2026-02-01 11:38:19,287 [INFO] [batch_0062_384x384.parquet] Processed 1998 rows
126
+ 2026-02-01 11:40:01,382 [INFO] [batch_0063_384x384.parquet] Processed 1997 rows
127
+ 2026-02-01 11:41:28,396 [INFO] [batch_0064_384x384.parquet] Processed 1992 rows
128
+ 2026-02-01 11:43:07,187 [INFO] [batch_0065_384x384.parquet] Processed 1994 rows
129
+ 2026-02-01 11:44:47,035 [INFO] [batch_0066_384x384.parquet] Processed 1992 rows
130
+ 2026-02-01 11:46:38,657 [INFO] [batch_0067_384x384.parquet] Processed 1993 rows
131
+ 2026-02-01 11:48:25,045 [INFO] [batch_0068_384x384.parquet] Processed 1994 rows
132
+ 2026-02-01 11:50:24,090 [INFO] [batch_0069_384x384.parquet] Processed 1992 rows
133
+ 2026-02-01 11:52:05,360 [INFO] [batch_0070_384x384.parquet] Processed 1997 rows
134
+ 2026-02-01 11:53:51,383 [INFO] [batch_0071_384x384.parquet] Processed 1996 rows
135
+ 2026-02-01 11:55:00,188 [INFO] [batch_0072_384x384.parquet] Processed 1992 rows
136
+ 2026-02-01 11:56:16,122 [INFO] [batch_0073_384x384.parquet] Processed 1995 rows
137
+ 2026-02-01 11:57:30,601 [INFO] [batch_0074_384x384.parquet] Processed 1992 rows
138
+ 2026-02-01 11:58:47,717 [INFO] [batch_0075_384x384.parquet] Processed 1991 rows
139
+ 2026-02-01 12:00:01,207 [INFO] [batch_0076_384x384.parquet] Processed 1998 rows
140
+ 2026-02-01 12:01:14,471 [INFO] [batch_0077_384x384.parquet] Processed 1996 rows
141
+ 2026-02-01 12:02:31,575 [INFO] [batch_0078_384x384.parquet] Processed 1992 rows
142
+ 2026-02-01 12:03:52,303 [INFO] [batch_0079_384x384.parquet] Processed 1995 rows
143
+ 2026-02-01 12:05:06,370 [INFO] [batch_0080_384x384.parquet] Processed 1993 rows
144
+ 2026-02-01 12:06:16,989 [INFO] [batch_0081_384x384.parquet] Processed 1995 rows
145
+ 2026-02-01 12:07:32,029 [INFO] [batch_0082_384x384.parquet] Processed 1989 rows
146
+ 2026-02-01 12:08:47,568 [INFO] [batch_0083_384x384.parquet] Processed 1995 rows
147
+ 2026-02-01 12:10:03,544 [INFO] [batch_0084_384x384.parquet] Processed 1996 rows
148
+ 2026-02-01 12:11:20,376 [INFO] [batch_0085_384x384.parquet] Processed 1997 rows
149
+ 2026-02-01 12:12:38,318 [INFO] [batch_0086_384x384.parquet] Processed 1996 rows
150
+ 2026-02-01 12:13:56,314 [INFO] [batch_0087_384x384.parquet] Processed 1994 rows
151
+ 2026-02-01 12:15:14,513 [INFO] [batch_0088_384x384.parquet] Processed 1992 rows
152
+ 2026-02-01 12:16:32,334 [INFO] [batch_0089_384x384.parquet] Processed 1993 rows
153
+ 2026-02-01 12:17:52,186 [INFO] [batch_0090_384x384.parquet] Processed 1993 rows
154
+ 2026-02-01 12:19:10,443 [INFO] [batch_0091_384x384.parquet] Processed 1995 rows
155
+ 2026-02-01 12:20:24,543 [INFO] [batch_0092_384x384.parquet] Processed 1994 rows
156
+ 2026-02-01 12:21:42,150 [INFO] [batch_0093_384x384.parquet] Processed 1998 rows
157
+ 2026-02-01 12:22:50,203 [INFO] [batch_0094_384x384.parquet] Processed 1993 rows
158
+ 2026-02-01 12:24:08,849 [INFO] [batch_0095_384x384.parquet] Processed 1995 rows
159
+ 2026-02-01 12:25:14,387 [INFO] [batch_0096_384x384.parquet] Processed 1997 rows
160
+ 2026-02-01 12:26:27,496 [INFO] [batch_0097_384x384.parquet] Processed 1990 rows
161
+ 2026-02-01 12:27:38,051 [INFO] [batch_0098_384x384.parquet] Processed 1995 rows
162
+ 2026-02-01 12:28:46,151 [INFO] [batch_0099_384x384.parquet] Processed 1992 rows
163
+ 2026-02-01 12:29:56,731 [INFO] [batch_0100_384x384.parquet] Processed 1993 rows
164
+ 2026-02-01 12:31:13,328 [INFO] [batch_0101_384x384.parquet] Processed 1994 rows
165
+ 2026-02-01 12:32:22,428 [INFO] [batch_0102_384x384.parquet] Processed 1991 rows
166
+ 2026-02-01 12:33:34,185 [INFO] [batch_0103_384x384.parquet] Processed 1990 rows
167
+ 2026-02-01 12:34:42,817 [INFO] [batch_0104_384x384.parquet] Processed 1995 rows
168
+ 2026-02-01 12:35:53,075 [INFO] [batch_0105_384x384.parquet] Processed 1993 rows
169
+ 2026-02-01 12:37:03,504 [INFO] [batch_0106_384x384.parquet] Processed 1988 rows
170
+ 2026-02-01 12:38:12,118 [INFO] [batch_0107_384x384.parquet] Processed 1996 rows
171
+ 2026-02-01 12:39:26,579 [INFO] [batch_0108_384x384.parquet] Processed 1992 rows
172
+ 2026-02-01 12:40:38,968 [INFO] [batch_0109_384x384.parquet] Processed 1993 rows
173
+ 2026-02-01 12:41:50,225 [INFO] [batch_0110_384x384.parquet] Processed 1996 rows
174
+ 2026-02-01 12:42:59,782 [INFO] [batch_0111_384x384.parquet] Processed 1991 rows
175
+ 2026-02-01 12:44:13,255 [INFO] [batch_0112_384x384.parquet] Processed 1994 rows
176
+ 2026-02-01 12:45:25,863 [INFO] [batch_0113_384x384.parquet] Processed 1996 rows
177
+ 2026-02-01 12:46:42,753 [INFO] [batch_0114_384x384.parquet] Processed 1997 rows
178
+ 2026-02-01 12:47:54,480 [INFO] [batch_0115_384x384.parquet] Processed 1997 rows
179
+ 2026-02-01 12:49:00,711 [INFO] [batch_0116_384x384.parquet] Processed 1992 rows
180
+ 2026-02-01 12:50:12,844 [INFO] [batch_0117_384x384.parquet] Processed 1992 rows
181
+ 2026-02-01 12:51:27,205 [INFO] [batch_0118_384x384.parquet] Processed 1993 rows
182
+ 2026-02-01 12:52:36,479 [INFO] [batch_0119_384x384.parquet] Processed 1998 rows
183
+ 2026-02-01 12:53:54,416 [INFO] [batch_0120_384x384.parquet] Processed 1995 rows
184
+ 2026-02-01 12:55:03,501 [INFO] [batch_0121_384x384.parquet] Processed 1995 rows
185
+ 2026-02-01 12:56:14,997 [INFO] [batch_0122_384x384.parquet] Processed 1992 rows
186
+ 2026-02-01 12:57:29,495 [INFO] [batch_0123_384x384.parquet] Processed 1994 rows
187
+ 2026-02-01 12:58:37,341 [INFO] [batch_0124_384x384.parquet] Processed 1995 rows
188
+ 2026-02-01 12:59:35,927 [INFO] [batch_0125_384x384.parquet] Processed 1505 rows
189
+ 2026-02-01 12:59:35,927 [INFO] Merging all results...
190
+ 2026-02-01 12:59:35,965 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
191
+ 2026-02-01 12:59:35,965 [INFO] ✓ All unwanted columns removed
192
+ 2026-02-01 12:59:35,965 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
193
+ 2026-02-01 12:59:35,965 [INFO] ✓ Column 'pixel_bbox' added
194
+ 2026-02-01 12:59:35,965 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet...
195
+ 2026-02-01 12:59:44,647 [INFO] ================================================================================
196
+ 2026-02-01 12:59:44,648 [INFO] Processing complete!
197
+ 2026-02-01 12:59:44,648 [INFO] Total rows: 248,719
198
+ 2026-02-01 12:59:44,648 [INFO] Embedding dimension: 1152
199
+ 2026-02-01 12:59:44,648 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet
200
+ 2026-02-01 12:59:44,648 [INFO] ================================================================================
models/FarSLIP/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/logs/
2
+ **/wandb/
3
+ models/
4
+ features/
5
+ results/
6
+ src/open_clip_train/config.py
7
+ src/open_clip_train/output_samples/
8
+ **/results_retrieval/
9
+ **/results_classification/
10
+ checkpoints/
11
+
12
+ tests/data/
13
+ *.pt
14
+
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ pip-wheel-metadata/
38
+ share/python-wheels/
39
+ *.egg-info/
40
+ .installed.cfg
41
+ *.egg
42
+ MANIFEST
43
+
44
+ # PyInstaller
45
+ # Usually these files are written by a python script from a template
46
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
47
+ *.manifest
48
+ *.spec
49
+
50
+ # Installer logs
51
+ pip-log.txt
52
+ pip-delete-this-directory.txt
53
+
54
+ # Unit test / coverage reports
55
+ htmlcov/
56
+ .tox/
57
+ .nox/
58
+ .coverage
59
+ .coverage.*
60
+ .cache
61
+ nosetests.xml
62
+ coverage.xml
63
+ *.cover
64
+ *.py,cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+ db.sqlite3-journal
77
+
78
+ # Flask stuff:
79
+ instance/
80
+ .webassets-cache
81
+
82
+ # Scrapy stuff:
83
+ .scrapy
84
+
85
+ # Sphinx documentation
86
+ docs/_build/
87
+
88
+ # PyBuilder
89
+ target/
90
+
91
+ # Jupyter Notebook
92
+ .ipynb_checkpoints
93
+
94
+ # IPython
95
+ profile_default/
96
+ ipython_config.py
97
+
98
+ # pyenv
99
+ .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109
+ __pypackages__/
110
+
111
+ # Celery stuff
112
+ celerybeat-schedule
113
+ celerybeat.pid
114
+
115
+ # SageMath parsed files
116
+ *.sage.py
117
+
118
+ # Environments
119
+ .env
120
+ .venv
121
+ env/
122
+ venv/
123
+ ENV/
124
+ env.bak/
125
+ venv.bak/
126
+
127
+ # Spyder project settings
128
+ .spyderproject
129
+ .spyproject
130
+
131
+ # Rope project settings
132
+ .ropeproject
133
+
134
+ # mkdocs documentation
135
+ /site
136
+
137
+ # mypy
138
+ .mypy_cache/
139
+ .dmypy.json
140
+ dmypy.json
141
+
142
+ # Pyre type checker
143
+ .pyre/
144
+ sync.sh
145
+ gpu1sync.sh
146
+ .idea
147
+ *.pdf
148
+ **/._*
149
+ **/*DS_*
150
+ **.jsonl
151
+ src/sbatch
152
+ src/misc
153
+ .vscode
154
+ src/debug
155
+ core.*
156
+
157
+ *.out
158
+
159
+ # Allow
160
+ !src/evaluation/misc/results_dbs/*
models/FarSLIP/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 LHRS
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/FarSLIP/README.md ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center"> FarSLIP: Discovering Effective CLIP Adaptation for Fine-Grained Remote Sensing Understanding </h1>
2
+
3
+ <p align="center">
4
+ <a href="https://huggingface.co/datasets/ZhenShiL/MGRS-200k">
5
+ <img alt="Hugging Face Dataset" src="https://img.shields.io/badge/🤗%20Hugging%20Face-Dataset-blue">
6
+ </a>
7
+ <a href="https://huggingface.co/ZhenShiL/FarSLIP">
8
+ <img alt="Hugging Face Model" src="https://img.shields.io/badge/🤗%20Hugging%20Face-Model-yellow">
9
+ </a>
10
+ <a href="https://arxiv.org/abs/2511.14901">
11
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-2511.14901-b31b1b">
12
+ </a>
13
+ </p>
14
+
15
+
16
+ ## Introduction
17
+ We introduce FarSLIP, a vision-language foundation model for remote sensing (RS) that achieves fine-grained vision-language alignment. FarSLIP demonstrates state-of-the-art performance on both fine-grained and image-level tasks, including open-vocabulary semantic segmentation, zero-shot classification, and image-text retrieval.
18
+ We also construct MGRS-200k, the first multi-granularity image-text dataset for RS. Each image is annotated with both short and long global-level captions, along with multiple object-category pairs.
19
+
20
+ <figure>
21
+ <div align="center">
22
+ <img src=assets/model.png width="60%">
23
+ </div>
24
+ </figure>
25
+
26
+
27
+ ## Table of Contents
28
+ - [Introduction](#Introduction)
29
+ - [Preparation](#Preparation)
30
+ - [Installation](#Installation)
31
+ - [Checkpoints](#Checkpoints)
32
+ - [Dataset](#Dataset)
33
+ - [Training](#Training)
34
+ - [Testing](#Testing)
35
+ - [Open-vocabulary semantic segmentation](#open-vocabulary-semantic-segmentation)
36
+ - [Zero-shot scene classification](#zero-shot-scene-classification)
37
+ - [Zero-shot image-text retrieval](#zero-shot-image-text-retrieval)
38
+ - [Acknowledgement](#Acknowledgement)
39
+ - [Citing](#Citing)
40
+
41
+
42
+
43
+
44
+
45
+ ## Preparation
46
+
47
+ ### Installation
48
+
49
+ 1. Clone this repository.
50
+
51
+ ~~~shell
52
+ git clone git@github.com:NJU-LHRS/FarSLIP.git
53
+ cd FarSLIP
54
+ ~~~
55
+
56
+ 2. Create a new virtual environment.
57
+
58
+ ~~~shell
59
+ conda create -n farslip python=3.10
60
+ conda activate farslip
61
+ ~~~
62
+
63
+ 3. Install dependences.
64
+
65
+ ~~~shell
66
+ pip install -r requirements.txt
67
+ ~~~
68
+
69
+ ### Checkpoints
70
+ You can download all our checkpoints from [Huggingface](https://huggingface.co/ZhenShiL/FarSLIP), or selectively download them through the links below.
71
+
72
+ | Model name | ViT-arch. | Test encoder | OVSS mIoU (%) | ZSC top-1 acc. (%) | Download |
73
+ |-------------|-----------|--------------|----------------|--------------------|----------------|
74
+ | FarSLIP-s1 | ViT-B-32 | Vanilla | 29.87 | 58.64 | [FarSLIP1_ViT-B-32](https://huggingface.co/ZhenShiL/FarSLIP/resolve/main/FarSLIP1_ViT-B-32.pt?download=true) |
75
+ | FarSLIP-s1 | ViT-B-16 | LongCLIP | 35.44 | 61.89 | [FarSLIP1_ViT-B-16](https://huggingface.co/ZhenShiL/FarSLIP/resolve/main/FarSLIP1_ViT-B-16.pt?download=true) |
76
+ | FarSLIP-s2 | ViT-B-32 | Vanilla | 30.49 | 60.12 | [FarSLIP2_ViT-B-32](https://huggingface.co/ZhenShiL/FarSLIP/resolve/main/FarSLIP2_ViT-B-32.pt?download=true) |
77
+ | FarSLIP-s2 | ViT-B-16 | LongCLIP | 35.41 | 62.24 | [FarSLIP2_ViT-B-16](https://huggingface.co/ZhenShiL/FarSLIP/resolve/main/FarSLIP2_ViT-B-16.pt?download=true) |
78
+
79
+
80
+ ### Dataset
81
+ FarSLIP is trained in two stages.
82
+ + In the first stage, we use the [RS5M](https://github.com/om-ai-lab/RS5M) dataset. A quick portal to the RS5M dataset: [link](https://huggingface.co/datasets/omlab/RS5M).
83
+ + In the second stage, we use the proposed MGRS-200k dataset, which is available on [Huggingface](https://huggingface.co/datasets/ZhenShiL/MGRS-200k).
84
+
85
+ [//]: # (<figure>)
86
+
87
+ [//]: # (<div align="center">)
88
+
89
+ [//]: # (<img src=assets/dataset.png width="80%">)
90
+
91
+ [//]: # (</div>)
92
+
93
+ [//]: # (<figcaption align="center"><em>Examples from MGRS-200k</em></figcaption>)
94
+
95
+ [//]: # (</figure>)
96
+
97
+ <p align="center">
98
+ <img src="assets/dataset.png" width="100%">
99
+ <br>
100
+ <em>Examples from MGRS-200k</em>
101
+ </p>
102
+
103
+ ## Training
104
+
105
+ + Validation data preparation
106
+ + Replace --root-val-img-dir and --val-data in [config.py](./open_clip_train/config.py) with the paths to your [SkyScript](https://github.com/wangzhecheng/SkyScript?tab=readme-ov-file#download) validation dataset ('SkyScript_val_5K_filtered_by_CLIP_openai').
107
+ + Stage1
108
+ ~~~shell
109
+ torchrun --nproc_per_node=4 -m open_clip_train.main \
110
+ --train-dataset-name RS5M \
111
+ --train-data '/your/path/to/rs5m/{pub11,rs3}-train-{0000..0031}.tar' \
112
+ --train-dataset-type webdataset \
113
+ --train-num-samples 5070186 \
114
+ --method farslip1 \
115
+ --use-imagecrop-aug \
116
+ --local-method randomcrops \
117
+ --warmup 1000 \
118
+ --batch-size 40 \
119
+ --lr 1e-6 \
120
+ --wd 1.0 \
121
+ --epochs 1 \
122
+ --model ViT-B-16 \
123
+ --loss-type global_itc distill \
124
+ --distill-align roi2pooled
125
+ ~~~
126
+
127
+ + Stage2
128
+ ~~~shell
129
+ torchrun --nproc_per_node=4 -m open_clip_train.main \
130
+ --train-dataset-name MGRS \
131
+ --root-train-img-dir '/your/path/to/mgrs/global_imgs/' \
132
+ --train-data '/your/path/to/mgrs/text_info.json' \
133
+ --train-dataset-type json \
134
+ --method farslip2 \
135
+ --warmup 250 \
136
+ --batch-size 40 \
137
+ --lr 4e-9 \
138
+ --wd 1.0 \
139
+ --epochs 10 \
140
+ --model ViT-B-16 \
141
+ --loss-type global_itc local_itc \
142
+ --local-itc-align cls
143
+ ~~~
144
+
145
+ ## Testing
146
+ ### Open-vocabulary semantic segmentation
147
+ + Please checkout [FarSLIP-OVSS](https://github.com/NJU-LHRS/FarSLIP-OVSS) for evaluation of open-vocabulary semantic segmentation in RS images.
148
+
149
+ <p align="center">
150
+ <img src="assets/ovss.png" width="100%">
151
+ <br>
152
+ <em>
153
+ OVSS accuracies across RS benchmarks (mIoU, %). G denotes general-domain models, and RS refers to RS-specific models.
154
+ f. indicates models specifically designed with fine-grained optimization. All models use an input image size of 224, except TIPS (448)
155
+ </em>
156
+ </p>
157
+
158
+
159
+
160
+ ### Zero-shot scene classification
161
+ + Please refer to [SkyScript](https://github.com/wangzhecheng/SkyScript?tab=readme-ov-file#download-benchmark-datasets) for scene classification dataset preparation, including 'SkyScript_cls', 'aid', 'eurosat', 'fmow', 'millionaid', 'patternnet', 'rsicb', 'nwpu'.
162
+ + Replace the BENCHMARK_DATASET_ROOT_DIR in [tests/test_scene_classification.py](./tests/test_scene_classification.py) to your own path.
163
+
164
+ + Run testing:
165
+ + FarSLIP-s1
166
+ ```
167
+ python -m tests.test_scene_classification --model-arch $VIT --model-name FarSLIP1 --force-quick-gelu --pretrained checkpoints/FarSLIP1_$VIT.pt
168
+ ```
169
+ <!-- + FarSLIP-s2 with vanilla CLIP text encoder
170
+ ```
171
+ python -m tests.test_scene_classification --model-arch $VIT --model-name FarSLIP2_VC --force-quick-gelu --pretrained checkpoints/FarSLIP2_VC_$VIT.pt
172
+ ``` -->
173
+ + FarSLIP-s2 with LongCLIP text encoder (supporting long text)
174
+ ```
175
+ python -m tests.test_scene_classification --model-arch $VIT --model-name FarSLIP2 --force-quick-gelu --pretrained checkpoints/FarSLIP2_$VIT.pt --use-long-clip
176
+ ```
177
+ - `$VIT` options: `ViT-B-16`, `ViT-B-32`
178
+
179
+ <figure>
180
+ <div align="center">
181
+ <img src=assets/classification.png width="100%">
182
+ </div>
183
+ <figcaption align="center">
184
+ <em>Comparison of zero-shot classification accuracies (Top-1 acc., %) of different RS-specific CLIP variants across multiple benchmarks.</em>
185
+ </figcaption>
186
+ </figure>
187
+
188
+
189
+ ### Zero-shot image-text retrieval
190
+ + Please refer to [SkyScript](https://github.com/wangzhecheng/SkyScript?tab=readme-ov-file#download-benchmark-datasets) for image-text retrieval dataset preparation, including 'RSICD', 'RSITMD', 'ucmcaptions', and ['SkyScript-retrieval'](https://github.com/wangzhecheng/SkyScript?tab=readme-ov-file#download) ('SkyScript_test_30K_filtered_by_CLIP_openai.csv').
191
+ + Replace the DATA_CSV_PATH_DICT, SKYSCRIPT_IMAGE_DIR, RETRIEVAL_IMAGE_DIR in [tests/test_retrieval.py](./tests/test_retrieval.py) to your own path.
192
+
193
+ + Run testing:
194
+ + FarSLIP-s1
195
+ ```
196
+ python -m tests.test_retrieval --model-arch $VIT --model-name FarSLIP1 --force-quick-gelu --pretrained checkpoints/FarSLIP1_$VIT.pt
197
+ ```
198
+ <!-- + FarSLIP-s2 with vanilla CLIP text encoder
199
+ ```
200
+ python -m tests.test_retrieval --model-arch $VIT --model-name FarSLIP2_VC --force-quick-gelu --pretrained checkpoints/FarSLIP2_VC_$VIT.pt
201
+ ``` -->
202
+ + FarSLIP-s2 with LongCLIP text encoder (supporting long text)
203
+ ```
204
+ python -m tests.test_retrieval --model-arch $VIT --model-name FarSLIP2 --force-quick-gelu --pretrained checkpoints/FarSLIP2_$VIT.pt --use-long-clip
205
+ ```
206
+ - `$VIT` options: `ViT-B-16`, `ViT-B-32`
207
+
208
+
209
+ <div align="center">
210
+ <img src=assets/retrieval.png width="50%">
211
+ </div>
212
+ <figcaption align="center">
213
+ <em>Comparison of cross-modal retrieval accuracies (%) of different RS-specific CLIP variants across multiple benchmarks. *
214
+ indicates models trained with in-hold supervision.</em>
215
+ </figcaption>
216
+ </figure>
217
+
218
+
219
+
220
+
221
+ ## Acknowledgement
222
+
223
+ + We gratitude to the following repositories for their wonderful works: [Open-CLIP](https://github.com/mlfoundations/open_clip), [CLIPSelf](https://github.com/wusize/CLIPSelf), [FineCLIP](https://github.com/Timsty1/FineCLIP), [Long-CLIP](https://github.com/beichenzbc/Long-CLIP), [SkyScript](https://github.com/wangzhecheng/SkyScript), [SegEarth](https://github.com/likyoo/SegEarth-OV).
224
+
225
+
226
+ ## Citing
227
+
228
+ + If you find our work is useful, please give us 🌟 in GitHub and consider cite our paper:
229
+
230
+ ~~~tex
231
+ @article{li2025farslip,
232
+ title={FarSLIP: Discovering Effective CLIP Adaptation for Fine-Grained Remote Sensing Understanding},
233
+ author={Zhenshi Li and Weikang Yu and Dilxat Muhtar and Xueliang Zhang and Pengfeng Xiao and Pedram Ghamisi and Xiao Xiang Zhu},
234
+ journal={arXiv preprint arXiv:2511.14901},
235
+ year={2025}
236
+ }
237
+ ~~~
models/FarSLIP/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .open_clip import *
models/FarSLIP/open_clip/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .version import __version__
2
+
3
+ from .coca_model import CoCa
4
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
5
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
6
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
7
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
8
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
9
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
10
+ get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
11
+ from .openai import load_openai_model, list_openai_models
12
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
13
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
14
+ from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
15
+ from .tokenizer import SimpleTokenizer, tokenize, decode
16
+ from .transform import image_transform, AugmentationCfg
17
+ from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
18
+ from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
models/FarSLIP/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
models/FarSLIP/open_clip/coca_model.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StopStringCriteria,
27
+ EosTokenCriteria,
28
+ StoppingCriteriaList
29
+ )
30
+
31
+ GENERATION_TYPES = {
32
+ "top_k": TopKLogitsWarper,
33
+ "top_p": TopPLogitsWarper,
34
+ "beam_search": "beam_search"
35
+ }
36
+ _has_transformers = True
37
+ except ImportError as e:
38
+ GENERATION_TYPES = {
39
+ "top_k": None,
40
+ "top_p": None,
41
+ "beam_search": "beam_search"
42
+ }
43
+ _has_transformers = False
44
+
45
+
46
+ @dataclass
47
+ class MultimodalCfg(CLIPTextCfg):
48
+ mlp_ratio: int = 4
49
+ dim_head: int = 64
50
+ heads: int = 8
51
+ n_queries: int = 256
52
+ attn_pooler_heads: int = 8
53
+
54
+
55
+ def _build_text_decoder_tower(
56
+ embed_dim,
57
+ multimodal_cfg,
58
+ quick_gelu: bool = False,
59
+ cast_dtype: Optional[torch.dtype] = None,
60
+ ):
61
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
62
+ act_layer = QuickGELU if quick_gelu else nn.GELU
63
+ norm_layer = (
64
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
65
+ )
66
+
67
+ decoder = MultimodalTransformer(
68
+ context_length=multimodal_cfg.context_length,
69
+ width=multimodal_cfg.width,
70
+ heads=multimodal_cfg.heads,
71
+ layers=multimodal_cfg.layers,
72
+ ls_init_value=multimodal_cfg.ls_init_value,
73
+ output_dim=embed_dim,
74
+ act_layer=act_layer,
75
+ norm_layer=norm_layer,
76
+ )
77
+
78
+ return decoder
79
+
80
+
81
+ def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
82
+ if not isinstance(token_id, torch.Tensor):
83
+ if isinstance(token_id, int):
84
+ token_id = [token_id]
85
+ token_id = torch.tensor(token_id, device=device)
86
+ return token_id
87
+
88
+
89
+ class CoCa(nn.Module):
90
+ def __init__(
91
+ self,
92
+ embed_dim,
93
+ multimodal_cfg: MultimodalCfg,
94
+ text_cfg: CLIPTextCfg,
95
+ vision_cfg: CLIPVisionCfg,
96
+ quick_gelu: bool = False,
97
+ init_logit_scale: float = np.log(1 / 0.07),
98
+ init_logit_bias: Optional[float] = None,
99
+ nonscalar_logit_scale: bool = False,
100
+ cast_dtype: Optional[torch.dtype] = None,
101
+ pad_id: int = 0,
102
+ ):
103
+ super().__init__()
104
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
105
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
106
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
107
+
108
+ self.text = _build_text_tower(
109
+ embed_dim=embed_dim,
110
+ text_cfg=text_cfg,
111
+ quick_gelu=quick_gelu,
112
+ cast_dtype=cast_dtype,
113
+ )
114
+
115
+ vocab_size = (
116
+ text_cfg.vocab_size # for hf models
117
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
118
+ else text_cfg.vocab_size
119
+ )
120
+
121
+ self.visual = _build_vision_tower(
122
+ embed_dim=embed_dim,
123
+ vision_cfg=vision_cfg,
124
+ quick_gelu=quick_gelu,
125
+ cast_dtype=cast_dtype,
126
+ )
127
+
128
+ self.text_decoder = _build_text_decoder_tower(
129
+ vocab_size,
130
+ multimodal_cfg=multimodal_cfg,
131
+ quick_gelu=quick_gelu,
132
+ cast_dtype=cast_dtype,
133
+ )
134
+
135
+ lshape = [1] if nonscalar_logit_scale else []
136
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
137
+ if init_logit_bias is not None:
138
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
139
+ else:
140
+ self.logit_bias = None
141
+ self.pad_id = pad_id
142
+
143
+ self.context_length = multimodal_cfg.context_length
144
+
145
+ @torch.jit.ignore
146
+ def set_grad_checkpointing(self, enable: bool = True):
147
+ self.visual.set_grad_checkpointing(enable)
148
+ self.text.set_grad_checkpointing(enable)
149
+ self.text_decoder.set_grad_checkpointing(enable)
150
+
151
+ def _encode_image(self, images, normalize: bool = True):
152
+ image_latent, tokens_embs = self.visual(images)
153
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
154
+ return image_latent, tokens_embs
155
+
156
+ def _encode_text(self, text, normalize: bool = True):
157
+ text_latent, token_emb = self.text(text)
158
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
159
+ return text_latent, token_emb
160
+
161
+ def encode_image(self, images, normalize: bool = True):
162
+ image_latent, _ = self._encode_image(images, normalize=normalize)
163
+ return image_latent
164
+
165
+ def encode_text(self, text, normalize: bool = True):
166
+ text_latent, _ = self._encode_text(text, normalize=normalize)
167
+ return text_latent
168
+
169
+ def forward_intermediates(
170
+ self,
171
+ image: Optional[torch.Tensor] = None,
172
+ text: Optional[torch.Tensor] = None,
173
+ image_indices: Optional[Union[int, List[int]]] = None,
174
+ text_indices: Optional[Union[int, List[int]]] = None,
175
+ stop_early: bool = False,
176
+ normalize: bool = True,
177
+ normalize_intermediates: bool = False,
178
+ intermediates_only: bool = False,
179
+ image_output_fmt: str = 'NCHW',
180
+ image_output_extra_tokens: bool = False,
181
+ text_output_fmt: str = 'NLC',
182
+ text_output_extra_tokens: bool = False,
183
+ output_logits: bool = False,
184
+ output_logit_scale_bias: bool = False,
185
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
186
+ """ Forward features that returns intermediates.
187
+
188
+ Args:
189
+ image: Input image tensor
190
+ text: Input text tensor
191
+ image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
192
+ text_indices: Take last n blocks if int, all if None, select matching indices if sequence
193
+ stop_early: Stop iterating over blocks when last desired intermediate hit
194
+ normalize: L2 Normalize final image and text features (if present)
195
+ normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
196
+ intermediates_only: Only return intermediate features, do not return final features
197
+ image_output_fmt: Shape of intermediate image feature outputs
198
+ image_output_extra_tokens: Return both prefix and spatial intermediate tokens
199
+ text_output_fmt: Shape of intermediate text feature outputs
200
+ text_output_extra_tokens: Return both prefix and spatial intermediate tokens
201
+ output_logits: Include logits in output
202
+ output_logit_scale_bias: Include the logit scale bias in the output
203
+ Returns:
204
+
205
+ """
206
+ output = {}
207
+ if intermediates_only:
208
+ # intermediates only disables final feature normalization, and include logits
209
+ normalize = False
210
+ output_logits = False
211
+ if output_logits:
212
+ assert False, 'FIXME, needs implementing'
213
+
214
+ if image is not None:
215
+ image_output = self.visual.forward_intermediates(
216
+ image,
217
+ indices=image_indices,
218
+ stop_early=stop_early,
219
+ normalize_intermediates=normalize_intermediates,
220
+ intermediates_only=intermediates_only,
221
+ output_fmt=image_output_fmt,
222
+ output_extra_tokens=image_output_extra_tokens,
223
+ )
224
+ if normalize and "image_features" in image_output:
225
+ image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
226
+ output.update(image_output)
227
+
228
+ if text is not None:
229
+ text_output = self.text.forward_intermediates(
230
+ text,
231
+ indices=text_indices,
232
+ stop_early=stop_early,
233
+ normalize_intermediates=normalize_intermediates,
234
+ intermediates_only=intermediates_only,
235
+ output_fmt=text_output_fmt,
236
+ output_extra_tokens=text_output_extra_tokens,
237
+ )
238
+ if normalize and "text_features" in text_output:
239
+ text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
240
+ output.update(text_output)
241
+
242
+ # FIXME text decoder
243
+ logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
244
+ if output_logit_scale_bias:
245
+ output["logit_scale"] = logit_scale_exp
246
+ if self.logit_bias is not None:
247
+ output['logit_bias'] = self.logit_bias
248
+
249
+ return output
250
+
251
+ def forward(
252
+ self,
253
+ image,
254
+ text: Optional[torch.Tensor] = None,
255
+ image_latent: Optional[torch.Tensor] = None,
256
+ image_embs: Optional[torch.Tensor] = None,
257
+ output_labels: bool = True,
258
+ ):
259
+ if image_latent is None or image_embs is None:
260
+ image_latent, image_embs = self._encode_image(image)
261
+
262
+ if text is None:
263
+ return {"image_features": image_latent, "image_embs": image_embs}
264
+
265
+ text_latent, token_embs = self._encode_text(text)
266
+
267
+ # FIXME this isn't an ideal solution, would like to improve -RW
268
+ labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
269
+ if output_labels:
270
+ # align text_embs and thus logits with labels for teacher-forcing caption loss
271
+ token_embs = token_embs[:, :-1]
272
+
273
+ logits = self.text_decoder(image_embs, token_embs)
274
+ out_dict = {
275
+ "image_features": image_latent,
276
+ "text_features": text_latent,
277
+ "logits": logits,
278
+ "logit_scale": self.logit_scale.exp()
279
+ }
280
+ if labels is not None:
281
+ out_dict["labels"] = labels
282
+ if self.logit_bias is not None:
283
+ out_dict["logit_bias"] = self.logit_bias
284
+ return out_dict
285
+
286
+ def generate(
287
+ self,
288
+ image,
289
+ text=None,
290
+ seq_len=30,
291
+ max_seq_len=77,
292
+ temperature=1.,
293
+ generation_type="beam_search",
294
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
295
+ top_k=1, # keeps the top_k most probable tokens
296
+ pad_token_id=None,
297
+ eos_token_id=None,
298
+ sot_token_id=None,
299
+ num_beams=6,
300
+ num_beam_groups=3,
301
+ min_seq_len=5,
302
+ stopping_criteria=None,
303
+ repetition_penalty=1.0,
304
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
305
+ ):
306
+ # taking many ideas and components from HuggingFace GenerationMixin
307
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
308
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
309
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
310
+ device = image.device
311
+
312
+ with torch.no_grad():
313
+ sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
314
+ eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
315
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
316
+ logit_processor = LogitsProcessorList(
317
+ [
318
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
319
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
320
+ ]
321
+ )
322
+
323
+ if stopping_criteria is None:
324
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
325
+ stopping_criteria = StoppingCriteriaList(stopping_criteria)
326
+
327
+ if generation_type == "beam_search":
328
+ output = self._generate_beamsearch(
329
+ image_inputs=image,
330
+ pad_token_id=pad_token_id,
331
+ eos_token_id=eos_token_id,
332
+ sot_token_id=sot_token_id,
333
+ num_beams=num_beams,
334
+ num_beam_groups=num_beam_groups,
335
+ min_seq_len=min_seq_len,
336
+ stopping_criteria=stopping_criteria,
337
+ logit_processor=logit_processor,
338
+ )
339
+ if fixed_output_length and output.shape[1] < seq_len:
340
+ pad_len = seq_len - output.shape[1]
341
+ return torch.cat((
342
+ output,
343
+ torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
344
+ ),
345
+ dim=1
346
+ )
347
+ return output
348
+
349
+ elif generation_type == "top_p":
350
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
351
+ elif generation_type == "top_k":
352
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
353
+ else:
354
+ raise ValueError(
355
+ f"generation_type has to be one of "
356
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
357
+ )
358
+
359
+ image_latent, image_embs = self._encode_image(image)
360
+
361
+ if text is None:
362
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
363
+
364
+ was_training = self.training
365
+ num_dims = len(text.shape)
366
+
367
+ if num_dims == 1:
368
+ text = text[None, :]
369
+
370
+ self.eval()
371
+ out = text
372
+
373
+ while True:
374
+ x = out[:, -max_seq_len:]
375
+ cur_len = x.shape[1]
376
+ logits = self(
377
+ image,
378
+ x,
379
+ image_latent=image_latent,
380
+ image_embs=image_embs,
381
+ output_labels=False,
382
+ )["logits"][:, -1]
383
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
384
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
385
+
386
+ if mask.all():
387
+ if not fixed_output_length:
388
+ break
389
+ else:
390
+ logits = logits[~mask, :]
391
+ filtered_logits = logit_processor(x[~mask, :], logits)
392
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
393
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
394
+
395
+ if (cur_len + 1 == seq_len):
396
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
397
+ else:
398
+ sample[~mask, :] = torch.multinomial(probs, 1)
399
+
400
+ out = torch.cat((out, sample), dim=-1)
401
+
402
+ cur_len += 1
403
+
404
+ if all(stopping_criteria(out, None)):
405
+ break
406
+
407
+ if num_dims == 1:
408
+ out = out.squeeze(0)
409
+
410
+ self.train(was_training)
411
+ return out
412
+
413
+ def _generate_beamsearch(
414
+ self,
415
+ image_inputs,
416
+ pad_token_id=None,
417
+ eos_token_id=None,
418
+ sot_token_id=None,
419
+ num_beams=6,
420
+ num_beam_groups=3,
421
+ min_seq_len=5,
422
+ stopping_criteria=None,
423
+ logit_processor=None,
424
+ logit_warper=None,
425
+ ):
426
+ device = image_inputs.device
427
+ batch_size = image_inputs.shape[0]
428
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
429
+ image_latent, image_embs = self._encode_image(image_inputs)
430
+
431
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
432
+ input_ids = input_ids * sot_token_id
433
+ beam_scorer = BeamSearchScorer(
434
+ batch_size=batch_size,
435
+ num_beams=num_beams,
436
+ device=device,
437
+ num_beam_groups=num_beam_groups,
438
+ )
439
+ # instantiate logits processors
440
+ logits_processor = (
441
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
442
+ if logit_processor is None
443
+ else logit_processor
444
+ )
445
+
446
+ num_beams = beam_scorer.num_beams
447
+ num_beam_groups = beam_scorer.num_beam_groups
448
+ num_sub_beams = num_beams // num_beam_groups
449
+ batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
450
+ batch_beam_size, cur_len = input_ids.shape
451
+ beam_indices = None
452
+
453
+ if num_beams * batch_size != batch_beam_size:
454
+ raise ValueError(
455
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
456
+ )
457
+
458
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
459
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
460
+ # the same group don't produce same tokens everytime.
461
+ beam_scores[:, ::num_sub_beams] = 0
462
+ beam_scores = beam_scores.view((batch_size * num_beams,))
463
+
464
+ while True:
465
+
466
+ # predicted tokens in cur_len step
467
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
468
+
469
+ # indices which will form the beams in the next time step
470
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
471
+
472
+ # do one decoder step on all beams of all sentences in batch
473
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
474
+ outputs = self(
475
+ model_inputs['images'],
476
+ model_inputs['text'],
477
+ image_latent=image_latent,
478
+ image_embs=image_embs,
479
+ output_labels=False,
480
+ )
481
+
482
+ for beam_group_idx in range(num_beam_groups):
483
+ group_start_idx = beam_group_idx * num_sub_beams
484
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
485
+ group_size = group_end_idx - group_start_idx
486
+
487
+ # indices of beams of current group among all sentences in batch
488
+ batch_group_indices = []
489
+
490
+ for batch_idx in range(batch_size):
491
+ batch_group_indices.extend(
492
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
493
+ )
494
+ group_input_ids = input_ids[batch_group_indices]
495
+
496
+ # select outputs of beams of currentg group only
497
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
498
+ vocab_size = next_token_logits.shape[-1]
499
+
500
+ next_token_scores_processed = logits_processor(
501
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
502
+ )
503
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
504
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
505
+
506
+ # reshape for beam search
507
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
508
+
509
+ next_token_scores, next_tokens = torch.topk(
510
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
511
+ )
512
+
513
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
514
+ next_tokens = next_tokens % vocab_size
515
+
516
+ # stateless
517
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
518
+ beam_outputs = beam_scorer.process(
519
+ group_input_ids,
520
+ next_token_scores,
521
+ next_tokens,
522
+ next_indices,
523
+ pad_token_id=pad_token_id,
524
+ eos_token_id=eos_token_id,
525
+ beam_indices=process_beam_indices,
526
+ group_index=beam_group_idx,
527
+ )
528
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
529
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
530
+ beam_idx = beam_outputs["next_beam_indices"]
531
+
532
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
533
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
534
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
535
+
536
+ # (beam_idx // group_size) -> batch_idx
537
+ # (beam_idx % group_size) -> offset of idx inside the group
538
+ reordering_indices[batch_group_indices] = (
539
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
540
+ )
541
+
542
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
543
+
544
+ # increase cur_len
545
+ cur_len = cur_len + 1
546
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
547
+ break
548
+
549
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
550
+ sequence_outputs = beam_scorer.finalize(
551
+ input_ids,
552
+ beam_scores,
553
+ next_tokens,
554
+ next_indices,
555
+ pad_token_id=pad_token_id,
556
+ eos_token_id=eos_token_id,
557
+ max_length=stopping_criteria.max_length,
558
+ beam_indices=final_beam_indices,
559
+ )
560
+ return sequence_outputs['sequences']
561
+
562
+
563
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
564
+ if past:
565
+ input_ids = input_ids[:, -1].unsqueeze(-1)
566
+
567
+ attention_mask = kwargs.get("attention_mask", None)
568
+ position_ids = kwargs.get("position_ids", None)
569
+
570
+ if attention_mask is not None and position_ids is None:
571
+ # create position_ids on the fly for batch generation
572
+ position_ids = attention_mask.long().cumsum(-1) - 1
573
+ position_ids.masked_fill_(attention_mask == 0, 1)
574
+ else:
575
+ position_ids = None
576
+ return {
577
+ "text": input_ids,
578
+ "images": image_inputs,
579
+ "past_key_values": past,
580
+ "position_ids": position_ids,
581
+ "attention_mask": attention_mask,
582
+ }
models/FarSLIP/open_clip/constants.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
4
+ IMAGENET_STD = (0.229, 0.224, 0.225)
5
+ INCEPTION_MEAN = (0.5, 0.5, 0.5)
6
+ INCEPTION_STD = (0.5, 0.5, 0.5)
7
+
8
+ # Default name for a weights file hosted on the Huggingface Hub.
9
+ HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
10
+ HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
11
+ HF_CONFIG_NAME = 'open_clip_config.json'
models/FarSLIP/open_clip/convert.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
2
+ """
3
+ from typing import Union
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from .model import CLIP, CustomTextCLIP
9
+ from .transformer import TextTransformer, Transformer
10
+
11
+
12
+ @torch.no_grad()
13
+ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
14
+ """ Load weights from .npz checkpoints for official Google big_vision image-text models
15
+
16
+ Currently, the SigLIP source models are supported and a CustomTextCLIP destination model
17
+ w/ timm image encoder.
18
+ """
19
+ from timm.layers import resample_patch_embed, resample_abs_pos_embed
20
+
21
+ def _n2p(w, t=True, idx=None):
22
+ if idx is not None:
23
+ w = w[idx]
24
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
25
+ w = w.flatten()
26
+ if t:
27
+ if w.ndim == 4:
28
+ w = w.transpose([3, 2, 0, 1])
29
+ elif w.ndim == 3:
30
+ w = w.transpose([2, 0, 1])
31
+ elif w.ndim == 2:
32
+ w = w.transpose([1, 0])
33
+ return torch.from_numpy(w)
34
+
35
+ w = np.load(checkpoint_path)
36
+ interpolation = 'bilinear'
37
+ antialias = False
38
+
39
+ def _convert_timm_img(module, prefix):
40
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
41
+ if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
42
+ embed_conv_w = resample_patch_embed(
43
+ embed_conv_w,
44
+ module.patch_embed.proj.weight.shape[-2:],
45
+ interpolation=interpolation,
46
+ antialias=antialias,
47
+ verbose=True,
48
+ )
49
+ module.patch_embed.proj.weight.copy_(embed_conv_w)
50
+ module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
51
+
52
+ if module.cls_token is not None:
53
+ module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
54
+
55
+ pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
56
+ if pos_embed_w.shape != module.pos_embed.shape:
57
+ assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
58
+ num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
59
+ pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
60
+ pos_embed_w,
61
+ new_size=module.patch_embed.grid_size,
62
+ num_prefix_tokens=num_prefix_tokens,
63
+ interpolation=interpolation,
64
+ antialias=antialias,
65
+ verbose=True,
66
+ )
67
+ module.pos_embed.copy_(pos_embed_w)
68
+
69
+ mha_sub, b_sub, ln1_sub = (0, 0, 1)
70
+ for i, block in enumerate(module.blocks.children()):
71
+ if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
72
+ block_prefix = f'{prefix}Transformer/encoderblock/'
73
+ idx = i
74
+ else:
75
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
76
+ idx = None
77
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
78
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
79
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
80
+ block.attn.qkv.weight.copy_(torch.cat([
81
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
82
+ block.attn.qkv.bias.copy_(torch.cat([
83
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
84
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
85
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
86
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
87
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
88
+ for r in range(2):
89
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(
90
+ _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
91
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(
92
+ _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
93
+
94
+ module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
95
+ module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
96
+
97
+ if module.attn_pool is not None:
98
+ block_prefix = f'{prefix}MAPHead_0/'
99
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
100
+ module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
101
+ module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
102
+ module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
103
+ module.attn_pool.kv.weight.copy_(torch.cat([
104
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
105
+ module.attn_pool.kv.bias.copy_(torch.cat([
106
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
107
+ module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
108
+ module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
109
+ module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
110
+ module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
111
+ for r in range(2):
112
+ getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
113
+ getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
114
+
115
+ def _convert_openclip_transformer(module: Transformer, prefix):
116
+ for i, block in enumerate(module.resblocks.children()):
117
+ if f'{prefix}encoderblock/LayerNorm_0/scale' in w:
118
+ block_prefix = f'{prefix}encoderblock/'
119
+ idx = i
120
+ else:
121
+ block_prefix = f'{prefix}encoderblock_{i}/'
122
+ idx = None
123
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
124
+ block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
125
+ block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
126
+ block.attn.in_proj_weight.copy_(torch.cat([
127
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
128
+ block.attn.in_proj_bias.copy_(torch.cat([
129
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
130
+ block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
131
+ block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
132
+ block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'], idx=idx))
133
+ block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'], idx=idx))
134
+ block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'], idx=idx))
135
+ block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'], idx=idx))
136
+ block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'], idx=idx))
137
+ block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'], idx=idx))
138
+
139
+ def _convert_openclip_txt(module: TextTransformer, prefix):
140
+ module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
141
+ pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
142
+ module.positional_embedding.copy_(pos_embed_w)
143
+ _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
144
+ module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
145
+ module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
146
+ if module.text_projection is not None:
147
+ module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
148
+ module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
149
+
150
+ root_prefix = 'params/' if 'params/b' in w else ''
151
+ _convert_timm_img(model.visual.trunk, f'{root_prefix}img/')
152
+ _convert_openclip_txt(model.text, f'{root_prefix}txt/')
153
+ model.logit_bias.copy_(_n2p(w[f'{root_prefix}b'])[0])
154
+ model.logit_scale.copy_(_n2p(w[f'{root_prefix}t'])[0])
155
+
156
+
157
+ @torch.no_grad()
158
+ def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
159
+
160
+ def _convert_timm_img(state_dict):
161
+ if fastvit:
162
+ from timm.models.fastvit import checkpoint_filter_fn
163
+ else:
164
+ from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
165
+ timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
166
+ timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
167
+ return timm_state_dict
168
+
169
+ def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
170
+ text_dict = {}
171
+ for k, v in state_dict.items():
172
+ if not k.startswith(prefix):
173
+ continue
174
+ k = k.replace(prefix, '')
175
+ k = k.replace('projection_layer', 'text_projection')
176
+ k = k.replace('embedding_layer', 'token_embedding')
177
+ if k.startswith('positional_embedding.pos_embed.pos_embed'):
178
+ k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
179
+ v = v.squeeze()
180
+ k = k.replace('final_layer_norm', 'ln_final')
181
+ k = k.replace('pre_norm_mha.0', 'ln_1')
182
+ k = k.replace('pre_norm_mha.1', 'attn')
183
+ k = k.replace('pre_norm_ffn.0', 'ln_2')
184
+ k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
185
+ k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
186
+ k = k.replace('qkv_proj.weight', 'in_proj_weight')
187
+ k = k.replace('qkv_proj.bias', 'in_proj_bias')
188
+ k = k.replace('transformer.', 'transformer.resblocks.')
189
+ text_dict['text.' + k] = v
190
+ return text_dict
191
+
192
+ image_dict = _convert_timm_img(state_dict)
193
+ text_dict = _convert_openclip_txt(state_dict)
194
+ out_dict = {**image_dict, **text_dict}
195
+ out_dict['logit_scale'] = state_dict['logit_scale']
196
+ return out_dict
197
+
198
+
199
+ def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
200
+ if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
201
+ # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
202
+ state_dict = convert_mobile_clip_state_dict(model, state_dict)
203
+ if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
204
+ # convert b model
205
+ state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
206
+ return state_dict
models/FarSLIP/open_clip/factory.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ import warnings
6
+ from copy import deepcopy
7
+ from dataclasses import asdict
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional, Tuple, Union
10
+
11
+ import torch
12
+
13
+ from .convert import convert_state_dict
14
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
15
+ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
16
+ from .coca_model import CoCa
17
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss, MultiPosConLossMM
18
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
19
+ list_pretrained_tags_by_model, download_pretrained_from_hf
20
+ from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
21
+ from .tokenizer import HFTokenizer, SimpleTokenizer, SigLipTokenizer, DEFAULT_CONTEXT_LENGTH
22
+
23
+ HF_HUB_PREFIX = 'hf-hub:'
24
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
25
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
26
+
27
+
28
+ def _natural_key(string_):
29
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
30
+
31
+
32
+ def _rescan_model_configs():
33
+ global _MODEL_CONFIGS
34
+
35
+ config_ext = ('.json',)
36
+ config_files = []
37
+ for config_path in _MODEL_CONFIG_PATHS:
38
+ if config_path.is_file() and config_path.suffix in config_ext:
39
+ config_files.append(config_path)
40
+ elif config_path.is_dir():
41
+ for ext in config_ext:
42
+ config_files.extend(config_path.glob(f'*{ext}'))
43
+
44
+ for cf in config_files:
45
+ with open(cf, 'r') as f:
46
+ model_cfg = json.load(f)
47
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
48
+ _MODEL_CONFIGS[cf.stem] = model_cfg
49
+
50
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
51
+
52
+
53
+ _rescan_model_configs() # initial populate of model config registry
54
+
55
+
56
+ def list_models():
57
+ """ enumerate available model architectures based on config files """
58
+ return list(_MODEL_CONFIGS.keys())
59
+
60
+
61
+ def add_model_config(path):
62
+ """ add model config path or file and update registry """
63
+ if not isinstance(path, Path):
64
+ path = Path(path)
65
+ _MODEL_CONFIG_PATHS.append(path)
66
+ _rescan_model_configs()
67
+
68
+
69
+ def get_model_config(model_name):
70
+ """ Fetch model config from builtin (local library) configs.
71
+ """
72
+ if model_name in _MODEL_CONFIGS:
73
+ return deepcopy(_MODEL_CONFIGS[model_name])
74
+ else:
75
+ return None
76
+
77
+
78
+ def _get_hf_config(
79
+ model_id: str,
80
+ cache_dir: Optional[str] = None,
81
+ ):
82
+ """ Fetch model config from HuggingFace Hub.
83
+ """
84
+ config_path = download_pretrained_from_hf(
85
+ model_id,
86
+ filename='open_clip_config.json',
87
+ cache_dir=cache_dir,
88
+ )
89
+ with open(config_path, 'r', encoding='utf-8') as f:
90
+ config = json.load(f)
91
+ return config
92
+
93
+
94
+ def get_tokenizer(
95
+ model_name: str = '',
96
+ context_length: Optional[int] = None,
97
+ cache_dir: Optional[str] = None,
98
+ **kwargs,
99
+ ):
100
+ if model_name.startswith(HF_HUB_PREFIX):
101
+ model_name = model_name[len(HF_HUB_PREFIX):]
102
+ try:
103
+ config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg']
104
+ except Exception:
105
+ tokenizer = HFTokenizer(
106
+ model_name,
107
+ context_length=context_length or DEFAULT_CONTEXT_LENGTH,
108
+ cache_dir=cache_dir,
109
+ **kwargs,
110
+ )
111
+ return tokenizer
112
+ else:
113
+ config = get_model_config(model_name)
114
+ assert config is not None, f"No valid model config found for {model_name}."
115
+
116
+ text_config = config.get('text_cfg', {})
117
+ if 'tokenizer_kwargs' in text_config:
118
+ tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
119
+ else:
120
+ tokenizer_kwargs = kwargs
121
+
122
+ if context_length is None:
123
+ context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
124
+
125
+ model_name = model_name.lower()
126
+ if text_config.get('hf_tokenizer_name', ''):
127
+ tokenizer = HFTokenizer(
128
+ text_config['hf_tokenizer_name'],
129
+ context_length=context_length,
130
+ cache_dir=cache_dir,
131
+ **tokenizer_kwargs,
132
+ )
133
+ elif 'siglip' in model_name:
134
+ tn = 'gemma' if 'siglip2' in model_name else 'mc4' if 'i18n' in model_name else 'c4-en'
135
+ tokenizer = SigLipTokenizer(
136
+ tn,
137
+ context_length=context_length,
138
+ # **tokenizer_kwargs,
139
+ )
140
+ else:
141
+ tokenizer = SimpleTokenizer(
142
+ context_length=context_length,
143
+ **tokenizer_kwargs,
144
+ )
145
+
146
+ return tokenizer
147
+
148
+
149
+ def load_state_dict(
150
+ checkpoint_path: str,
151
+ device='cpu',
152
+ weights_only=True,
153
+ ):
154
+ # Check if safetensors or not and load weights accordingly
155
+ if str(checkpoint_path).endswith(".safetensors"):
156
+ from safetensors.torch import load_file
157
+ checkpoint = load_file(checkpoint_path, device=device)
158
+ else:
159
+ try:
160
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
161
+ except Exception:
162
+ checkpoint = torch.load(checkpoint_path, map_location=device)
163
+
164
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
165
+ state_dict = checkpoint['state_dict']
166
+ elif isinstance(checkpoint, torch.jit.ScriptModule):
167
+ state_dict = checkpoint.state_dict()
168
+ for key in ["input_resolution", "context_length", "vocab_size"]:
169
+ state_dict.pop(key, None)
170
+ else:
171
+ state_dict = checkpoint
172
+ if next(iter(state_dict.items()))[0].startswith('module'):
173
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
174
+ return state_dict
175
+
176
+
177
+ def load_checkpoint(
178
+ model: Union[CLIP, CustomTextCLIP],
179
+ checkpoint_path: str,
180
+ strict: bool = True,
181
+ weights_only: bool = True,
182
+ device='cpu',
183
+ ):
184
+ if Path(checkpoint_path).suffix in ('.npz', '.npy'):
185
+ # Separate path loading numpy big_vision (SigLIP) weights
186
+ from open_clip.convert import load_big_vision_weights
187
+ load_big_vision_weights(model, checkpoint_path)
188
+ return {}
189
+
190
+ state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
191
+
192
+ # Detect & convert 3rd party state_dicts -> open_clip
193
+ state_dict = convert_state_dict(model, state_dict)
194
+
195
+ # Detect old format and make compatible with new format
196
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
197
+ state_dict = convert_to_custom_text_state_dict(state_dict)
198
+
199
+ # correct if logit_scale differs in being scaler vs 1d param
200
+ if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
201
+ state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)
202
+
203
+ # correct if logit_bias differs in being scaler vs 1d param
204
+ if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
205
+ state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)
206
+
207
+ # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
208
+ if 'logit_bias' not in state_dict and model.logit_bias is not None:
209
+ state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
210
+
211
+ # Certain text transformers no longer expect position_ids after transformers==4.31
212
+ position_id_key = 'text.transformer.embeddings.position_ids'
213
+ if position_id_key in state_dict and not hasattr(model, position_id_key):
214
+ del state_dict[position_id_key]
215
+
216
+ resize_pos_embed(state_dict, model)
217
+ resize_text_pos_embed(state_dict, model)
218
+
219
+ # Finally, load the massaged state_dict into model
220
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
221
+ if incompatible_keys.missing_keys:
222
+ print("Missing keys:", incompatible_keys.missing_keys)
223
+ if incompatible_keys.unexpected_keys:
224
+ print("Unexpected keys:", incompatible_keys.unexpected_keys)
225
+
226
+ logging.info(f"Missing keys: {incompatible_keys.missing_keys}")
227
+ return incompatible_keys
228
+
229
+
230
+ def create_model(
231
+ model_name: str,
232
+ pretrained: Optional[str] = None,
233
+ precision: str = 'fp32',
234
+ device: Union[str, torch.device] = 'cpu',
235
+ jit: bool = False,
236
+ force_quick_gelu: bool = False,
237
+ force_custom_text: bool = False,
238
+ force_patch_dropout: Optional[float] = None,
239
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
240
+ force_preprocess_cfg: Optional[Dict[str, Any]] = None,
241
+ pretrained_image: bool = False,
242
+ pretrained_hf: bool = True,
243
+ cache_dir: Optional[str] = None,
244
+ output_dict: Optional[bool] = None,
245
+ require_pretrained: bool = False,
246
+ load_weights_only: bool = True,
247
+ long_clip: Optional[str] = 'disable',
248
+ **model_kwargs,
249
+ ):
250
+ """Creates and configures a contrastive vision-language model.
251
+
252
+ Args:
253
+ model_name: Name of the model architecture to create. Can be a local model name
254
+ or a Hugging Face model ID prefixed with 'hf-hub:'.
255
+ pretrained: Tag/path for pretrained model weights. Can be:
256
+ - A pretrained tag name (e.g., 'openai')
257
+ - A path to local weights
258
+ - None to initialize with random weights
259
+ precision: Model precision/AMP configuration. Options:
260
+ - 'fp32': 32-bit floating point
261
+ - 'fp16'/'bf16': Mixed precision with FP32 for certain layers
262
+ - 'pure_fp16'/'pure_bf16': Pure 16-bit precision
263
+ device: Device to load the model on ('cpu', 'cuda', or torch.device object)
264
+ jit: If True, JIT compile the model
265
+ force_quick_gelu: Force use of QuickGELU activation
266
+ force_custom_text: Force use of custom text encoder
267
+ force_patch_dropout: Override default patch dropout value
268
+ force_image_size: Override default image size for vision encoder
269
+ force_preprocess_cfg: Override default preprocessing configuration
270
+ pretrained_image: Load pretrained weights for timm vision models
271
+ pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights
272
+ cache_dir: Override default cache directory for downloaded model files
273
+ output_dict: If True and model supports it, return dictionary of features
274
+ require_pretrained: Raise error if pretrained weights cannot be loaded
275
+ load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)
276
+ **model_kwargs: Additional keyword arguments passed to model constructor
277
+
278
+ Returns:
279
+ Created and configured model instance
280
+
281
+ Raises:
282
+ RuntimeError: If model config is not found or required pretrained weights
283
+ cannot be loaded
284
+
285
+ Examples:
286
+ # Create basic CLIP model
287
+ model = create_model('ViT-B/32')
288
+
289
+ # Create CLIP model with mixed precision on GPU
290
+ model = create_model('ViT-B/32', precision='fp16', device='cuda')
291
+
292
+ # Load pretrained OpenAI weights
293
+ model = create_model('ViT-B/32', pretrained='openai')
294
+
295
+ # Load Hugging Face model
296
+ model = create_model('hf-hub:organization/model-name')
297
+ """
298
+
299
+ force_preprocess_cfg = force_preprocess_cfg or {}
300
+ preprocess_cfg = asdict(PreprocessCfg())
301
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
302
+ if has_hf_hub_prefix:
303
+ model_id = model_name[len(HF_HUB_PREFIX):]
304
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
305
+ config = _get_hf_config(model_id, cache_dir=cache_dir)
306
+ preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
307
+ model_cfg = config['model_cfg']
308
+ pretrained_hf = False # override, no need to load original HF text weights
309
+ else:
310
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
311
+ checkpoint_path = None
312
+ model_cfg = None
313
+
314
+ if isinstance(device, str):
315
+ device = torch.device(device)
316
+
317
+ model_cfg = model_cfg or get_model_config(model_name)
318
+ if model_cfg is not None:
319
+ logging.info(f'Loaded {model_name} model config.')
320
+ else:
321
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
322
+ raise RuntimeError(f'Model config for {model_name} not found.')
323
+
324
+ if force_quick_gelu:
325
+ # override for use of QuickGELU on non-OpenAI transformer models
326
+ model_cfg["quick_gelu"] = True
327
+
328
+ if force_patch_dropout is not None:
329
+ # override the default patch dropout value
330
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
331
+
332
+ if force_image_size is not None:
333
+ # override model config's image size
334
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
335
+
336
+ is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
337
+ if pretrained_image:
338
+ if is_timm_model:
339
+ # pretrained weight loading for timm models set via vision_cfg
340
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
341
+ else:
342
+ assert False, 'pretrained image towers currently only supported for timm models'
343
+
344
+ # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
345
+ cast_dtype = get_cast_dtype(precision)
346
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
347
+ if is_hf_model:
348
+ # load pretrained weights for HF text model IFF no CLIP weights being loaded
349
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
350
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
351
+
352
+ model_cfg.update({"long_clip": long_clip})
353
+ model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
354
+ if custom_text:
355
+ if "multimodal_cfg" in model_cfg:
356
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
357
+ else:
358
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
359
+ else:
360
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
361
+
362
+ if precision in ("fp16", "bf16"):
363
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
364
+ # manual mixed precision that matches original OpenAI behaviour
365
+ if is_timm_model:
366
+ # FIXME this is a bit janky, create timm based model in low-precision and
367
+ # then cast only LayerNormFp32 instances back to float32 so they don't break.
368
+ # Why? The convert_weights_to_lp fn only works with native models.
369
+ model.to(device=device, dtype=dtype)
370
+ from .transformer import LayerNormFp32
371
+
372
+ def _convert_ln(m):
373
+ if isinstance(m, LayerNormFp32):
374
+ m.weight.data = m.weight.data.to(torch.float32)
375
+ m.bias.data = m.bias.data.to(torch.float32)
376
+ model.apply(_convert_ln)
377
+ else:
378
+ model.to(device=device)
379
+ convert_weights_to_lp(model, dtype=dtype)
380
+ elif precision in ("pure_fp16", "pure_bf16"):
381
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
382
+ model.to(device=device, dtype=dtype)
383
+ else:
384
+ model.to(device=device)
385
+
386
+ pretrained_loaded = False
387
+ if pretrained:
388
+ checkpoint_path = ''
389
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
390
+ if pretrained_cfg:
391
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
392
+ preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
393
+ pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)
394
+ model_quick_gelu = model_cfg.get('quick_gelu', False)
395
+ if pretrained_quick_gelu and not model_quick_gelu:
396
+ warnings.warn(
397
+ f'These pretrained weights were trained with QuickGELU activation but the model config does '
398
+ f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.')
399
+ elif not pretrained_quick_gelu and model_quick_gelu:
400
+ warnings.warn(
401
+ f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '
402
+ f'model config, consider using a model config without QuickGELU or disable override flags.')
403
+ elif os.path.exists(pretrained):
404
+ checkpoint_path = pretrained
405
+
406
+ if checkpoint_path:
407
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
408
+ load_checkpoint(model, checkpoint_path, weights_only=load_weights_only, strict=False)
409
+ else:
410
+ error_str = (
411
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
412
+ f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
413
+ logging.warning(error_str)
414
+ raise RuntimeError(error_str)
415
+ pretrained_loaded = True
416
+ elif has_hf_hub_prefix:
417
+ logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
418
+ load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
419
+ pretrained_loaded = True
420
+
421
+ if require_pretrained and not pretrained_loaded:
422
+ # callers of create_model_from_pretrained always expect pretrained weights
423
+ raise RuntimeError(
424
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
425
+
426
+ if output_dict and hasattr(model, "output_dict"):
427
+ model.output_dict = True
428
+
429
+ if jit:
430
+ model = torch.jit.script(model)
431
+
432
+ # set image preprocessing configuration in model attributes for convenience
433
+ if getattr(model.visual, 'image_size', None) is not None:
434
+ # use image_size set on model creation (via config or force_image_size arg)
435
+ force_preprocess_cfg['size'] = model.visual.image_size
436
+ set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
437
+
438
+ return model
439
+
440
+
441
+ def create_loss(args):
442
+ if args.distill:
443
+ return DistillClipLoss(
444
+ local_loss=args.local_loss,
445
+ gather_with_grad=args.gather_with_grad,
446
+ cache_labels=True,
447
+ rank=args.rank,
448
+ world_size=args.world_size,
449
+ use_horovod=args.horovod,
450
+ )
451
+ elif "coca" in args.model.lower():
452
+ return CoCaLoss(
453
+ caption_loss_weight=args.coca_caption_loss_weight,
454
+ clip_loss_weight=args.coca_contrastive_loss_weight,
455
+ local_loss=args.local_loss,
456
+ gather_with_grad=args.gather_with_grad,
457
+ cache_labels=True,
458
+ rank=args.rank,
459
+ world_size=args.world_size,
460
+ use_horovod=args.horovod,
461
+ )
462
+ elif args.siglip:
463
+ assert not args.horovod, "Horovod not currently supported for SigLip"
464
+ return SigLipLoss(
465
+ rank=args.rank,
466
+ world_size=args.world_size,
467
+ dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
468
+ )
469
+ # elif args.mpcl_loss:
470
+ # return MultiPosConLossMM(
471
+ # rank=args.rank,
472
+ # world_size=args.world_size,
473
+ # temperature=0.07, w1=1.0, w2=1.0
474
+ # )
475
+
476
+ return ClipLoss(
477
+ local_loss=args.local_loss,
478
+ gather_with_grad=args.gather_with_grad,
479
+ cache_labels=True,
480
+ rank=args.rank,
481
+ world_size=args.world_size,
482
+ use_horovod=args.horovod,
483
+ )
484
+
485
+
486
+ def create_model_and_transforms(
487
+ model_name: str,
488
+ pretrained: Optional[str] = None,
489
+ precision: str = 'fp32',
490
+ device: Union[str, torch.device] = 'cpu',
491
+ jit: bool = False,
492
+ force_quick_gelu: bool = False,
493
+ force_custom_text: bool = False,
494
+ force_patch_dropout: Optional[float] = None,
495
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
496
+ image_mean: Optional[Tuple[float, ...]] = None,
497
+ image_std: Optional[Tuple[float, ...]] = None,
498
+ image_interpolation: Optional[str] = None,
499
+ image_resize_mode: Optional[str] = None, # only effective for inference
500
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
501
+ pretrained_image: bool = False,
502
+ pretrained_hf: bool = True,
503
+ cache_dir: Optional[str] = None,
504
+ output_dict: Optional[bool] = None,
505
+ load_weights_only: bool = True,
506
+ long_clip: Optional[str] = 'disable',
507
+
508
+ use_imagecrop_aug: Optional[bool] = False,
509
+ max_boxes: Optional[int] = 10,
510
+ local_method: str = 'grids',
511
+ **model_kwargs,
512
+ ):
513
+ force_preprocess_cfg = merge_preprocess_kwargs(
514
+ {},
515
+ mean=image_mean,
516
+ std=image_std,
517
+ interpolation=image_interpolation,
518
+ resize_mode=image_resize_mode,
519
+ )
520
+
521
+ model = create_model(
522
+ model_name,
523
+ pretrained,
524
+ precision=precision,
525
+ device=device,
526
+ jit=jit,
527
+ force_quick_gelu=force_quick_gelu,
528
+ force_custom_text=force_custom_text,
529
+ force_patch_dropout=force_patch_dropout,
530
+ force_image_size=force_image_size,
531
+ force_preprocess_cfg=force_preprocess_cfg,
532
+ pretrained_image=pretrained_image,
533
+ pretrained_hf=pretrained_hf,
534
+ cache_dir=cache_dir,
535
+ output_dict=output_dict,
536
+ load_weights_only=load_weights_only,
537
+ long_clip=long_clip,
538
+ **model_kwargs,
539
+ )
540
+
541
+ pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
542
+
543
+ preprocess_train = image_transform_v2(
544
+ pp_cfg,
545
+ is_train=True,
546
+
547
+ use_imagecrop_aug = use_imagecrop_aug,
548
+ max_boxes = max_boxes,
549
+ local_method = local_method,
550
+ aug_cfg=aug_cfg,
551
+ )
552
+ preprocess_val = image_transform_v2(
553
+ pp_cfg,
554
+ is_train=False,
555
+ )
556
+
557
+ return model, preprocess_train, preprocess_val
558
+
559
+
560
+ def create_model_from_pretrained(
561
+ model_name: str,
562
+ pretrained: Optional[str] = None,
563
+ precision: str = 'fp32',
564
+ device: Union[str, torch.device] = 'cpu',
565
+ jit: bool = False,
566
+ force_quick_gelu: bool = False,
567
+ force_custom_text: bool = False,
568
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
569
+ image_mean: Optional[Tuple[float, ...]] = None,
570
+ image_std: Optional[Tuple[float, ...]] = None,
571
+ image_interpolation: Optional[str] = None,
572
+ image_resize_mode: Optional[str] = None, # only effective for inference
573
+ return_transform: bool = True,
574
+ cache_dir: Optional[str] = None,
575
+ load_weights_only: bool = True,
576
+ **model_kwargs,
577
+ ):
578
+ force_preprocess_cfg = merge_preprocess_kwargs(
579
+ {},
580
+ mean=image_mean,
581
+ std=image_std,
582
+ interpolation=image_interpolation,
583
+ resize_mode=image_resize_mode,
584
+ )
585
+
586
+ model = create_model(
587
+ model_name,
588
+ pretrained,
589
+ precision=precision,
590
+ device=device,
591
+ jit=jit,
592
+ force_quick_gelu=force_quick_gelu,
593
+ force_custom_text=force_custom_text,
594
+ force_image_size=force_image_size,
595
+ force_preprocess_cfg=force_preprocess_cfg,
596
+ cache_dir=cache_dir,
597
+ require_pretrained=True,
598
+ load_weights_only=load_weights_only,
599
+ **model_kwargs,
600
+ )
601
+
602
+ if not return_transform:
603
+ return model
604
+
605
+ preprocess = image_transform_v2(
606
+ PreprocessCfg(**model.visual.preprocess_cfg),
607
+ is_train=False,
608
+ )
609
+
610
+ return model, preprocess