Upload all files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- MajorTOM/MajorTOMDataset.py +64 -0
- MajorTOM/__init__.py +5 -0
- MajorTOM/embedder/MajorTOM_Embedder.py +191 -0
- MajorTOM/embedder/__init__.py +2 -0
- MajorTOM/embedder/__pycache__/MajorTOM_Embedder.cpython-311.pyc +0 -0
- MajorTOM/embedder/__pycache__/__init__.cpython-311.pyc +0 -0
- MajorTOM/embedder/__pycache__/grid_cell_fragment.cpython-311.pyc +0 -0
- MajorTOM/embedder/grid_cell_fragment.py +164 -0
- MajorTOM/embedder/models/DINOv2_S2RGB.py +91 -0
- MajorTOM/embedder/models/SSL4EO_S1RTC.py +125 -0
- MajorTOM/embedder/models/SSL4EO_S2L1C.py +97 -0
- MajorTOM/embedder/models/SigLIP_S2RGB.py +65 -0
- MajorTOM/embedder/models/__init__.py +4 -0
- MajorTOM/embedder/models/__pycache__/DINOv2_S2RGB.cpython-311.pyc +0 -0
- MajorTOM/embedder/models/__pycache__/SSL4EO_S1RTC.cpython-311.pyc +0 -0
- MajorTOM/embedder/models/__pycache__/SSL4EO_S2L1C.cpython-311.pyc +0 -0
- MajorTOM/embedder/models/__pycache__/SigLIP_S2RGB.cpython-311.pyc +0 -0
- MajorTOM/embedder/models/__pycache__/__init__.cpython-311.pyc +0 -0
- MajorTOM/extras/coverage-example.png +3 -0
- MajorTOM/extras/coverage_vis.py +149 -0
- MajorTOM/extras/extract-sample-from-raw-S2.ipynb +0 -0
- MajorTOM/extras/thumbnail_dem.py +77 -0
- MajorTOM/extras/thumbnail_s1rtc.py +80 -0
- MajorTOM/extras/thumbnail_s2.py +68 -0
- MajorTOM/grid.py +284 -0
- MajorTOM/metadata_helpers.py +159 -0
- MajorTOM/sample_helpers.py +20 -0
- app.py +799 -0
- compute_embeddings.py +606 -0
- configs/huggingface.yaml +15 -0
- countries.geo.json +0 -0
- data_utils.py +223 -0
- examples/example1.png +3 -0
- examples/example2.png +3 -0
- examples/example3.png +3 -0
- logs/compute_embeddings_dinov2.log +170 -0
- logs/compute_embeddings_farslip.log +150 -0
- logs/compute_embeddings_satclip.log +182 -0
- logs/compute_embeddings_siglip.log +200 -0
- models/FarSLIP/.gitignore +160 -0
- models/FarSLIP/LICENSE +21 -0
- models/FarSLIP/README.md +237 -0
- models/FarSLIP/__init__.py +1 -0
- models/FarSLIP/open_clip/__init__.py +18 -0
- models/FarSLIP/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- models/FarSLIP/open_clip/coca_model.py +582 -0
- models/FarSLIP/open_clip/constants.py +11 -0
- models/FarSLIP/open_clip/convert.py +206 -0
- models/FarSLIP/open_clip/factory.py +610 -0
.gitattributes
CHANGED
|
@@ -39,3 +39,9 @@ EarthEmbeddingExplorer/examples/example3.png filter=lfs diff=lfs merge=lfs -text
|
|
| 39 |
EarthEmbeddingExplorer/MajorTOM/extras/coverage-example.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
EarthEmbeddingExplorer/models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 41 |
EarthEmbeddingExplorer/models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
EarthEmbeddingExplorer/MajorTOM/extras/coverage-example.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
EarthEmbeddingExplorer/models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 41 |
EarthEmbeddingExplorer/models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/example1.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/example2.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/example3.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
MajorTOM/extras/coverage-example.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
models/SatCLIP/satclip/positional_encoding/__pycache__/spherical_harmonics_ylm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
MajorTOM/MajorTOMDataset.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import rasterio as rio
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torchvision.transforms as transforms
|
| 9 |
+
|
| 10 |
+
class MajorTOM(Dataset):
|
| 11 |
+
"""MajorTOM Dataset (https://huggingface.co/Major-TOM)
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
df ((geo)pandas.DataFrame): Metadata dataframe
|
| 15 |
+
local_dir (string): Root directory of the local dataset version
|
| 16 |
+
tif_bands (list): A list of tif file names to be read
|
| 17 |
+
png_bands (list): A list of png file names to be read
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self,
|
| 22 |
+
df,
|
| 23 |
+
local_dir = None,
|
| 24 |
+
tif_bands=['B04','B03','B02'],
|
| 25 |
+
png_bands=['thumbnail'],
|
| 26 |
+
tif_transforms=[transforms.ToTensor()],
|
| 27 |
+
png_transforms=[transforms.ToTensor()]
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.df = df
|
| 31 |
+
self.local_dir = Path(local_dir) if isinstance(local_dir,str) else local_dir
|
| 32 |
+
self.tif_bands = tif_bands if not isinstance(tif_bands,str) else [tif_bands]
|
| 33 |
+
self.png_bands = png_bands if not isinstance(png_bands,str) else [png_bands]
|
| 34 |
+
self.tif_transforms = transforms.Compose(tif_transforms) if tif_transforms is not None else None
|
| 35 |
+
self.png_transforms = transforms.Compose(png_transforms) if png_transforms is not None else None
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.df)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
meta = self.df.iloc[idx]
|
| 42 |
+
|
| 43 |
+
product_id = meta.product_id
|
| 44 |
+
grid_cell = meta.grid_cell
|
| 45 |
+
row = grid_cell.split('_')[0]
|
| 46 |
+
|
| 47 |
+
path = self.local_dir / Path("{}/{}/{}".format(row, grid_cell, product_id))
|
| 48 |
+
out_dict = {'meta' : meta}
|
| 49 |
+
|
| 50 |
+
for band in self.tif_bands:
|
| 51 |
+
with rio.open(path / '{}.tif'.format(band)) as f:
|
| 52 |
+
out = f.read()
|
| 53 |
+
if self.tif_transforms is not None:
|
| 54 |
+
out = self.tif_transforms(out)
|
| 55 |
+
out_dict[band] = out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
for band in self.png_bands:
|
| 59 |
+
out = Image.open(path / '{}.png'.format(band))
|
| 60 |
+
if self.png_transforms is not None:
|
| 61 |
+
out = self.png_transforms(out)
|
| 62 |
+
out_dict[band] = out
|
| 63 |
+
|
| 64 |
+
return out_dict
|
MajorTOM/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .sample_helpers import *
|
| 2 |
+
from .metadata_helpers import *
|
| 3 |
+
from .MajorTOMDataset import *
|
| 4 |
+
from .grid import *
|
| 5 |
+
from .embedder import *
|
MajorTOM/embedder/MajorTOM_Embedder.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import geopandas as gpd
|
| 3 |
+
import hashlib
|
| 4 |
+
from rasterio.io import MemoryFile
|
| 5 |
+
|
| 6 |
+
from .grid_cell_fragment import *
|
| 7 |
+
from .models import *
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
class MajorTOM_Embedder(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
MajorTOM Embedder class that applies a model to geospatial image fragments,
|
| 13 |
+
computes embeddings, and returns metadata for each fragment.
|
| 14 |
+
|
| 15 |
+
This class is designed to work with raster data, where the image is fragmented
|
| 16 |
+
into smaller tiles, and embeddings are computed for each tile using the provided
|
| 17 |
+
embedder model. The output is a GeoDataFrame containing spatial metadata and
|
| 18 |
+
the corresponding embeddings for each tile.
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
embedder: A model that generates embeddings for image fragments.
|
| 22 |
+
frag_params: Dictionary containing fragmentation parameters such as the
|
| 23 |
+
target overlap and border shift.
|
| 24 |
+
column_types: Dictionary specifying data types for the output GeoDataFrame columns.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, embedder, target_overlap=0.1, border_shift=True):
|
| 28 |
+
"""
|
| 29 |
+
Initializes the MajorTOM Embedder with the given parameters.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
embedder (torch.nn.Module): A model that generates embeddings for image fragments.
|
| 33 |
+
target_overlap (float): The target overlap between image fragments. Default is 0.1.
|
| 34 |
+
border_shift (bool): Whether to shift the borders of fragments to avoid edge artifacts. Default is True.
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
# Model
|
| 39 |
+
self.embedder = embedder
|
| 40 |
+
|
| 41 |
+
# Fragmentation Settings
|
| 42 |
+
self.frag_params = params = {
|
| 43 |
+
'fragment_size' : self.embedder.size[0],
|
| 44 |
+
'target_overlap' : target_overlap,
|
| 45 |
+
'border_shift' : border_shift
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Data types for the output dataframe (commented columns need no conversion)
|
| 49 |
+
self.column_types = {
|
| 50 |
+
#'unique_id' :,
|
| 51 |
+
#'embedding' : ,
|
| 52 |
+
#'timestamp' : ,
|
| 53 |
+
#'product_id' : ,
|
| 54 |
+
#'grid_cell' : ,
|
| 55 |
+
'grid_row_u' : 'int16',
|
| 56 |
+
'grid_col_r' : 'int16',
|
| 57 |
+
'centre_lat' : 'float32',
|
| 58 |
+
'centre_lon' : 'float32',
|
| 59 |
+
#'utm_footprint' : ,
|
| 60 |
+
#'utm_crs' : ,
|
| 61 |
+
#'pixel_bbox' : ,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def bands(self):
|
| 65 |
+
"""
|
| 66 |
+
Returns the set of input bands in the correct order.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
list: List of input bands used by the embedder.
|
| 70 |
+
"""
|
| 71 |
+
return self.embedder.bands
|
| 72 |
+
|
| 73 |
+
def size(self):
|
| 74 |
+
"""
|
| 75 |
+
Returns the input image size.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
tuple: Tuple representing the image size (height, width).
|
| 79 |
+
"""
|
| 80 |
+
return self.embedder.size
|
| 81 |
+
|
| 82 |
+
def calculate_checksum(self, geometry, timestamp, product_id, embedding):
|
| 83 |
+
"""
|
| 84 |
+
Calculates a checksum for the given geometry, timestamp, product ID, and embedding.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
geometry (shapely.geometry): The geometry object representing the fragment's footprint.
|
| 88 |
+
timestamp (str): Timestamp of the data.
|
| 89 |
+
product_id (str): Product identifier.
|
| 90 |
+
embedding (np.ndarray): The embedding of the image fragment.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
str: A SHA256 checksum of the concatenated input parameters.
|
| 94 |
+
"""
|
| 95 |
+
combined = f"{geometry}_{timestamp}_{product_id}_{embedding}"
|
| 96 |
+
checksum = hashlib.sha256(combined.encode()).hexdigest()
|
| 97 |
+
return checksum
|
| 98 |
+
|
| 99 |
+
def _read_image(self, row):
|
| 100 |
+
"""
|
| 101 |
+
Reads and processes the image bands for a given row, performs optional upsampling
|
| 102 |
+
if the resolution is mismatched, and returns the image data, footprint, and CRS.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
row (pandas.Series): The input row containing the image bands.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
torch.Tensor: A tensor containing the stacked image bands.
|
| 109 |
+
shapely.geometry: The footprint of the image.
|
| 110 |
+
rasterio.crs.CRS: The CRS of the image.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
# Read the file
|
| 114 |
+
img = []
|
| 115 |
+
for band in self.embedder.bands:
|
| 116 |
+
with MemoryFile(row[band][0].as_py()) as mem_f:
|
| 117 |
+
with mem_f.open(driver='GTiff') as f:
|
| 118 |
+
crs = f.crs
|
| 119 |
+
footprint = box(*f.bounds)
|
| 120 |
+
img.append(f.read()[0])
|
| 121 |
+
|
| 122 |
+
# optional upsampling
|
| 123 |
+
shapes = [layer.shape for layer in img]
|
| 124 |
+
if any([el!=shapes[0] for el in shapes]): # if any resolution mismatch
|
| 125 |
+
h, w = max([el[0] for el in shapes]), max([el[1] for el in shapes]) # maximum size
|
| 126 |
+
for layer_idx, layer in enumerate(img):
|
| 127 |
+
if layer.shape != (h,w):
|
| 128 |
+
img[layer_idx] = cv2.resize(layer, (h,w), interpolation=cv2.INTER_NEAREST)
|
| 129 |
+
img = torch.from_numpy(np.stack(img,-1).astype(np.float32))
|
| 130 |
+
|
| 131 |
+
return img, footprint, crs
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def forward(self, row, row_meta, device='cuda'):
|
| 135 |
+
"""
|
| 136 |
+
Forward pass of the model: Reads the image, fragments it, computes embeddings
|
| 137 |
+
for each fragment, and returns a GeoDataFrame with the spatial metadata and
|
| 138 |
+
embeddings.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
row (pandas.Series): The input row containing the image data.
|
| 142 |
+
row_meta (pandas.Series): Metadata associated with the row (e.g., timestamp, product_id).
|
| 143 |
+
device (str): The device to run the model on ('cpu' or 'cuda'). Default is 'cuda'.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
geopandas.GeoDataFrame: A GeoDataFrame containing metadata and embeddings for each fragment.
|
| 147 |
+
"""
|
| 148 |
+
# Read file
|
| 149 |
+
img, footprint, crs = self._read_image(row)
|
| 150 |
+
|
| 151 |
+
# Fragment the sample
|
| 152 |
+
fragments, xys = fragment_fn(img, **self.frag_params, return_indices=True, verbose=False)
|
| 153 |
+
|
| 154 |
+
nrows, ncols, c, h, w = fragments.shape
|
| 155 |
+
# Apply the model
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
embeddings = self.embedder(fragments.reshape(-1,c,h,w).to(device)).view(nrows, ncols, -1)
|
| 158 |
+
|
| 159 |
+
df_rows = []
|
| 160 |
+
|
| 161 |
+
# Pack rows for geoparquet
|
| 162 |
+
for r_idx in range(nrows):
|
| 163 |
+
for c_idx in range(ncols):
|
| 164 |
+
embedding = embeddings[r_idx, c_idx].cpu().numpy()
|
| 165 |
+
# spatial features per fragment
|
| 166 |
+
x_offset,y_offset=xys[r_idx,c_idx].int().tolist()
|
| 167 |
+
pixel_bbox = [x_offset, y_offset, x_offset + h,y_offset + w] # in pixels
|
| 168 |
+
utm_footprint = crop_footprint(footprint, *img.shape[:2], pixel_bbox)
|
| 169 |
+
# main footprint is in WGS84 (needs to be consistent across parquet)
|
| 170 |
+
transformer = Transformer.from_crs(crs, CRS.from_epsg(4326), always_xy=True)
|
| 171 |
+
geometry = transform(transformer.transform, utm_footprint) # WGS84
|
| 172 |
+
centre_lon, centre_lat = geometry.centroid.coords[0]
|
| 173 |
+
|
| 174 |
+
row_dict = {
|
| 175 |
+
'unique_id' : self.calculate_checksum(geometry, row_meta.timestamp.item(), row_meta.product_id.item(), embedding),
|
| 176 |
+
'embedding' : embedding,
|
| 177 |
+
'timestamp' : row_meta.timestamp.item(),
|
| 178 |
+
'product_id' : row_meta.product_id.item(),
|
| 179 |
+
'grid_cell' : row_meta.grid_cell.item(),
|
| 180 |
+
'grid_row_u' : row_meta.grid_row_u.item(),
|
| 181 |
+
'grid_col_r' : row_meta.grid_col_r.item(),
|
| 182 |
+
'geometry' : geometry,
|
| 183 |
+
'centre_lat' : centre_lat,
|
| 184 |
+
'centre_lon' : centre_lon,
|
| 185 |
+
'utm_footprint' : utm_footprint.wkt,
|
| 186 |
+
'utm_crs' : crs.to_string(),
|
| 187 |
+
'pixel_bbox' : pixel_bbox,
|
| 188 |
+
}
|
| 189 |
+
df_rows.append(row_dict)
|
| 190 |
+
|
| 191 |
+
return gpd.GeoDataFrame(df_rows).astype(self.column_types)
|
MajorTOM/embedder/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .MajorTOM_Embedder import *
|
| 2 |
+
from .grid_cell_fragment import *
|
MajorTOM/embedder/__pycache__/MajorTOM_Embedder.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
MajorTOM/embedder/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (244 Bytes). View file
|
|
|
MajorTOM/embedder/__pycache__/grid_cell_fragment.cpython-311.pyc
ADDED
|
Binary file (8.37 kB). View file
|
|
|
MajorTOM/embedder/grid_cell_fragment.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from shapely.ops import transform
|
| 5 |
+
from pyproj import CRS, Transformer
|
| 6 |
+
import geopandas as gpd
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
from shapely.geometry import Polygon, box
|
| 10 |
+
from rasterio.transform import from_bounds, xy
|
| 11 |
+
#from rasterio.windows import Window, from_bounds
|
| 12 |
+
import rasterio as rio
|
| 13 |
+
|
| 14 |
+
def crop_footprint(footprint, height, width, crop_bbox):
|
| 15 |
+
"""
|
| 16 |
+
Crops the given footprint to the specified bounding box.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
footprint (shapely.geometry.Polygon): The original footprint of the image or area.
|
| 20 |
+
height (int): Height of the image (in pixels).
|
| 21 |
+
width (int): Width of the image (in pixels).
|
| 22 |
+
crop_bbox (list): The bounding box to crop the footprint. The format is
|
| 23 |
+
[col_start, row_start, col_end, row_end], where:
|
| 24 |
+
- col_start, row_start: top-left corner
|
| 25 |
+
- col_end, row_end: bottom-right corner
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
shapely.geometry.Polygon: The cropped bounding box in the same coordinate reference system (CRS) as the original footprint.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
transform = from_bounds(*footprint.bounds, width, height)
|
| 32 |
+
|
| 33 |
+
# Convert pixel coordinates (col, row) to spatial coordinates (e.g., UTM)
|
| 34 |
+
# Using the raster's affine transform
|
| 35 |
+
min_x, min_y = transform * (crop_bbox[0], crop_bbox[1]) # (col_start, row_start)
|
| 36 |
+
max_x, max_y = transform * (crop_bbox[2], crop_bbox[3]) # (col_end, row_end)
|
| 37 |
+
|
| 38 |
+
# Create a Shapely polygon for the crop's bounding box in UTM
|
| 39 |
+
return box(min_x, min_y, max_x, max_y)
|
| 40 |
+
|
| 41 |
+
def fragment_unfold(image,fragment_size,overlap):
|
| 42 |
+
"""
|
| 43 |
+
Unfold operation for a fragment with overlap. This function extracts image patches (fragments) with a specified
|
| 44 |
+
size and overlap between them.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
image (torch.Tensor or np.ndarray): The input image to be fragmented (height, width, channels).
|
| 48 |
+
fragment_size (int or list): The size of each fragment. Can be a single integer for square fragments or
|
| 49 |
+
a list of two integers for non-square fragments.
|
| 50 |
+
overlap (int or list): The overlap between adjacent fragments. Can be a single integer or a list of two integers.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
torch.Tensor: The unfolded fragments of the image, each with the specified size and overlap.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# Convert image to a tensor and reorder dimensions if necessary
|
| 57 |
+
if not torch.is_tensor(image):
|
| 58 |
+
image = torch.from_numpy(image).permute(2, 0, 1) # Rearrange to (channels, height, width)
|
| 59 |
+
if len(image.shape) < 4:
|
| 60 |
+
image = image.unsqueeze(0) # Add batch dimension
|
| 61 |
+
|
| 62 |
+
b, c, h, w = image.shape
|
| 63 |
+
|
| 64 |
+
# Ensure fragment size is a list
|
| 65 |
+
if isinstance(fragment_size, int):
|
| 66 |
+
fragment_size = [fragment_size, fragment_size]
|
| 67 |
+
if isinstance(overlap, int):
|
| 68 |
+
overlap = [overlap, overlap]
|
| 69 |
+
|
| 70 |
+
# Calculate stride based on fragment size and overlap
|
| 71 |
+
stride = [f - o for f, o in zip(fragment_size, overlap)]
|
| 72 |
+
|
| 73 |
+
# Perform the unfolding operation
|
| 74 |
+
uf = torch.nn.functional.unfold(image, fragment_size, dilation=1, padding=0, stride=stride)
|
| 75 |
+
|
| 76 |
+
# Reshape and permute to return the unfolded image fragments
|
| 77 |
+
return uf.view(b, c, *fragment_size, -1).permute(0, 4, 1, 2, 3)[0]
|
| 78 |
+
|
| 79 |
+
def fragment_fn(img,
|
| 80 |
+
fragment_size,
|
| 81 |
+
target_overlap,
|
| 82 |
+
border_shift=True, # determines whether the outer border is shifted to ensure full coverage
|
| 83 |
+
return_indices=False,
|
| 84 |
+
verbose=False
|
| 85 |
+
):
|
| 86 |
+
"""
|
| 87 |
+
Fragment an image into smaller patches with a specified fragment size and overlap.
|
| 88 |
+
|
| 89 |
+
This function handles different scenarios based on image size, fragment size, and overlap,
|
| 90 |
+
and creates fragments from the input image accordingly. It also supports shifting the outer
|
| 91 |
+
border of fragments to ensure full coverage of the image.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
img (np.ndarray or torch.Tensor): The input image to be fragmented (height, width, channels).
|
| 95 |
+
fragment_size (int or list): The size of the fragments. Can be a single integer (square) or a list of two integers (non-square).
|
| 96 |
+
target_overlap (float): The target overlap between adjacent fragments, in pixels.
|
| 97 |
+
border_shift (bool): Whether to shift the border of fragments to ensure full coverage of the image. Default is True.
|
| 98 |
+
return_indices (bool): If True, the function will also return the indices (offsets) for each fragment. Default is False.
|
| 99 |
+
verbose (bool): If True, the function will print additional details about the overlap. Default is False.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
torch.Tensor or tuple:
|
| 103 |
+
- If `return_indices` is False, a tensor containing the image fragments.
|
| 104 |
+
- If `return_indices` is True, a tuple of the image fragments and their offsets.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
h,w,c=img.shape
|
| 108 |
+
|
| 109 |
+
assert h==w # SQUARE IMAGES SUPPORT ONLY
|
| 110 |
+
|
| 111 |
+
hf, wf = fragment_size, fragment_size
|
| 112 |
+
ho, wo = target_overlap*hf, target_overlap*wf
|
| 113 |
+
|
| 114 |
+
assert h >= hf and w >= wf # reject Scenario 1
|
| 115 |
+
|
| 116 |
+
# Scenario 2
|
| 117 |
+
if h == hf or w == wf:
|
| 118 |
+
if not torch.is_tensor(img):
|
| 119 |
+
img=torch.from_numpy(img).permute(2,0,1)
|
| 120 |
+
return img.view(1,1,c,h,w)
|
| 121 |
+
|
| 122 |
+
# Scenario 3 & 4
|
| 123 |
+
|
| 124 |
+
# determine number of segments between the centers of outermost fragments
|
| 125 |
+
h_n = max(1, int(np.round((h-hf)/(hf-ho))))
|
| 126 |
+
w_n = max(1, int(np.round((w-wf)/(wf-wo))))
|
| 127 |
+
|
| 128 |
+
# adjust practical overlap (divide the distance between the centers of outermost fragments by the true number of segments)
|
| 129 |
+
aho = int(np.ceil(hf-(h-hf)/(h_n)))
|
| 130 |
+
awo = int(np.ceil(wf-(w-wf)/(w_n)))
|
| 131 |
+
|
| 132 |
+
# compute fragments (might not exactly fill the outermost border)
|
| 133 |
+
topleft = fragment_unfold(img.permute(2,0,1),fragment_size=(hf,wf), overlap=(aho,awo)).view(1+h_n, 1+w_n, c, hf, wf)
|
| 134 |
+
|
| 135 |
+
full = topleft
|
| 136 |
+
|
| 137 |
+
if border_shift:
|
| 138 |
+
|
| 139 |
+
if h > hf+h_n*(hf-aho) or w > wf+w_n*(wf-awo):
|
| 140 |
+
#print('Outers...')
|
| 141 |
+
bottomleft = fragment_unfold(img[-hf:,:,:],fragment_size=(hf,wf), overlap=(aho,awo)).view(1,1+w_n,c,hf,wf)
|
| 142 |
+
topright = fragment_unfold(img[:,-wf:,:],fragment_size=(hf,wf), overlap=(aho,awo)).view(1+h_n,1,c,hf,wf)
|
| 143 |
+
|
| 144 |
+
# Shift last row and col to the border of the original
|
| 145 |
+
full[:,-1,None] = topright
|
| 146 |
+
full[-1] = bottomleft
|
| 147 |
+
|
| 148 |
+
if verbose:
|
| 149 |
+
print('Target Overlap: {} pixels. Feasible Overlap: {} pixels.'.format(ho,aho))
|
| 150 |
+
|
| 151 |
+
if not return_indices:
|
| 152 |
+
return full
|
| 153 |
+
else:
|
| 154 |
+
offset=-1*torch.ones(*full.shape[:2],2)
|
| 155 |
+
for ridx in range(full.shape[0]):
|
| 156 |
+
for cidx in range(full.shape[1]):
|
| 157 |
+
offset[ridx,cidx,1] = cidx * (hf-aho)
|
| 158 |
+
offset[ridx,cidx,0] = ridx * (wf-awo)
|
| 159 |
+
|
| 160 |
+
if border_shift:
|
| 161 |
+
offset[ridx,-1,1] = h-hf
|
| 162 |
+
offset[-1,cidx,0] = w-wf
|
| 163 |
+
|
| 164 |
+
return full,offset
|
MajorTOM/embedder/models/DINOv2_S2RGB.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 3 |
+
|
| 4 |
+
class DINOv2_S2RGB_Embedder(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Embedding wrapper for DINOv2 and Sentinel-2 data.
|
| 7 |
+
|
| 8 |
+
This model uses the DINOv2 architecture to generate embeddings for Sentinel-2 RGB data. The input data (RGB bands)
|
| 9 |
+
is preprocessed by normalizing and mapping it to true-color values. Then, it is passed through the DINOv2 model
|
| 10 |
+
to obtain feature embeddings.
|
| 11 |
+
|
| 12 |
+
Preprocessing:
|
| 13 |
+
The input Sentinel-2 image is divided by 10,000 and multiplied by 2.5 to map it to a true-color image
|
| 14 |
+
(normalized to the range [0, 1]), followed by processing using the DINOv2 image processor.
|
| 15 |
+
|
| 16 |
+
Model:
|
| 17 |
+
The DINOv2 model processes RGB input images of shape [224, 224] and produces embeddings, which are then
|
| 18 |
+
averaged across the sequence dimension to obtain a fixed-size embedding vector.
|
| 19 |
+
|
| 20 |
+
Model Components:
|
| 21 |
+
- `AutoImageProcessor`: Preprocessing pipeline for handling Sentinel-2 data.
|
| 22 |
+
- `AutoModel`: DINOv2 transformer model used for feature extraction.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
processor (AutoImageProcessor): The DINOv2 image processor to handle preprocessing.
|
| 26 |
+
model (AutoModel): The DINOv2 model used to generate embeddings from preprocessed images.
|
| 27 |
+
bands (list): List of the Sentinel-2 bands used for RGB input (B04, B03, B02).
|
| 28 |
+
size (tuple): The input size expected by the model (height, width) for the RGB image.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""
|
| 33 |
+
Initializes the DINOv2_S2RGB_Embedder by loading the pre-trained DINOv2 model and processor,
|
| 34 |
+
and setting the expected input size for Sentinel-2 RGB data.
|
| 35 |
+
|
| 36 |
+
This embedder uses the 'facebook/dinov2-base' model for feature extraction from Sentinel-2
|
| 37 |
+
true-color images (RGB).
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
processor (AutoImageProcessor): The DINOv2 image processor for preprocessing Sentinel-2 images.
|
| 41 |
+
model (AutoModel): The pre-trained DINOv2 model for generating embeddings.
|
| 42 |
+
bands (list): The Sentinel-2 bands used for RGB data (B04 - Red, B03 - Green, B02 - Blue).
|
| 43 |
+
size (tuple): The expected input size of the image for the DINOv2 model (height, width).
|
| 44 |
+
"""
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
# Load the DINOv2 processor and model from Hugging Face
|
| 48 |
+
self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
| 49 |
+
self.model = AutoModel.from_pretrained('facebook/dinov2-base')
|
| 50 |
+
|
| 51 |
+
# Define the RGB bands for Sentinel-2 (B04, B03, B02)
|
| 52 |
+
self.bands = ['B04', 'B03', 'B02']
|
| 53 |
+
|
| 54 |
+
# Extract the input size from the processor settings
|
| 55 |
+
self.size = self.processor.crop_size['height'], self.processor.crop_size['width']
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def normalize(self, input):
|
| 59 |
+
"""
|
| 60 |
+
Normalizes Sentinel-2 RGB data to true-color values.
|
| 61 |
+
|
| 62 |
+
The input image (in raw Sentinel-2 reflectance values) is first divided by 10,000 to convert it
|
| 63 |
+
to reflectance values in the range [0, 1]. Then, the result is multiplied by 2.5 to obtain true-color
|
| 64 |
+
values that are suitable for input into the DINOv2 model.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
input (torch.Tensor): The raw Sentinel-2 image tensor to be normalized.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
torch.Tensor: The normalized true-color image.
|
| 71 |
+
"""
|
| 72 |
+
return (2.5 * (input / 1e4)).clip(0,1)
|
| 73 |
+
|
| 74 |
+
def forward(self, input):
|
| 75 |
+
"""
|
| 76 |
+
Forward pass through the model to generate embeddings for the input image.
|
| 77 |
+
|
| 78 |
+
The input image is first normalized using the `normalize` method, then processed by the DINOv2 image processor
|
| 79 |
+
and passed through the DINOv2 model to generate embeddings. The output from the model is averaged across
|
| 80 |
+
the sequence dimension to obtain a fixed-size embedding.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
input (torch.Tensor): The input Sentinel-2 image tensor with shape [C, H, W], where C=3 (RGB channels).
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
torch.Tensor: The embedding vector, averaged over the sequence dimension, with shape [embedding_dim].
|
| 87 |
+
"""
|
| 88 |
+
model_input = self.processor(self.normalize(input), return_tensors="pt")
|
| 89 |
+
outputs = self.model(model_input['pixel_values'].to(self.model.device))
|
| 90 |
+
last_hidden_states = outputs.last_hidden_state
|
| 91 |
+
return last_hidden_states.mean(dim=1).cpu()
|
MajorTOM/embedder/models/SSL4EO_S1RTC.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchgeo.models import ResNet50_Weights
|
| 3 |
+
import timm
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class SSL4EO_S1RTC_Embedder(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
SSL4EO Embedder for Sentinel-1 data using a pre-trained model.
|
| 9 |
+
|
| 10 |
+
This model is based on the SSL4EO (Self-Supervised Learning for Earth Observation) approach,
|
| 11 |
+
using a pre-trained ResNet50 model for Sentinel-1 radar data (SAR). The model is fine-tuned
|
| 12 |
+
to work with Sentinel-1 data and can be used directly for feature extraction.
|
| 13 |
+
|
| 14 |
+
Project Code:
|
| 15 |
+
https://github.com/zhu-xlab/SSL4EO-S12
|
| 16 |
+
|
| 17 |
+
Publication:
|
| 18 |
+
https://arxiv.org/abs/2211.07044
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, s1_mean=[-12.54847273, -20.19237134], s1_std=[5.25697717,5.91150917]):
|
| 22 |
+
"""
|
| 23 |
+
Initializes the SSL4EO_S1RTC_Embedder by setting up the mean and standard deviation for Sentinel-1 data normalization,
|
| 24 |
+
and loading the pre-trained model.
|
| 25 |
+
|
| 26 |
+
The model uses a pre-trained ResNet50 architecture adapted for Sentinel-1 radar (SAR) data, with weights provided
|
| 27 |
+
by the `torchgeo` library. The `s1_mean` and `s1_std` are used for normalizing the input data to the model.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
s1_mean (list, optional): Mean values for Sentinel-1 radar (SAR) data. Default is set to SSL4EO's values.
|
| 31 |
+
s1_std (list, optional): Standard deviation values for Sentinel-1 radar (SAR) data. Default is set to SSL4EO's values.
|
| 32 |
+
|
| 33 |
+
Attributes:
|
| 34 |
+
s1_mean (torch.FloatTensor): Mean values for normalization.
|
| 35 |
+
s1_std (torch.FloatTensor): Standard deviation values for normalization.
|
| 36 |
+
model (torch.nn.Module): The ResNet50 model initialized with pre-trained weights.
|
| 37 |
+
bands (list): List of Sentinel-1 bands used for input data (VV, VH).
|
| 38 |
+
size (tuple): The input size expected by the model (224x224 pixels).
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.s1_mean = torch.FloatTensor(s1_mean)
|
| 43 |
+
self.s1_std = torch.FloatTensor(s1_std)
|
| 44 |
+
|
| 45 |
+
# load model
|
| 46 |
+
self.model = self.init_model()
|
| 47 |
+
self.bands = ['vv','vh']
|
| 48 |
+
self.size = 224,224
|
| 49 |
+
|
| 50 |
+
def init_model(self):
|
| 51 |
+
"""
|
| 52 |
+
Initializes the ResNet50 model with pre-trained weights for Sentinel-1 data.
|
| 53 |
+
|
| 54 |
+
This method loads the pre-trained model weights for Sentinel-1 data from `ResNet50_Weights.SENTINEL1_ALL_MOCO`
|
| 55 |
+
and sets the fully connected layer (`fc`) to an identity function to output embeddings directly from the last
|
| 56 |
+
convolutional layer.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
torch.nn.Module: The initialized ResNet50 model.
|
| 60 |
+
"""
|
| 61 |
+
weights = ResNet50_Weights.SENTINEL1_ALL_MOCO
|
| 62 |
+
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
|
| 63 |
+
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
| 64 |
+
model.fc=torch.nn.Identity()
|
| 65 |
+
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
def normalize(self, img,scale=1.0):
|
| 69 |
+
"""
|
| 70 |
+
Normalizes the Sentinel-1 SAR (Synthetic Aperture Radar) data.
|
| 71 |
+
|
| 72 |
+
This method normalizes the Sentinel-1 radar signals using the mean (`s1_mean`)
|
| 73 |
+
and standard deviation (`s1_std`) values. The radar data is normalized to a
|
| 74 |
+
standard range, and the pixel values are scaled using a factor (`scale`).
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
img (torch.Tensor): The input Sentinel-1 image to be normalized.
|
| 78 |
+
scale (float, optional): The scaling factor for the normalized image. Default is 1.0.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
torch.Tensor: The normalized and scaled image.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
min_value = (self.s1_mean - 2 * self.s1_std).to(img.device)
|
| 86 |
+
max_value = (self.s1_mean + 2 * self.s1_std).to(img.device)
|
| 87 |
+
img = (img - min_value[:,None,None]) / (max_value - min_value)[:,None,None] * scale
|
| 88 |
+
img = img.clip(0,scale).float()
|
| 89 |
+
|
| 90 |
+
return img
|
| 91 |
+
|
| 92 |
+
def preprocess(self, input):
|
| 93 |
+
"""
|
| 94 |
+
Preprocesses the Sentinel-1 SAR (Synthetic Aperture Radar) data before feeding it into the model.
|
| 95 |
+
|
| 96 |
+
This method applies a logarithmic transformation to the input image to convert
|
| 97 |
+
it from linear scale to decibel (dB) scale. The image is clipped to avoid
|
| 98 |
+
logarithm of zero and then normalized using the `normalize` method.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
input (torch.Tensor): The input Sentinel-1 image (e.g., VV or VH polarization).
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
torch.Tensor: The preprocessed and normalized image in dB scale.
|
| 105 |
+
"""
|
| 106 |
+
# Convert the input from linear scale to decibel (dB) scale
|
| 107 |
+
dB_input = 10 * input.log10(input.clip(min=1e-10)) # Clip to prevent log(0)
|
| 108 |
+
|
| 109 |
+
# Normalize the dB-scaled image
|
| 110 |
+
return self.normalize(dB_input)
|
| 111 |
+
|
| 112 |
+
def forward(self, input):
|
| 113 |
+
"""
|
| 114 |
+
Forward pass through the model.
|
| 115 |
+
|
| 116 |
+
The input image is preprocessed using the `preprocess` method and then passed
|
| 117 |
+
through the ResNet50 model to obtain an embedding.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
input (torch.Tensor): Preprocessed Sentinel-1 image (e.g., shape: [C, H, W]).
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
torch.Tensor: The output embedding from the model.
|
| 124 |
+
"""
|
| 125 |
+
return self.model(self.preprocess(input))
|
MajorTOM/embedder/models/SSL4EO_S2L1C.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchgeo.models import ResNet50_Weights
|
| 3 |
+
import timm
|
| 4 |
+
|
| 5 |
+
class SSL4EO_S2L1C_Embedder(torch.nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
SSL4EO Embedder for Sentinel-2 data using a pre-trained model.
|
| 8 |
+
|
| 9 |
+
This model is based on the SSL4EO (Self-Supervised Learning for Earth Observation) approach,
|
| 10 |
+
using a pre-trained ResNet50 model for Sentinel-2 data. The model is fine-tuned for Sentinel-2
|
| 11 |
+
images and can be used directly for feature extraction.
|
| 12 |
+
|
| 13 |
+
Project Code:
|
| 14 |
+
https://github.com/zhu-xlab/SSL4EO-S12
|
| 15 |
+
|
| 16 |
+
Publication:
|
| 17 |
+
https://arxiv.org/abs/2211.07044
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
"""
|
| 24 |
+
Initializes the SSL4EO_S2L1C_Embedder by loading the pre-trained SSL4EO model.
|
| 25 |
+
|
| 26 |
+
The model uses ResNet50 architecture, adapted for Sentinel-2 data with a specific
|
| 27 |
+
weight configuration (`ResNet50_Weights.SENTINEL2_ALL_DINO`) provided by `torchgeo`.
|
| 28 |
+
It also defines the bands used for Sentinel-2 data and sets the input image size to
|
| 29 |
+
224x224 pixels (the model input size).
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
model (torch.nn.Module): The ResNet50 model with pre-trained weights for Sentinel-2 data.
|
| 33 |
+
bands (list): List of Sentinel-2 bands used for input data.
|
| 34 |
+
size (tuple): The input image size expected by the model, set to 224x224 pixels.
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
# Load the pre-trained SSL4EO ResNet50 model
|
| 39 |
+
self.model = self.init_model()
|
| 40 |
+
|
| 41 |
+
# Define the Sentinel-2 L1C bands (e.g., B01, B02, B03, etc.)
|
| 42 |
+
self.bands = [
|
| 43 |
+
'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07',
|
| 44 |
+
'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Define the expected input size of the model
|
| 48 |
+
self.size = 224, 224
|
| 49 |
+
|
| 50 |
+
def init_model(self):
|
| 51 |
+
"""
|
| 52 |
+
Initializes the ResNet50 model with pre-trained weights for Sentinel-2 data.
|
| 53 |
+
|
| 54 |
+
The model is loaded using the `timm` library, with Sentinel-2 specific weights
|
| 55 |
+
(`ResNet50_Weights.SENTINEL2_ALL_DINO`). The fully connected layer (`fc`) is replaced
|
| 56 |
+
with an identity function to obtain embeddings directly from the last convolutional
|
| 57 |
+
layer.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
torch.nn.Module: The initialized ResNet50 model.
|
| 61 |
+
"""
|
| 62 |
+
weights = ResNet50_Weights.SENTINEL2_ALL_DINO
|
| 63 |
+
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
|
| 64 |
+
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
| 65 |
+
model.fc=torch.nn.Identity()
|
| 66 |
+
|
| 67 |
+
return model
|
| 68 |
+
|
| 69 |
+
def preprocess(self, input):
|
| 70 |
+
"""
|
| 71 |
+
Preprocesses the Sentinel-2 input data for the model.
|
| 72 |
+
|
| 73 |
+
This function normalizes the input image by dividing the pixel values by 10,000.
|
| 74 |
+
This scaling step ensures that the reflectance values are mapped into an appropriate
|
| 75 |
+
range for the model.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
input (torch.Tensor): Input image with Sentinel-2 reflectance values (e.g., shape: [C, H, W]).
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
torch.Tensor: Preprocessed input, scaled by a factor of 10,000.
|
| 82 |
+
"""
|
| 83 |
+
return input / 1e4
|
| 84 |
+
|
| 85 |
+
def forward(self, input):
|
| 86 |
+
"""
|
| 87 |
+
Forward pass through the model.
|
| 88 |
+
|
| 89 |
+
The input image is preprocessed and then passed through the ResNet50 model to obtain the embedding.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
input (torch.Tensor): Preprocessed Sentinel-2 image (e.g., shape: [C, H, W]).
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
torch.Tensor: The output embedding from the model.
|
| 96 |
+
"""
|
| 97 |
+
return self.model(self.preprocess(input))
|
MajorTOM/embedder/models/SigLIP_S2RGB.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from open_clip import create_model_from_pretrained, get_tokenizer
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class SigLIP_S2RGB_Embedder(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Embedding wrapper for SigLIP and Sentinel-2 data.
|
| 7 |
+
|
| 8 |
+
This model processes Sentinel-2 RGB data and embeds it into a feature space using the DINOv@ transformer model.
|
| 9 |
+
The preprocessing includes normalizing Sentinel-2 values to create a True-Colour image before passing it through
|
| 10 |
+
the model. The final output is a high-dimensional feature vector representing the input image.
|
| 11 |
+
|
| 12 |
+
Preprocessing:
|
| 13 |
+
- Sentinel-2 bands are divided by 10,000 to scale the reflectance values.
|
| 14 |
+
- Then, the values are multiplied by 2.5 to map them into the [0, 1] range for True-Colour images.
|
| 15 |
+
- The model input is further processed using the DINOv@ preprocessor.
|
| 16 |
+
|
| 17 |
+
Model:
|
| 18 |
+
- Takes an RGB input of shape 384x384 pixels and produces an embedding vector.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
# load model
|
| 25 |
+
self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
|
| 26 |
+
# Sentinel-2 RGB bands (B04 - Red, B03 - Green, B02 - Blue)
|
| 27 |
+
self.bands = ['B04', 'B03', 'B02']
|
| 28 |
+
self.size = self.preprocess.transforms[0].size
|
| 29 |
+
|
| 30 |
+
def normalize(self, input):
|
| 31 |
+
"""
|
| 32 |
+
Normalizes Sentinel-2 image data to create a True-Colour image.
|
| 33 |
+
|
| 34 |
+
Sentinel-2 images are scaled to reflectance values in the range [0, 1]. This function:
|
| 35 |
+
- Divides the input by 10,000 to scale Sentinel-2 values.
|
| 36 |
+
- Multiplies the result by 2.5 to map the values into the True-Colour image range.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
input (torch.Tensor or np.ndarray): Input image with Sentinel-2 reflectance values.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
torch.Tensor: Normalized True-Colour image, clipped to the range [0, 1].
|
| 43 |
+
"""
|
| 44 |
+
return (2.5 * (input / 1e4)).clip(0,1)
|
| 45 |
+
|
| 46 |
+
def forward(self, input):
|
| 47 |
+
"""
|
| 48 |
+
Forward pass through the SigLIP model.
|
| 49 |
+
|
| 50 |
+
This method normalizes the input Sentinel-2 image to a True-Colour representation and processes it through
|
| 51 |
+
the model to obtain an embedding.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
input (torch.Tensor): A Sentinel-2 image, typically of shape (C, H, W), where C=3 (RGB),
|
| 55 |
+
H=384, and W=384.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: The image embedding produced by the model.
|
| 59 |
+
"""
|
| 60 |
+
preprocess_input = self.normalize(input)
|
| 61 |
+
|
| 62 |
+
# normalization only
|
| 63 |
+
model_input = self.preprocess.transforms[-1](preprocess_input)
|
| 64 |
+
|
| 65 |
+
return self.model.encode_image(model_input)
|
MajorTOM/embedder/models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .SigLIP_S2RGB import *
|
| 2 |
+
from .DINOv2_S2RGB import *
|
| 3 |
+
from .SSL4EO_S2L1C import *
|
| 4 |
+
from .SSL4EO_S1RTC import *
|
MajorTOM/embedder/models/__pycache__/DINOv2_S2RGB.cpython-311.pyc
ADDED
|
Binary file (5.58 kB). View file
|
|
|
MajorTOM/embedder/models/__pycache__/SSL4EO_S1RTC.cpython-311.pyc
ADDED
|
Binary file (7.02 kB). View file
|
|
|
MajorTOM/embedder/models/__pycache__/SSL4EO_S2L1C.cpython-311.pyc
ADDED
|
Binary file (4.75 kB). View file
|
|
|
MajorTOM/embedder/models/__pycache__/SigLIP_S2RGB.cpython-311.pyc
ADDED
|
Binary file (3.72 kB). View file
|
|
|
MajorTOM/embedder/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (308 Bytes). View file
|
|
|
MajorTOM/extras/coverage-example.png
ADDED
|
Git LFS Details
|
MajorTOM/extras/coverage_vis.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from mpl_toolkits.basemap import Basemap
|
| 5 |
+
import PIL
|
| 6 |
+
|
| 7 |
+
def get_mask(df):
|
| 8 |
+
"""
|
| 9 |
+
Take a Major TOM dataframe and create a mask corresponding to available cells
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
mask = np.zeros((2004,4008), dtype=np.uint8)
|
| 13 |
+
row_offset = -1002
|
| 14 |
+
col_offset = -2004
|
| 15 |
+
|
| 16 |
+
nodata = df['nodata'].values > 0.5
|
| 17 |
+
|
| 18 |
+
yy = mask.shape[0] - (np.array(df['grid_row_u']) - row_offset) - 1
|
| 19 |
+
xx = np.array(df['grid_col_r']) - col_offset
|
| 20 |
+
|
| 21 |
+
yy = yy[~nodata]
|
| 22 |
+
xx = xx[~nodata]
|
| 23 |
+
|
| 24 |
+
mask[yy, xx] = 255
|
| 25 |
+
|
| 26 |
+
return PIL.Image.fromarray(mask)
|
| 27 |
+
|
| 28 |
+
def fig2img(fig):
|
| 29 |
+
"""Convert a Matplotlib figure to a PIL Image and return it"""
|
| 30 |
+
import io
|
| 31 |
+
buf = io.BytesIO()
|
| 32 |
+
fig.savefig(buf)
|
| 33 |
+
buf.seek(0)
|
| 34 |
+
img = PIL.Image.open(buf)
|
| 35 |
+
return img
|
| 36 |
+
|
| 37 |
+
def light_basemap():
|
| 38 |
+
"""
|
| 39 |
+
Bright coloured contours
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
with plt.ioff():
|
| 43 |
+
fig, ax = plt.subplots(figsize=(48,24), dpi=167)
|
| 44 |
+
|
| 45 |
+
m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
|
| 46 |
+
m.fillcontinents(color="#9eba9b", lake_color='#CCDDFF')
|
| 47 |
+
m.drawmapboundary(fill_color="#CCDDFF")
|
| 48 |
+
m.drawcountries(color="#666666", linewidth=1)
|
| 49 |
+
m.drawcoastlines(color="#666666", linewidth=1)
|
| 50 |
+
|
| 51 |
+
plt.gca().set_axis_off()
|
| 52 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
|
| 53 |
+
hspace = 0, wspace = 0)
|
| 54 |
+
plt.margins(0,0)
|
| 55 |
+
|
| 56 |
+
return fig2img(fig)
|
| 57 |
+
|
| 58 |
+
def dark_basemap():
|
| 59 |
+
"""
|
| 60 |
+
Dark contours
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
with plt.ioff():
|
| 64 |
+
fig, ax = plt.subplots(figsize=(48,24), dpi=167)
|
| 65 |
+
|
| 66 |
+
m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
|
| 67 |
+
m.fillcontinents(color="#242424", lake_color='#242424')
|
| 68 |
+
m.drawmapboundary(fill_color="#242424")
|
| 69 |
+
m.drawcountries(color="#000000", linewidth=1)
|
| 70 |
+
m.drawcoastlines(color="#000000", linewidth=1)
|
| 71 |
+
|
| 72 |
+
plt.gca().set_axis_off()
|
| 73 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
|
| 74 |
+
hspace = 0, wspace = 0)
|
| 75 |
+
plt.margins(0,0)
|
| 76 |
+
|
| 77 |
+
return fig2img(fig)
|
| 78 |
+
|
| 79 |
+
def get_coveragemap(input, input2=None):
|
| 80 |
+
"""
|
| 81 |
+
Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
|
| 82 |
+
|
| 83 |
+
Optionally, input2 can be provided and then, the map plots a map with extra colours indicating cells available only in input (green) or only input2 (blue)
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
if input2 is None:
|
| 87 |
+
return single_coveragemap(input)
|
| 88 |
+
else:
|
| 89 |
+
cmap1 = single_coveragemap(input)
|
| 90 |
+
cmap2 = single_coveragemap(input2)
|
| 91 |
+
|
| 92 |
+
# arrays for mixing
|
| 93 |
+
inp1_arr = np.array(cmap1)[...,:3]
|
| 94 |
+
inp2_arr = np.array(cmap2)[...,:3]
|
| 95 |
+
|
| 96 |
+
common_arr = inp1_arr*(inp1_arr.sum(-1) == inp2_arr.sum(-1))[:,:,None]
|
| 97 |
+
common_arr[:,:,(1,2)] = 0
|
| 98 |
+
inp1_arr[:,:,(0,2)] = 0 # Green - indicates presence of S2 only
|
| 99 |
+
inp2_arr[:,:,(0,1)] = 0 # Blue - indicates presense of DEM only
|
| 100 |
+
|
| 101 |
+
return PIL.Image.fromarray(((common_arr + inp1_arr + inp2_arr)).astype(np.uint8))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def single_coveragemap(input):
|
| 105 |
+
"""
|
| 106 |
+
Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# compute mask if df is provided
|
| 110 |
+
if isinstance(input, pd.DataFrame):
|
| 111 |
+
mask = get_mask(input)
|
| 112 |
+
else:
|
| 113 |
+
mask = input
|
| 114 |
+
|
| 115 |
+
basemap = light_basemap()
|
| 116 |
+
basemap_d = dark_basemap()
|
| 117 |
+
|
| 118 |
+
outside_earth = np.array(basemap.convert('RGBA'))[:, :, 0] == 255
|
| 119 |
+
outside_earth = PIL.Image.fromarray(outside_earth)
|
| 120 |
+
|
| 121 |
+
mask = mask.resize(basemap.size, PIL.Image.NEAREST)
|
| 122 |
+
|
| 123 |
+
basemap.putalpha(mask)
|
| 124 |
+
|
| 125 |
+
# Mask outside of earth
|
| 126 |
+
basemap.paste(outside_earth, (0,0), outside_earth)
|
| 127 |
+
|
| 128 |
+
basemap_d.paste(basemap, (0,0), basemap)
|
| 129 |
+
|
| 130 |
+
return basemap_d
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
DATASET_NAME = 'Major-TOM/Core-S2L2A'
|
| 134 |
+
meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
|
| 135 |
+
df = pd.read_parquet(meta_path)
|
| 136 |
+
|
| 137 |
+
# This is how you make a coverage figure!
|
| 138 |
+
coverage_img = get_coveragemap(df)
|
| 139 |
+
|
| 140 |
+
coverage_img.save('coverage-example.png', format='PNG')
|
| 141 |
+
|
| 142 |
+
# and this is how you can create an overap for 2 datasets!
|
| 143 |
+
DATASET_NAME = 'Major-TOM/Core-DEM'
|
| 144 |
+
meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
|
| 145 |
+
dem_df = pd.read_parquet(meta_path)
|
| 146 |
+
|
| 147 |
+
coverage_img = get_coveragemap(df,dem_df)
|
| 148 |
+
|
| 149 |
+
coverage_img.save('overlap-coverage-example.png', format='PNG')
|
MajorTOM/extras/extract-sample-from-raw-S2.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
MajorTOM/extras/thumbnail_dem.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
|
| 3 |
+
|
| 4 |
+
Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from rasterio.io import MemoryFile
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import rasterio as rio
|
| 13 |
+
from matplotlib.colors import LightSource
|
| 14 |
+
|
| 15 |
+
def get_grayscale(x):
|
| 16 |
+
"""
|
| 17 |
+
Normalized grayscale visualisation
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# normalize
|
| 21 |
+
x_n = x-x.min()
|
| 22 |
+
x_n = x_n/x_n.max()
|
| 23 |
+
|
| 24 |
+
return np.uint8(x_n*255)
|
| 25 |
+
|
| 26 |
+
def get_hillshade(x, azdeg=315, altdeg=45,ve=1):
|
| 27 |
+
"""
|
| 28 |
+
Hillshade visualisation for DEM
|
| 29 |
+
"""
|
| 30 |
+
ls = LightSource(azdeg=azdeg, altdeg=altdeg)
|
| 31 |
+
|
| 32 |
+
return np.uint8(255*ls.hillshade(x, vert_exag=ve))
|
| 33 |
+
|
| 34 |
+
def dem_thumbnail(dem, dem_NODATA = -32768.0, hillshade=True):
|
| 35 |
+
"""
|
| 36 |
+
Takes vv and vh numpy arrays along with the corresponding NODATA values (default is -32768.0)
|
| 37 |
+
|
| 38 |
+
Returns a numpy array with the thumbnail
|
| 39 |
+
"""
|
| 40 |
+
if hillshade:
|
| 41 |
+
return get_hillshade(dem)
|
| 42 |
+
else:
|
| 43 |
+
return get_grayscale(dem)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def dem_thumbnail_from_datarow(datarow):
|
| 47 |
+
"""
|
| 48 |
+
Takes a datarow directly from one of the data parquet files
|
| 49 |
+
|
| 50 |
+
Returns a PIL Image
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
with MemoryFile(datarow['DEM'][0].as_py()) as mem_f:
|
| 54 |
+
with mem_f.open(driver='GTiff') as f:
|
| 55 |
+
dem=f.read().squeeze()
|
| 56 |
+
dem_NODATA = f.nodata
|
| 57 |
+
|
| 58 |
+
img = dem_thumbnail(dem, dem_NODATA)
|
| 59 |
+
|
| 60 |
+
return Image.fromarray(img,'L')
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
from fsspec.parquet import open_parquet_file
|
| 64 |
+
import pyarrow.parquet as pq
|
| 65 |
+
|
| 66 |
+
print('[example run] reading file from HuggingFace...')
|
| 67 |
+
url = "https://huggingface.co/datasets/Major-TOM/Core-DEM/resolve/main/images/part_01001.parquet"
|
| 68 |
+
with open_parquet_file(url) as f:
|
| 69 |
+
with pq.ParquetFile(f) as pf:
|
| 70 |
+
first_row_group = pf.read_row_group(1)
|
| 71 |
+
|
| 72 |
+
print('[example run] computing the thumbnail...')
|
| 73 |
+
thumbnail = dem_thumbnail_from_datarow(first_row_group)
|
| 74 |
+
|
| 75 |
+
thumbnail_fname = 'example_thumbnail.png'
|
| 76 |
+
thumbnail.save(thumbnail_fname, format = 'PNG')
|
| 77 |
+
print('[example run] saved as "{}"'.format(thumbnail_fname))
|
MajorTOM/extras/thumbnail_s1rtc.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
|
| 3 |
+
|
| 4 |
+
Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from rasterio.io import MemoryFile
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
def s1rtc_thumbnail(vv, vh, vv_NODATA = -32768.0, vh_NODATA = -32768.0):
|
| 12 |
+
"""
|
| 13 |
+
Takes vv and vh numpy arrays along with the corresponding NODATA values (default is -32768.0)
|
| 14 |
+
|
| 15 |
+
Returns a numpy array with the thumbnail
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# valid data masks
|
| 19 |
+
vv_mask = vv != vv_NODATA
|
| 20 |
+
vh_mask = vh != vh_NODATA
|
| 21 |
+
|
| 22 |
+
# remove invalid values before log op
|
| 23 |
+
vv[vv<0] = vv[vv>=0].min()
|
| 24 |
+
vh[vh<0] = vh[vh>=0].min()
|
| 25 |
+
|
| 26 |
+
# apply log op
|
| 27 |
+
vv_dB = 10*np.log10(vv)
|
| 28 |
+
vh_dB = 10*np.log10(vh)
|
| 29 |
+
|
| 30 |
+
# scale to 0-255
|
| 31 |
+
vv_dB = (vv_dB - vv_dB[vv_mask].min()) / (vv_dB[vv_mask].max() - vv_dB[vv_mask].min()) * 255
|
| 32 |
+
vh_dB = (vh_dB - vh_dB[vh_mask].min()) / (vh_dB[vh_mask].max() - vh_dB[vh_mask].min()) * 255
|
| 33 |
+
|
| 34 |
+
# represent nodata as 0
|
| 35 |
+
vv_dB[vv_mask==0] = 0
|
| 36 |
+
vh_dB[vh_mask==0] = 0
|
| 37 |
+
|
| 38 |
+
# false colour composite
|
| 39 |
+
return np.stack([vv_dB,
|
| 40 |
+
255*(vv_dB+vh_dB)/np.max(vv_dB+vh_dB),
|
| 41 |
+
vh_dB
|
| 42 |
+
],-1).astype(np.uint8)
|
| 43 |
+
|
| 44 |
+
def s1rtc_thumbnail_from_datarow(datarow):
|
| 45 |
+
"""
|
| 46 |
+
Takes a datarow directly from one of the data parquet files
|
| 47 |
+
|
| 48 |
+
Returns a PIL Image
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
with MemoryFile(datarow['vv'][0].as_py()) as mem_f:
|
| 52 |
+
with mem_f.open(driver='GTiff') as f:
|
| 53 |
+
vv=f.read().squeeze()
|
| 54 |
+
vv_NODATA = f.nodata
|
| 55 |
+
|
| 56 |
+
with MemoryFile(datarow['vh'][0].as_py()) as mem_f:
|
| 57 |
+
with mem_f.open(driver='GTiff') as f:
|
| 58 |
+
vh=f.read().squeeze()
|
| 59 |
+
vh_NODATA = f.nodata
|
| 60 |
+
|
| 61 |
+
img = s1rtc_thumbnail(vv, vh, vv_NODATA=vv_NODATA, vh_NODATA=vh_NODATA)
|
| 62 |
+
|
| 63 |
+
return Image.fromarray(img)
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
from fsspec.parquet import open_parquet_file
|
| 67 |
+
import pyarrow.parquet as pq
|
| 68 |
+
|
| 69 |
+
print('[example run] reading file from HuggingFace...')
|
| 70 |
+
url = "https://huggingface.co/datasets/Major-TOM/Core-S1RTC/resolve/main/images/part_00001.parquet"
|
| 71 |
+
with open_parquet_file(url) as f:
|
| 72 |
+
with pq.ParquetFile(f) as pf:
|
| 73 |
+
first_row_group = pf.read_row_group(1)
|
| 74 |
+
|
| 75 |
+
print('[example run] computing the thumbnail...')
|
| 76 |
+
thumbnail = s1rtc_thumbnail_from_datarow(first_row_group)
|
| 77 |
+
|
| 78 |
+
thumbnail_fname = 'example_thumbnail.png'
|
| 79 |
+
thumbnail.save(thumbnail_fname, format = 'PNG')
|
| 80 |
+
print('[example run] saved as "{}"'.format(thumbnail_fname))
|
MajorTOM/extras/thumbnail_s2.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NOTE: Major TOM standard does not require any specific type of thumbnail to be computed.
|
| 3 |
+
|
| 4 |
+
Instead these are shared as optional help since this is how the Core dataset thumbnails have been computed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from rasterio.io import MemoryFile
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
def s2l2a_thumbnail(B04, B03, B02, gain=1.3, gamma=0.6):
|
| 12 |
+
"""
|
| 13 |
+
Takes B04, B03, B02 numpy arrays along with the corresponding NODATA values (default is -32768.0)
|
| 14 |
+
|
| 15 |
+
Returns a numpy array with the thumbnail
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# concatenate
|
| 19 |
+
thumb = np.stack([B04, B03, B02], -1)
|
| 20 |
+
|
| 21 |
+
# apply gain & gamma
|
| 22 |
+
thumb = gain*((thumb/10_000)**gamma)
|
| 23 |
+
|
| 24 |
+
return (thumb.clip(0,1)*255).astype(np.uint8)
|
| 25 |
+
|
| 26 |
+
def s2l2a_thumbnail_from_datarow(datarow):
|
| 27 |
+
"""
|
| 28 |
+
Takes a datarow directly from one of the data parquet files
|
| 29 |
+
|
| 30 |
+
Returns a PIL Image
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# red
|
| 34 |
+
with MemoryFile(datarow['B04'][0].as_py()) as mem_f:
|
| 35 |
+
with mem_f.open(driver='GTiff') as f:
|
| 36 |
+
B04=f.read().squeeze()
|
| 37 |
+
B04_NODATA = f.nodata
|
| 38 |
+
|
| 39 |
+
# green
|
| 40 |
+
with MemoryFile(datarow['B03'][0].as_py()) as mem_f:
|
| 41 |
+
with mem_f.open(driver='GTiff') as f:
|
| 42 |
+
B03=f.read().squeeze()
|
| 43 |
+
B03_NODATA = f.nodata
|
| 44 |
+
|
| 45 |
+
# blue
|
| 46 |
+
with MemoryFile(datarow['B02'][0].as_py()) as mem_f:
|
| 47 |
+
with mem_f.open(driver='GTiff') as f:
|
| 48 |
+
B02=f.read().squeeze()
|
| 49 |
+
B02_NODATA = f.nodata
|
| 50 |
+
|
| 51 |
+
img = s2l2a_thumbnail(B04,B03,B02)
|
| 52 |
+
|
| 53 |
+
return Image.fromarray(img)
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
from fsspec.parquet import open_parquet_file
|
| 57 |
+
import pyarrow.parquet as pq
|
| 58 |
+
|
| 59 |
+
print('[example run] reading file from HuggingFace...')
|
| 60 |
+
url = "https://huggingface.co/datasets/Major-TOM/Core-S2L2A/resolve/main/images/part_01000.parquet"
|
| 61 |
+
with open_parquet_file(url, columns = ["B04", "B03", "B02"]) as f:
|
| 62 |
+
with pq.ParquetFile(f) as pf:
|
| 63 |
+
first_row_group = pf.read_row_group(1, columns = ["B04", "B03", "B02"])
|
| 64 |
+
|
| 65 |
+
print('[example run] computing the thumbnail...')
|
| 66 |
+
thumbnail = s2l2a_thumbnail_from_datarow(first_row_group)
|
| 67 |
+
|
| 68 |
+
thumbnail.save('example_thumbnail.png', format = 'PNG')
|
MajorTOM/grid.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import geopandas as gpd
|
| 5 |
+
from shapely.geometry import LineString, Polygon
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Grid():
|
| 12 |
+
|
| 13 |
+
RADIUS_EQUATOR = 6378.137 # km
|
| 14 |
+
|
| 15 |
+
def __init__(self,dist,latitude_range=(-85,85),longitude_range=(-180,180),utm_definition='bottomleft'):
|
| 16 |
+
self.dist = dist
|
| 17 |
+
self.latitude_range = latitude_range
|
| 18 |
+
self.longitude_range = longitude_range
|
| 19 |
+
self.utm_definition = utm_definition
|
| 20 |
+
self.rows,self.lats = self.get_rows()
|
| 21 |
+
self.points, self.points_by_row = self.get_points()
|
| 22 |
+
|
| 23 |
+
def get_rows(self):
|
| 24 |
+
|
| 25 |
+
# Define set of latitudes to use, based on the grid distance
|
| 26 |
+
arc_pole_to_pole = math.pi * self.RADIUS_EQUATOR
|
| 27 |
+
num_divisions_in_hemisphere = math.ceil(arc_pole_to_pole / self.dist)
|
| 28 |
+
|
| 29 |
+
latitudes = np.linspace(-90, 90, num_divisions_in_hemisphere+1)[:-1]
|
| 30 |
+
latitudes = np.mod(latitudes, 180) - 90
|
| 31 |
+
|
| 32 |
+
# order should be from south to north
|
| 33 |
+
latitudes = np.sort(latitudes)
|
| 34 |
+
|
| 35 |
+
zeroth_row = np.searchsorted(latitudes,0)
|
| 36 |
+
|
| 37 |
+
# From 0U-NU and 1D-ND
|
| 38 |
+
rows = [None] * len(latitudes)
|
| 39 |
+
rows[zeroth_row:] = [f'{i}U' for i in range(len(latitudes)-zeroth_row)]
|
| 40 |
+
rows[:zeroth_row] = [f'{abs(i-zeroth_row)}D' for i in range(zeroth_row)]
|
| 41 |
+
|
| 42 |
+
# bound to range
|
| 43 |
+
idxs = (latitudes>=self.latitude_range[0]) * (latitudes<=self.latitude_range[1])
|
| 44 |
+
rows,latitudes = np.array(rows), np.array(latitudes)
|
| 45 |
+
rows,latitudes = rows[idxs],latitudes[idxs]
|
| 46 |
+
|
| 47 |
+
return rows,latitudes
|
| 48 |
+
|
| 49 |
+
def get_circumference_at_latitude(self,lat):
|
| 50 |
+
|
| 51 |
+
# Circumference of the cross-section of a sphere at a given latitude
|
| 52 |
+
|
| 53 |
+
radius_at_lat = self.RADIUS_EQUATOR * math.cos(lat * math.pi / 180)
|
| 54 |
+
circumference = 2 * math.pi * radius_at_lat
|
| 55 |
+
|
| 56 |
+
return circumference
|
| 57 |
+
|
| 58 |
+
def subdivide_circumference(self,lat,return_cols=False):
|
| 59 |
+
# Provide a list of longitudes that subdivide the circumference of the earth at a given latitude
|
| 60 |
+
# into equal parts as close as possible to dist
|
| 61 |
+
|
| 62 |
+
circumference = self.get_circumference_at_latitude(lat)
|
| 63 |
+
num_divisions = math.ceil(circumference / self.dist)
|
| 64 |
+
longitudes = np.linspace(-180,180, num_divisions+1)[:-1]
|
| 65 |
+
longitudes = np.mod(longitudes, 360) - 180
|
| 66 |
+
longitudes = np.sort(longitudes)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if return_cols:
|
| 70 |
+
cols = [None] * len(longitudes)
|
| 71 |
+
zeroth_idx = np.where(longitudes==0)[0][0]
|
| 72 |
+
cols[zeroth_idx:] = [f'{i}R' for i in range(len(longitudes)-zeroth_idx)]
|
| 73 |
+
cols[:zeroth_idx] = [f'{abs(i-zeroth_idx)}L' for i in range(zeroth_idx)]
|
| 74 |
+
return np.array(cols),np.array(longitudes)
|
| 75 |
+
|
| 76 |
+
return np.array(longitudes)
|
| 77 |
+
|
| 78 |
+
def get_points(self):
|
| 79 |
+
|
| 80 |
+
r_idx = 0
|
| 81 |
+
points_by_row = [None]*len(self.rows)
|
| 82 |
+
for r,lat in zip(self.rows,self.lats):
|
| 83 |
+
point_names,grid_row_names,grid_col_names,grid_row_idx,grid_col_idx,grid_lats,grid_lons,utm_zones,epsgs = [],[],[],[],[],[],[],[],[]
|
| 84 |
+
cols,lons = self.subdivide_circumference(lat,return_cols=True)
|
| 85 |
+
|
| 86 |
+
cols,lons = self.filter_longitude(cols,lons)
|
| 87 |
+
c_idx = 0
|
| 88 |
+
for c,lon in zip(cols,lons):
|
| 89 |
+
point_names.append(f'{r}_{c}')
|
| 90 |
+
grid_row_names.append(r)
|
| 91 |
+
grid_col_names.append(c)
|
| 92 |
+
grid_row_idx.append(r_idx)
|
| 93 |
+
grid_col_idx.append(c_idx)
|
| 94 |
+
grid_lats.append(lat)
|
| 95 |
+
grid_lons.append(lon)
|
| 96 |
+
if self.utm_definition == 'bottomleft':
|
| 97 |
+
utm_zones.append(get_utm_zone_from_latlng([lat,lon]))
|
| 98 |
+
elif self.utm_definition == 'center':
|
| 99 |
+
center_lat = lat + (1000*self.dist/2)/111_120
|
| 100 |
+
center_lon = lon + (1000*self.dist/2)/(111_120*math.cos(center_lat*math.pi/180))
|
| 101 |
+
utm_zones.append(get_utm_zone_from_latlng([center_lat,center_lon]))
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f'Invalid utm_definition {self.utm_definition}')
|
| 104 |
+
epsgs.append(f'EPSG:{utm_zones[-1]}')
|
| 105 |
+
|
| 106 |
+
c_idx += 1
|
| 107 |
+
points_by_row[r_idx] = gpd.GeoDataFrame({
|
| 108 |
+
'name':point_names,
|
| 109 |
+
'row':grid_row_names,
|
| 110 |
+
'col':grid_col_names,
|
| 111 |
+
'row_idx':grid_row_idx,
|
| 112 |
+
'col_idx':grid_col_idx,
|
| 113 |
+
'utm_zone':utm_zones,
|
| 114 |
+
'epsg':epsgs
|
| 115 |
+
},geometry=gpd.points_from_xy(grid_lons,grid_lats))
|
| 116 |
+
r_idx += 1
|
| 117 |
+
points = gpd.GeoDataFrame(pd.concat(points_by_row))
|
| 118 |
+
# points.reset_index(inplace=True,drop=True)
|
| 119 |
+
return points, points_by_row
|
| 120 |
+
|
| 121 |
+
def group_points_by_row(self):
|
| 122 |
+
# Make list of different gdfs for each row
|
| 123 |
+
points_by_row = [None]*len(self.rows)
|
| 124 |
+
for i,row in enumerate(self.rows):
|
| 125 |
+
points_by_row[i] = self.points[self.points.row==row]
|
| 126 |
+
return points_by_row
|
| 127 |
+
|
| 128 |
+
def filter_longitude(self,cols,lons):
|
| 129 |
+
idxs = (lons>=self.longitude_range[0]) * (lons<=self.longitude_range[1])
|
| 130 |
+
cols,lons = cols[idxs],lons[idxs]
|
| 131 |
+
return cols,lons
|
| 132 |
+
|
| 133 |
+
def latlon2rowcol(self,lats,lons,return_idx=False,integer=False):
|
| 134 |
+
"""
|
| 135 |
+
Convert latitude and longitude to row and column number from the grid
|
| 136 |
+
"""
|
| 137 |
+
# Always take bottom left corner of grid cell
|
| 138 |
+
rows = np.searchsorted(self.lats,lats)-1
|
| 139 |
+
|
| 140 |
+
# Get the possible points of the grid cells at the given latitude
|
| 141 |
+
possible_points = [self.points_by_row[row] for row in rows]
|
| 142 |
+
|
| 143 |
+
# For each point, find the rightmost point that is still to the left of the given longitude
|
| 144 |
+
cols = [poss_points.iloc[np.searchsorted(poss_points.geometry.x,lon)-1].col for poss_points,lon in zip(possible_points,lons)]
|
| 145 |
+
rows = self.rows[rows].tolist()
|
| 146 |
+
|
| 147 |
+
outputs = [rows, cols]
|
| 148 |
+
if return_idx:
|
| 149 |
+
# Get the table index for self.points with each row,col pair in rows, cols
|
| 150 |
+
idx = [self.points[(self.points.row==row) & (self.points.col==col)].index.values[0] for row,col in zip(rows,cols)]
|
| 151 |
+
outputs.append(idx)
|
| 152 |
+
|
| 153 |
+
# return raw numbers
|
| 154 |
+
if integer:
|
| 155 |
+
outputs[0] = [int(el[:-1]) if el[-1] == 'U' else -int(el[:-1]) for el in outputs[0]]
|
| 156 |
+
outputs[1] = [int(el[:-1]) if el[-1] == 'R' else -int(el[:-1]) for el in outputs[1]]
|
| 157 |
+
|
| 158 |
+
return outputs
|
| 159 |
+
|
| 160 |
+
def rowcol2latlon(self,rows,cols):
|
| 161 |
+
point_geoms = [self.points.loc[(self.points.row==row) & (self.points.col==col),'geometry'].values[0] for row,col in zip(rows,cols)]
|
| 162 |
+
lats = [point.y for point in point_geoms]
|
| 163 |
+
lons = [point.x for point in point_geoms]
|
| 164 |
+
return lats,lons
|
| 165 |
+
|
| 166 |
+
def get_bounded_footprint(self,point,buffer_ratio=0):
|
| 167 |
+
# Gets the polygon footprint of the grid cell for a given point, bounded by the other grid points' cells.
|
| 168 |
+
# Grid point defined as bottom-left corner of polygon. Buffer ratio is the ratio of the grid cell's width/height to buffer by.
|
| 169 |
+
|
| 170 |
+
bottom,left = point.geometry.y,point.geometry.x
|
| 171 |
+
row_idx = point.row_idx
|
| 172 |
+
col_idx = point.col_idx
|
| 173 |
+
next_row_idx = row_idx+1
|
| 174 |
+
next_col_idx = col_idx+1
|
| 175 |
+
|
| 176 |
+
if next_row_idx >= len(self.lats): # If at top row, use difference between top and second-to-top row for height
|
| 177 |
+
height = (self.lats[row_idx] - self.lats[row_idx-1])
|
| 178 |
+
top = self.lats[row_idx] + height
|
| 179 |
+
else:
|
| 180 |
+
top = self.lats[next_row_idx]
|
| 181 |
+
|
| 182 |
+
max_col = len(self.points_by_row[row_idx].col_idx)-1
|
| 183 |
+
if next_col_idx > max_col: # If at rightmost column, use difference between rightmost and second-to-rightmost column for width
|
| 184 |
+
width = (self.points_by_row[row_idx].iloc[col_idx].geometry.x - self.points_by_row[row_idx].iloc[col_idx-1].geometry.x)
|
| 185 |
+
right = self.points_by_row[row_idx].iloc[col_idx].geometry.x + width
|
| 186 |
+
else:
|
| 187 |
+
right = self.points_by_row[row_idx].iloc[next_col_idx].geometry.x
|
| 188 |
+
|
| 189 |
+
# Buffer the polygon by the ratio of the grid cell's width/height
|
| 190 |
+
width = right - left
|
| 191 |
+
height = top - bottom
|
| 192 |
+
|
| 193 |
+
buffer_horizontal = width * buffer_ratio
|
| 194 |
+
buffer_vertical = height * buffer_ratio
|
| 195 |
+
|
| 196 |
+
new_left = left - buffer_horizontal
|
| 197 |
+
new_right = right + buffer_horizontal
|
| 198 |
+
|
| 199 |
+
new_bottom = bottom - buffer_vertical
|
| 200 |
+
new_top = top + buffer_vertical
|
| 201 |
+
|
| 202 |
+
bbox = Polygon([(new_left,new_bottom),(new_left,new_top),(new_right,new_top),(new_right,new_bottom)])
|
| 203 |
+
|
| 204 |
+
return bbox
|
| 205 |
+
|
| 206 |
+
def get_utm_zone_from_latlng(latlng):
|
| 207 |
+
"""
|
| 208 |
+
Get the UTM zone from a latlng list and return the corresponding EPSG code.
|
| 209 |
+
|
| 210 |
+
Parameters
|
| 211 |
+
----------
|
| 212 |
+
latlng : List[Union[int, float]]
|
| 213 |
+
The latlng list to get the UTM zone from.
|
| 214 |
+
|
| 215 |
+
Returns
|
| 216 |
+
-------
|
| 217 |
+
str
|
| 218 |
+
The EPSG code for the UTM zone.
|
| 219 |
+
"""
|
| 220 |
+
assert isinstance(latlng, (list, tuple)), "latlng must be in the form of a list or tuple."
|
| 221 |
+
|
| 222 |
+
longitude = latlng[1]
|
| 223 |
+
latitude = latlng[0]
|
| 224 |
+
|
| 225 |
+
zone_number = (math.floor((longitude + 180) / 6)) % 60 + 1
|
| 226 |
+
|
| 227 |
+
# Special zones for Svalbard and Norway
|
| 228 |
+
if latitude >= 56.0 and latitude < 64.0 and longitude >= 3.0 and longitude < 12.0:
|
| 229 |
+
zone_number = 32
|
| 230 |
+
elif latitude >= 72.0 and latitude < 84.0:
|
| 231 |
+
if longitude >= 0.0 and longitude < 9.0:
|
| 232 |
+
zone_number = 31
|
| 233 |
+
elif longitude >= 9.0 and longitude < 21.0:
|
| 234 |
+
zone_number = 33
|
| 235 |
+
elif longitude >= 21.0 and longitude < 33.0:
|
| 236 |
+
zone_number = 35
|
| 237 |
+
elif longitude >= 33.0 and longitude < 42.0:
|
| 238 |
+
zone_number = 37
|
| 239 |
+
|
| 240 |
+
# Determine the hemisphere and construct the EPSG code
|
| 241 |
+
if latitude < 0:
|
| 242 |
+
epsg_code = f"327{zone_number:02d}"
|
| 243 |
+
else:
|
| 244 |
+
epsg_code = f"326{zone_number:02d}"
|
| 245 |
+
if not re.match(r"32[6-7](0[1-9]|[1-5][0-9]|60)",epsg_code):
|
| 246 |
+
print(f"latlng: {latlng}, epsg_code: {epsg_code}")
|
| 247 |
+
raise ValueError(f"out of bound latlng resulted in incorrect EPSG code for the point")
|
| 248 |
+
|
| 249 |
+
return epsg_code
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if __name__ == '__main__':
|
| 253 |
+
|
| 254 |
+
assert get_utm_zone_from_latlng([-1,-174.34]) == "32701"
|
| 255 |
+
assert get_utm_zone_from_latlng([48,-4]) == "32630"
|
| 256 |
+
assert get_utm_zone_from_latlng([78,13]) == "32633"
|
| 257 |
+
assert get_utm_zone_from_latlng([-34,19.7]) == "32734"
|
| 258 |
+
assert get_utm_zone_from_latlng([-36,175.7]) == "32760"
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
dist = 100
|
| 262 |
+
grid = Grid(dist)
|
| 263 |
+
|
| 264 |
+
np.random.seed(0)
|
| 265 |
+
test_lons = np.random.uniform(-20,20,size=(1000)) % 180 # Checks edge-case of crossing 180th meridian
|
| 266 |
+
test_lats = np.random.uniform(-20,68,size=(1000))
|
| 267 |
+
|
| 268 |
+
test_rows,test_cols = grid.latlon2rowcol(test_lats,test_lons)
|
| 269 |
+
test_lats2,test_lons2 = grid.rowcol2latlon(test_rows,test_cols)
|
| 270 |
+
|
| 271 |
+
print(test_lons[:10])
|
| 272 |
+
print(test_lats[:10])
|
| 273 |
+
print(test_rows[:10])
|
| 274 |
+
print(test_cols[:10])
|
| 275 |
+
|
| 276 |
+
# Make line segments from the points to their corresponding grid points
|
| 277 |
+
lines = []
|
| 278 |
+
for i in range(len(test_lats)):
|
| 279 |
+
lines.append([(test_lons[i],test_lats[i]),(test_lons2[i],test_lats2[i])])
|
| 280 |
+
|
| 281 |
+
lines = gpd.GeoDataFrame(geometry=gpd.GeoSeries([LineString(line) for line in lines]))
|
| 282 |
+
|
| 283 |
+
lines.to_file(f'testlines_{dist}km.geojson',driver='GeoJSON')
|
| 284 |
+
grid.points.to_file(f'testgrid_{dist}km.geojson',driver='GeoJSON')
|
MajorTOM/metadata_helpers.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pyarrow.parquet as pq
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import geopandas as gpd
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import urllib.request
|
| 6 |
+
import fsspec
|
| 7 |
+
from fsspec.parquet import open_parquet_file
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from rasterio.io import MemoryFile
|
| 11 |
+
from tqdm.notebook import tqdm
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
from .sample_helpers import *
|
| 15 |
+
|
| 16 |
+
def metadata_from_url(access_url, local_url):
|
| 17 |
+
local_url, response = urllib.request.urlretrieve(access_url, local_url)
|
| 18 |
+
df = pq.read_table(local_url).to_pandas()
|
| 19 |
+
df['timestamp'] = pd.to_datetime(df.timestamp)
|
| 20 |
+
gdf = gpd.GeoDataFrame(
|
| 21 |
+
df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
|
| 22 |
+
)
|
| 23 |
+
return gdf
|
| 24 |
+
|
| 25 |
+
def filter_metadata(df,
|
| 26 |
+
region=None,
|
| 27 |
+
daterange=None,
|
| 28 |
+
cloud_cover=(0,100),
|
| 29 |
+
nodata=(0, 1.0)
|
| 30 |
+
):
|
| 31 |
+
"""Filters the Major-TOM dataframe based on several parameters
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
df (geopandas dataframe): Parent dataframe
|
| 35 |
+
region (shapely geometry object) : Region of interest
|
| 36 |
+
daterange (tuple) : Inclusive range of dates (example format: '2020-01-01')
|
| 37 |
+
cloud_cover (tuple) : Inclusive percentage range (0-100) of cloud cover
|
| 38 |
+
nodata (tuple) : Inclusive fraction (0.0-1.0) of no data allowed in a sample
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
df: a filtered dataframe
|
| 42 |
+
"""
|
| 43 |
+
# temporal filtering
|
| 44 |
+
if daterange is not None:
|
| 45 |
+
assert (isinstance(daterange, list) or isinstance(daterange, tuple)) and len(daterange)==2
|
| 46 |
+
df = df[df.timestamp >= daterange[0]]
|
| 47 |
+
df = df[df.timestamp <= daterange[1]]
|
| 48 |
+
|
| 49 |
+
# spatial filtering
|
| 50 |
+
if region is not None:
|
| 51 |
+
idxs = df.sindex.query(region)
|
| 52 |
+
df = df.take(idxs)
|
| 53 |
+
# cloud filtering
|
| 54 |
+
if cloud_cover is not None:
|
| 55 |
+
df = df[df.cloud_cover >= cloud_cover[0]]
|
| 56 |
+
df = df[df.cloud_cover <= cloud_cover[1]]
|
| 57 |
+
|
| 58 |
+
# spatial filtering
|
| 59 |
+
if nodata is not None:
|
| 60 |
+
df = df[df.nodata >= nodata[0]]
|
| 61 |
+
df = df[df.nodata <= nodata[1]]
|
| 62 |
+
|
| 63 |
+
return df
|
| 64 |
+
|
| 65 |
+
def read_row(row, columns=["thumbnail"]):
|
| 66 |
+
"""Reads a row from a Major-TOM dataframe
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
row (row from geopandas dataframe): The row of metadata
|
| 70 |
+
columns (list): columns to be read from the file
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
data (dict): dictionary with returned data from requested columns
|
| 74 |
+
"""
|
| 75 |
+
with open_parquet_file(row.parquet_url, columns=columns, footer_sample_size=2000000) as f:
|
| 76 |
+
with pq.ParquetFile(f) as pf:
|
| 77 |
+
row_group = pf.read_row_group(row.parquet_row, columns=columns)
|
| 78 |
+
|
| 79 |
+
if columns == ["thumbnail"]:
|
| 80 |
+
stream = BytesIO(row_group['thumbnail'][0].as_py())
|
| 81 |
+
return Image.open(stream)
|
| 82 |
+
else:
|
| 83 |
+
row_output = {}
|
| 84 |
+
for col in columns:
|
| 85 |
+
bytes = row_group[col][0].as_py()
|
| 86 |
+
|
| 87 |
+
if col != 'thumbnail':
|
| 88 |
+
row_output[col] = read_tif_bytes(bytes)
|
| 89 |
+
else:
|
| 90 |
+
stream = BytesIO(bytes)
|
| 91 |
+
row_output[col] = Image.open(stream)
|
| 92 |
+
|
| 93 |
+
return row_output
|
| 94 |
+
|
| 95 |
+
def filter_download(df, local_dir, source_name, by_row = False, verbose = False, tif_columns=None):
|
| 96 |
+
"""Downloads and unpacks the data of Major-TOM based on a metadata dataframe
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
df (geopandas dataframe): Metadata dataframe
|
| 100 |
+
local_dir (str or Path) : Path to the where the data is to be stored locally
|
| 101 |
+
source_name (str) : Name alias of the resulting dataset
|
| 102 |
+
by_row (bool): If True, it will access individual rows of parquet via http - otherwise entire parquets are downloaded temporarily
|
| 103 |
+
verbose (bool) : option for potential internal state printing
|
| 104 |
+
tif_columns (list of str) : Optionally specified columns to be downloaded as .tifs, e.g. ['B04', 'B03', 'B02']
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
None
|
| 108 |
+
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
if isinstance(local_dir, str):
|
| 112 |
+
local_dir = Path(local_dir)
|
| 113 |
+
|
| 114 |
+
temp_file = local_dir / 'temp.parquet'
|
| 115 |
+
|
| 116 |
+
# identify all parquets that need to be downloaded (group them)
|
| 117 |
+
urls = df.parquet_url.unique()
|
| 118 |
+
print('Starting download of {} parquet files.'.format(len(urls))) if verbose else None
|
| 119 |
+
|
| 120 |
+
for url in tqdm(urls, desc='Downloading and unpacking...', disable=not verbose):
|
| 121 |
+
# identify all relevant rows
|
| 122 |
+
rows = df[df.parquet_url == url].parquet_row.unique()
|
| 123 |
+
|
| 124 |
+
if not by_row: # (downloads entire parquet)
|
| 125 |
+
# download a temporary file
|
| 126 |
+
temp_path, http_resp = urllib.request.urlretrieve(url, temp_file)
|
| 127 |
+
else:
|
| 128 |
+
f=fsspec.open(url)
|
| 129 |
+
temp_path = f.open()
|
| 130 |
+
|
| 131 |
+
# populate the bands
|
| 132 |
+
with pq.ParquetFile(temp_path) as pf:
|
| 133 |
+
for row_idx in rows:
|
| 134 |
+
table = pf.read_row_group(row_idx)
|
| 135 |
+
|
| 136 |
+
product_id = table['product_id'][0].as_py()
|
| 137 |
+
grid_cell = table['grid_cell'][0].as_py()
|
| 138 |
+
row = grid_cell.split('_')[0]
|
| 139 |
+
|
| 140 |
+
dest = local_dir / Path("{}/{}/{}/{}".format(source_name, row, grid_cell, product_id))
|
| 141 |
+
dest.mkdir(exist_ok=True, parents=True)
|
| 142 |
+
|
| 143 |
+
columns = [col for col in table.column_names if col[0] == 'B'] + ['cloud_mask'] if tif_columns is None else tif_columns
|
| 144 |
+
# tifs
|
| 145 |
+
for col in columns:
|
| 146 |
+
with open(dest / "{}.tif".format(col), "wb") as f:
|
| 147 |
+
# Write bytes to file
|
| 148 |
+
f.write(table[col][0].as_py())
|
| 149 |
+
|
| 150 |
+
# thumbnail (png)
|
| 151 |
+
col = 'thumbnail'
|
| 152 |
+
with open(dest / "{}.png".format(col), "wb") as f:
|
| 153 |
+
# Write bytes to file
|
| 154 |
+
f.write(table[col][0].as_py())
|
| 155 |
+
if not by_row:
|
| 156 |
+
# remove downloaded file
|
| 157 |
+
os.remove(temp_path)
|
| 158 |
+
else:
|
| 159 |
+
f.close()
|
MajorTOM/sample_helpers.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rasterio.io import MemoryFile
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
|
| 7 |
+
def plot(sample, bands = ['B04', 'B03', 'B02'], scaling=2e3):
|
| 8 |
+
img = []
|
| 9 |
+
for b in bands:
|
| 10 |
+
img.append(read_tif_bytes(sample[b]))
|
| 11 |
+
plt.imshow(np.stack(img, -1)/2e3)
|
| 12 |
+
|
| 13 |
+
def read_tif_bytes(tif_bytes):
|
| 14 |
+
with MemoryFile(tif_bytes) as mem_f:
|
| 15 |
+
with mem_f.open(driver='GTiff') as f:
|
| 16 |
+
return f.read().squeeze()
|
| 17 |
+
|
| 18 |
+
def read_png_bytes(png_bytes):
|
| 19 |
+
stream = BytesIO(png_bytes)
|
| 20 |
+
return Image.open(stream)
|
app.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import zipfile
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 10 |
+
|
| 11 |
+
# Import custom modules
|
| 12 |
+
from models.siglip_model import SigLIPModel
|
| 13 |
+
from models.satclip_model import SatCLIPModel
|
| 14 |
+
from models.farslip_model import FarSLIPModel
|
| 15 |
+
from models.dinov2_model import DINOv2Model
|
| 16 |
+
from models.load_config import load_and_process_config
|
| 17 |
+
from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution
|
| 18 |
+
from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image
|
| 19 |
+
from PIL import Image as PILImage
|
| 20 |
+
from PIL import ImageDraw, ImageFont
|
| 21 |
+
|
| 22 |
+
# Configuration
|
| 23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
+
print(f"Running on device: {device}")
|
| 25 |
+
|
| 26 |
+
# Load and process configuration
|
| 27 |
+
config = load_and_process_config()
|
| 28 |
+
print(config)
|
| 29 |
+
|
| 30 |
+
# Initialize Models
|
| 31 |
+
print("Initializing models...")
|
| 32 |
+
models = {}
|
| 33 |
+
|
| 34 |
+
# DINOv2
|
| 35 |
+
try:
|
| 36 |
+
if config and 'dinov2' in config:
|
| 37 |
+
models['DINOv2'] = DINOv2Model(
|
| 38 |
+
ckpt_path=config['dinov2'].get('ckpt_path'),
|
| 39 |
+
embedding_path=config['dinov2'].get('embedding_path'),
|
| 40 |
+
device=device
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
models['DINOv2'] = DINOv2Model(device=device)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Failed to load DINOv2: {e}")
|
| 46 |
+
|
| 47 |
+
# SigLIP
|
| 48 |
+
try:
|
| 49 |
+
if config and 'siglip' in config:
|
| 50 |
+
models['SigLIP'] = SigLIPModel(
|
| 51 |
+
ckpt_path=config['siglip'].get('ckpt_path'),
|
| 52 |
+
tokenizer_path=config['siglip'].get('tokenizer_path'),
|
| 53 |
+
embedding_path=config['siglip'].get('embedding_path'),
|
| 54 |
+
device=device
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
models['SigLIP'] = SigLIPModel(device=device)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Failed to load SigLIP: {e}")
|
| 60 |
+
|
| 61 |
+
# SatCLIP
|
| 62 |
+
try:
|
| 63 |
+
if config and 'satclip' in config:
|
| 64 |
+
models['SatCLIP'] = SatCLIPModel(
|
| 65 |
+
ckpt_path=config['satclip'].get('ckpt_path'),
|
| 66 |
+
embedding_path=config['satclip'].get('embedding_path'),
|
| 67 |
+
device=device
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
models['SatCLIP'] = SatCLIPModel(device=device)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Failed to load SatCLIP: {e}")
|
| 73 |
+
|
| 74 |
+
# FarSLIP
|
| 75 |
+
try:
|
| 76 |
+
if config and 'farslip' in config:
|
| 77 |
+
models['FarSLIP'] = FarSLIPModel(
|
| 78 |
+
ckpt_path=config['farslip'].get('ckpt_path'),
|
| 79 |
+
model_name=config['farslip'].get('model_name'),
|
| 80 |
+
embedding_path=config['farslip'].get('embedding_path'),
|
| 81 |
+
device=device
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
models['FarSLIP'] = FarSLIPModel(device=device)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Failed to load FarSLIP: {e}")
|
| 87 |
+
|
| 88 |
+
def get_active_model(model_name):
|
| 89 |
+
if model_name not in models:
|
| 90 |
+
return None, f"Model {model_name} not loaded."
|
| 91 |
+
return models[model_name], None
|
| 92 |
+
|
| 93 |
+
def combine_images(img1, img2):
|
| 94 |
+
if img1 is None: return img2
|
| 95 |
+
if img2 is None: return img1
|
| 96 |
+
|
| 97 |
+
# Resize to match width
|
| 98 |
+
w1, h1 = img1.size
|
| 99 |
+
w2, h2 = img2.size
|
| 100 |
+
|
| 101 |
+
new_w = max(w1, w2)
|
| 102 |
+
new_h1 = int(h1 * new_w / w1)
|
| 103 |
+
new_h2 = int(h2 * new_w / w2)
|
| 104 |
+
|
| 105 |
+
img1 = img1.resize((new_w, new_h1))
|
| 106 |
+
img2 = img2.resize((new_w, new_h2))
|
| 107 |
+
|
| 108 |
+
dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255))
|
| 109 |
+
dst.paste(img1, (0, 0))
|
| 110 |
+
dst.paste(img2, (0, new_h1))
|
| 111 |
+
return dst
|
| 112 |
+
|
| 113 |
+
def create_text_image(text, size=(384, 384)):
|
| 114 |
+
img = PILImage.new('RGB', size, color=(240, 240, 240))
|
| 115 |
+
d = ImageDraw.Draw(img)
|
| 116 |
+
|
| 117 |
+
# Try to load a font, fallback to default
|
| 118 |
+
try:
|
| 119 |
+
# Try to find a font that supports larger size
|
| 120 |
+
font = ImageFont.truetype("DejaVuSans.ttf", 40)
|
| 121 |
+
except:
|
| 122 |
+
font = ImageFont.load_default()
|
| 123 |
+
|
| 124 |
+
# Wrap text simply
|
| 125 |
+
margin = 20
|
| 126 |
+
offset = 100
|
| 127 |
+
for line in text.split(','):
|
| 128 |
+
d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0))
|
| 129 |
+
offset += 50
|
| 130 |
+
|
| 131 |
+
d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255))
|
| 132 |
+
return img
|
| 133 |
+
|
| 134 |
+
def fetch_top_k_images(top_indices, probs, df_embed, query_text=None):
|
| 135 |
+
"""
|
| 136 |
+
Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image.
|
| 137 |
+
"""
|
| 138 |
+
results = []
|
| 139 |
+
|
| 140 |
+
# We can run this in parallel
|
| 141 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
| 142 |
+
future_to_idx = {}
|
| 143 |
+
for i, idx in enumerate(top_indices):
|
| 144 |
+
row = df_embed.iloc[idx]
|
| 145 |
+
pid = row['product_id']
|
| 146 |
+
|
| 147 |
+
# Use download_and_process_image to get real data
|
| 148 |
+
future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False)
|
| 149 |
+
future_to_idx[future] = idx
|
| 150 |
+
|
| 151 |
+
for future in as_completed(future_to_idx):
|
| 152 |
+
idx = future_to_idx[future]
|
| 153 |
+
try:
|
| 154 |
+
img_384, img_full = future.result()
|
| 155 |
+
|
| 156 |
+
if img_384 is None:
|
| 157 |
+
# Fallback to Esri if download fails
|
| 158 |
+
print(f"Download failed for idx {idx}, falling back to Esri...")
|
| 159 |
+
row = df_embed.iloc[idx]
|
| 160 |
+
img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text)
|
| 161 |
+
img_full = img_384
|
| 162 |
+
|
| 163 |
+
row = df_embed.iloc[idx]
|
| 164 |
+
results.append({
|
| 165 |
+
'image_384': img_384,
|
| 166 |
+
'image_full': img_full,
|
| 167 |
+
'score': probs[idx],
|
| 168 |
+
'lat': row['centre_lat'],
|
| 169 |
+
'lon': row['centre_lon'],
|
| 170 |
+
'id': row['product_id']
|
| 171 |
+
})
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print(f"Error fetching image for idx {idx}: {e}")
|
| 174 |
+
|
| 175 |
+
# Sort results by score descending (since futures complete in random order)
|
| 176 |
+
results.sort(key=lambda x: x['score'], reverse=True)
|
| 177 |
+
return results
|
| 178 |
+
|
| 179 |
+
def get_all_results_metadata(model, filtered_indices, probs):
|
| 180 |
+
if len(filtered_indices) == 0:
|
| 181 |
+
return []
|
| 182 |
+
|
| 183 |
+
# Sort by score descending
|
| 184 |
+
filtered_scores = probs[filtered_indices]
|
| 185 |
+
sorted_order = np.argsort(filtered_scores)[::-1]
|
| 186 |
+
sorted_indices = filtered_indices[sorted_order]
|
| 187 |
+
|
| 188 |
+
# Extract from DataFrame
|
| 189 |
+
df_results = model.df_embed.iloc[sorted_indices].copy()
|
| 190 |
+
df_results['score'] = probs[sorted_indices]
|
| 191 |
+
|
| 192 |
+
# Rename columns
|
| 193 |
+
df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'})
|
| 194 |
+
|
| 195 |
+
# Convert to list of dicts
|
| 196 |
+
return df_results[['id', 'lat', 'lon', 'score']].to_dict('records')
|
| 197 |
+
|
| 198 |
+
def search_text(query, threshold, model_name):
|
| 199 |
+
model, error = get_active_model(model_name)
|
| 200 |
+
if error:
|
| 201 |
+
yield None, None, error, None, None, None, None
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
if not query:
|
| 205 |
+
yield None, None, "Please enter a query.", None, None, None, None
|
| 206 |
+
return
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
timings = {}
|
| 210 |
+
|
| 211 |
+
# 1. Encode Text
|
| 212 |
+
yield None, None, "Encoding text...", None, None, None, None
|
| 213 |
+
t0 = time.time()
|
| 214 |
+
text_features = model.encode_text(query)
|
| 215 |
+
timings['Encoding'] = time.time() - t0
|
| 216 |
+
|
| 217 |
+
if text_features is None:
|
| 218 |
+
yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
# 2. Search
|
| 222 |
+
yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None
|
| 223 |
+
t0 = time.time()
|
| 224 |
+
probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0)
|
| 225 |
+
timings['Retrieval'] = time.time() - t0
|
| 226 |
+
|
| 227 |
+
if probs is None:
|
| 228 |
+
yield None, None, "Search failed (embeddings missing?).", None, None, None, None
|
| 229 |
+
return
|
| 230 |
+
|
| 231 |
+
# Show geographic distribution (not timed)
|
| 232 |
+
df_embed = model.df_embed
|
| 233 |
+
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})')
|
| 234 |
+
|
| 235 |
+
# 3. Download Images
|
| 236 |
+
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 237 |
+
t0 = time.time()
|
| 238 |
+
top_indices = top_indices[:10]
|
| 239 |
+
results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query)
|
| 240 |
+
timings['Download'] = time.time() - t0
|
| 241 |
+
|
| 242 |
+
# 4. Visualize - keep geo_dist_map visible
|
| 243 |
+
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 244 |
+
t0 = time.time()
|
| 245 |
+
fig_results = plot_top5_overview(None, results, query_info=query)
|
| 246 |
+
gallery_items = format_results_for_gallery(results)
|
| 247 |
+
timings['Visualization'] = time.time() - t0
|
| 248 |
+
|
| 249 |
+
# 5. Generate Final Status
|
| 250 |
+
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
|
| 251 |
+
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
|
| 252 |
+
|
| 253 |
+
all_results = get_all_results_metadata(model, filtered_indices, probs)
|
| 254 |
+
results_txt = format_results_to_text(all_results)
|
| 255 |
+
|
| 256 |
+
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
import traceback
|
| 260 |
+
traceback.print_exc()
|
| 261 |
+
yield None, None, f"Error: {str(e)}", None, None, None, None
|
| 262 |
+
|
| 263 |
+
def search_image(image_input, threshold, model_name):
|
| 264 |
+
model, error = get_active_model(model_name)
|
| 265 |
+
if error:
|
| 266 |
+
yield None, None, error, None, None, None, None
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
if image_input is None:
|
| 270 |
+
yield None, None, "Please upload an image.", None, None, None, None
|
| 271 |
+
return
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
timings = {}
|
| 275 |
+
|
| 276 |
+
# 1. Encode Image
|
| 277 |
+
yield None, None, "Encoding image...", None, None, None, None
|
| 278 |
+
t0 = time.time()
|
| 279 |
+
image_features = model.encode_image(image_input)
|
| 280 |
+
timings['Encoding'] = time.time() - t0
|
| 281 |
+
|
| 282 |
+
if image_features is None:
|
| 283 |
+
yield None, None, "Model does not support image encoding.", None, None, None, None
|
| 284 |
+
return
|
| 285 |
+
|
| 286 |
+
# 2. Search
|
| 287 |
+
yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None
|
| 288 |
+
t0 = time.time()
|
| 289 |
+
probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0)
|
| 290 |
+
timings['Retrieval'] = time.time() - t0
|
| 291 |
+
|
| 292 |
+
# Show geographic distribution (not timed)
|
| 293 |
+
df_embed = model.df_embed
|
| 294 |
+
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})')
|
| 295 |
+
|
| 296 |
+
# 3. Download Images
|
| 297 |
+
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 298 |
+
t0 = time.time()
|
| 299 |
+
top_indices = top_indices[:6]
|
| 300 |
+
results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query")
|
| 301 |
+
timings['Download'] = time.time() - t0
|
| 302 |
+
|
| 303 |
+
# 4. Visualize - keep geo_dist_map visible
|
| 304 |
+
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 305 |
+
t0 = time.time()
|
| 306 |
+
fig_results = plot_top5_overview(image_input, results, query_info="Image Query")
|
| 307 |
+
gallery_items = format_results_for_gallery(results)
|
| 308 |
+
timings['Visualization'] = time.time() - t0
|
| 309 |
+
|
| 310 |
+
# 5. Generate Final Status
|
| 311 |
+
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
|
| 312 |
+
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
|
| 313 |
+
|
| 314 |
+
all_results = get_all_results_metadata(model, filtered_indices, probs)
|
| 315 |
+
results_txt = format_results_to_text(all_results[:50])
|
| 316 |
+
|
| 317 |
+
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 318 |
+
|
| 319 |
+
except Exception as e:
|
| 320 |
+
import traceback
|
| 321 |
+
traceback.print_exc()
|
| 322 |
+
yield None, None, f"Error: {str(e)}", None, None, None, None
|
| 323 |
+
|
| 324 |
+
def search_location(lat, lon, threshold):
|
| 325 |
+
model_name = "SatCLIP"
|
| 326 |
+
model, error = get_active_model(model_name)
|
| 327 |
+
if error:
|
| 328 |
+
yield None, None, error, None, None, None, None
|
| 329 |
+
return
|
| 330 |
+
|
| 331 |
+
try:
|
| 332 |
+
timings = {}
|
| 333 |
+
|
| 334 |
+
# 1. Encode Location
|
| 335 |
+
yield None, None, "Encoding location...", None, None, None, None
|
| 336 |
+
t0 = time.time()
|
| 337 |
+
loc_features = model.encode_location(float(lat), float(lon))
|
| 338 |
+
timings['Encoding'] = time.time() - t0
|
| 339 |
+
|
| 340 |
+
if loc_features is None:
|
| 341 |
+
yield None, None, "Location encoding failed.", None, None, None, None
|
| 342 |
+
return
|
| 343 |
+
|
| 344 |
+
# 2. Search
|
| 345 |
+
yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None
|
| 346 |
+
t0 = time.time()
|
| 347 |
+
probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0)
|
| 348 |
+
timings['Retrieval'] = time.time() - t0
|
| 349 |
+
|
| 350 |
+
# 3. Generate Distribution Map (not timed for location distribution)
|
| 351 |
+
yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None
|
| 352 |
+
df_embed = model.df_embed
|
| 353 |
+
top_10_indices = top_indices[:10]
|
| 354 |
+
top_10_results = []
|
| 355 |
+
for idx in top_10_indices:
|
| 356 |
+
row = df_embed.iloc[idx]
|
| 357 |
+
top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']})
|
| 358 |
+
|
| 359 |
+
# Show geographic distribution (not timed)
|
| 360 |
+
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})')
|
| 361 |
+
|
| 362 |
+
# 4. Download Images
|
| 363 |
+
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 364 |
+
t0 = time.time()
|
| 365 |
+
top_6_indices = top_indices[:6]
|
| 366 |
+
results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}")
|
| 367 |
+
|
| 368 |
+
# Get query tile
|
| 369 |
+
query_tile = None
|
| 370 |
+
try:
|
| 371 |
+
lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce')
|
| 372 |
+
lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce')
|
| 373 |
+
dists = (lats - float(lat))**2 + (lons - float(lon))**2
|
| 374 |
+
nearest_idx = dists.idxmin()
|
| 375 |
+
pid = df_embed.loc[nearest_idx, 'product_id']
|
| 376 |
+
query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False)
|
| 377 |
+
except Exception as e:
|
| 378 |
+
print(f"Error fetching nearest MajorTOM image: {e}")
|
| 379 |
+
if query_tile is None:
|
| 380 |
+
query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})")
|
| 381 |
+
timings['Download'] = time.time() - t0
|
| 382 |
+
|
| 383 |
+
# 5. Visualize - keep geo_dist_map visible
|
| 384 |
+
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 385 |
+
t0 = time.time()
|
| 386 |
+
fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}")
|
| 387 |
+
gallery_items = format_results_for_gallery(results)
|
| 388 |
+
timings['Visualization'] = time.time() - t0
|
| 389 |
+
|
| 390 |
+
# 6. Generate Final Status
|
| 391 |
+
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
|
| 392 |
+
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
|
| 393 |
+
|
| 394 |
+
all_results = get_all_results_metadata(model, filtered_indices, probs)
|
| 395 |
+
results_txt = format_results_to_text(all_results)
|
| 396 |
+
|
| 397 |
+
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
|
| 398 |
+
|
| 399 |
+
except Exception as e:
|
| 400 |
+
import traceback
|
| 401 |
+
traceback.print_exc()
|
| 402 |
+
yield None, None, f"Error: {str(e)}", None, None, None, None
|
| 403 |
+
|
| 404 |
+
def generate_status_msg(count, threshold, results):
|
| 405 |
+
status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n"
|
| 406 |
+
for i, res in enumerate(results[:5]):
|
| 407 |
+
status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n"
|
| 408 |
+
return status_msg
|
| 409 |
+
|
| 410 |
+
def get_initial_plot():
|
| 411 |
+
# Use FarSLIP as default for initial plot, fallback to SigLIP
|
| 412 |
+
df_vis = None
|
| 413 |
+
img = None
|
| 414 |
+
if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
|
| 415 |
+
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
|
| 416 |
+
# fig = plot_global_map(models['FarSLIP'].df_embed)
|
| 417 |
+
else:
|
| 418 |
+
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
|
| 419 |
+
return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False)
|
| 420 |
+
|
| 421 |
+
def handle_map_click(evt: gr.SelectData, df_vis):
|
| 422 |
+
if evt is None:
|
| 423 |
+
return None, None, None, "No point selected."
|
| 424 |
+
|
| 425 |
+
try:
|
| 426 |
+
x, y = evt.index[0], evt.index[1]
|
| 427 |
+
|
| 428 |
+
# Image dimensions (New)
|
| 429 |
+
img_width = 3000
|
| 430 |
+
img_height = 1500
|
| 431 |
+
|
| 432 |
+
# Scaled Margins (Proportional to 4000x2000)
|
| 433 |
+
left_margin = 110 * 0.75
|
| 434 |
+
right_margin = 110 * 0.75
|
| 435 |
+
top_margin = 100 * 0.75
|
| 436 |
+
bottom_margin = 67 * 0.75
|
| 437 |
+
|
| 438 |
+
plot_width = img_width - left_margin - right_margin
|
| 439 |
+
plot_height = img_height - top_margin - bottom_margin
|
| 440 |
+
|
| 441 |
+
# Adjust for aspect ratio preservation
|
| 442 |
+
map_aspect = 360.0 / 180.0 # 2.0
|
| 443 |
+
plot_aspect = plot_width / plot_height
|
| 444 |
+
|
| 445 |
+
if plot_aspect > map_aspect:
|
| 446 |
+
actual_map_width = plot_height * map_aspect
|
| 447 |
+
actual_map_height = plot_height
|
| 448 |
+
h_offset = (plot_width - actual_map_width) / 2
|
| 449 |
+
v_offset = 0
|
| 450 |
+
else:
|
| 451 |
+
actual_map_width = plot_width
|
| 452 |
+
actual_map_height = plot_width / map_aspect
|
| 453 |
+
h_offset = 0
|
| 454 |
+
v_offset = (plot_height - actual_map_height) / 2
|
| 455 |
+
|
| 456 |
+
# Calculate relative position within the plot area
|
| 457 |
+
x_in_plot = x - left_margin
|
| 458 |
+
y_in_plot = y - top_margin
|
| 459 |
+
|
| 460 |
+
# Check if click is within the actual map bounds
|
| 461 |
+
if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or
|
| 462 |
+
y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height):
|
| 463 |
+
return None, None, None, "Click outside map area. Please click on the map."
|
| 464 |
+
|
| 465 |
+
# Calculate relative position within the map (0 to 1)
|
| 466 |
+
x_rel = (x_in_plot - h_offset) / actual_map_width
|
| 467 |
+
y_rel = (y_in_plot - v_offset) / actual_map_height
|
| 468 |
+
|
| 469 |
+
# Clamp to [0, 1]
|
| 470 |
+
x_rel = max(0, min(1, x_rel))
|
| 471 |
+
y_rel = max(0, min(1, y_rel))
|
| 472 |
+
|
| 473 |
+
# Convert to geographic coordinates
|
| 474 |
+
lon = x_rel * 360 - 180
|
| 475 |
+
lat = 90 - y_rel * 180
|
| 476 |
+
|
| 477 |
+
# Find nearest point in df_vis if available
|
| 478 |
+
pid = ""
|
| 479 |
+
if df_vis is not None:
|
| 480 |
+
dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2
|
| 481 |
+
min_idx = dists.idxmin()
|
| 482 |
+
nearest_row = df_vis.loc[min_idx]
|
| 483 |
+
|
| 484 |
+
if dists[min_idx] < 25:
|
| 485 |
+
lat = nearest_row['centre_lat']
|
| 486 |
+
lon = nearest_row['centre_lon']
|
| 487 |
+
pid = nearest_row['product_id']
|
| 488 |
+
|
| 489 |
+
except Exception as e:
|
| 490 |
+
print(f"Error handling click: {e}")
|
| 491 |
+
import traceback
|
| 492 |
+
traceback.print_exc()
|
| 493 |
+
return None, None, None, f"Error: {e}"
|
| 494 |
+
|
| 495 |
+
return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})"
|
| 496 |
+
|
| 497 |
+
def download_image_by_location(lat, lon, pid, model_name):
|
| 498 |
+
"""Download and return the image at the specified location"""
|
| 499 |
+
if lat is None or lon is None:
|
| 500 |
+
return None, "Please specify coordinates first."
|
| 501 |
+
|
| 502 |
+
model, error = get_active_model(model_name)
|
| 503 |
+
if error:
|
| 504 |
+
return None, error
|
| 505 |
+
|
| 506 |
+
try:
|
| 507 |
+
# Convert to float to ensure proper formatting
|
| 508 |
+
lat = float(lat)
|
| 509 |
+
lon = float(lon)
|
| 510 |
+
|
| 511 |
+
# Find Product ID if not provided
|
| 512 |
+
if not pid:
|
| 513 |
+
df = model.df_embed
|
| 514 |
+
lats = pd.to_numeric(df['centre_lat'], errors='coerce')
|
| 515 |
+
lons = pd.to_numeric(df['centre_lon'], errors='coerce')
|
| 516 |
+
dists = (lats - lat)**2 + (lons - lon)**2
|
| 517 |
+
nearest_idx = dists.idxmin()
|
| 518 |
+
pid = df.loc[nearest_idx, 'product_id']
|
| 519 |
+
|
| 520 |
+
# Download image
|
| 521 |
+
img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True)
|
| 522 |
+
|
| 523 |
+
if img_384 is None:
|
| 524 |
+
return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})"
|
| 525 |
+
|
| 526 |
+
return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})"
|
| 527 |
+
|
| 528 |
+
except Exception as e:
|
| 529 |
+
import traceback
|
| 530 |
+
traceback.print_exc()
|
| 531 |
+
return None, f"Error: {str(e)}"
|
| 532 |
+
|
| 533 |
+
def reset_to_global_map():
|
| 534 |
+
"""Reset the map to the initial global distribution view"""
|
| 535 |
+
img = None
|
| 536 |
+
df_vis = None
|
| 537 |
+
if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
|
| 538 |
+
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
|
| 539 |
+
else:
|
| 540 |
+
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
|
| 541 |
+
|
| 542 |
+
return gr.update(value=img, visible=True), [img], df_vis
|
| 543 |
+
|
| 544 |
+
def format_results_to_text(results):
|
| 545 |
+
if not results:
|
| 546 |
+
return "No results found."
|
| 547 |
+
|
| 548 |
+
txt = f"Top {len(results)} Retrieval Results\n"
|
| 549 |
+
txt += "=" * 30 + "\n\n"
|
| 550 |
+
for i, res in enumerate(results):
|
| 551 |
+
txt += f"Rank: {i+1}\n"
|
| 552 |
+
txt += f"Product ID: {res['id']}\n"
|
| 553 |
+
txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n"
|
| 554 |
+
txt += f"Similarity Score: {res['score']:.6f}\n"
|
| 555 |
+
txt += "-" * 30 + "\n"
|
| 556 |
+
return txt
|
| 557 |
+
|
| 558 |
+
def save_plot(figs):
|
| 559 |
+
if figs is None:
|
| 560 |
+
return None
|
| 561 |
+
try:
|
| 562 |
+
# If it's a single image (initial state), save as png
|
| 563 |
+
if isinstance(figs, PILImage.Image):
|
| 564 |
+
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
|
| 565 |
+
os.close(fd)
|
| 566 |
+
figs.save(path)
|
| 567 |
+
return path
|
| 568 |
+
|
| 569 |
+
# If it's a list/tuple of images [map_img, results_img]
|
| 570 |
+
if isinstance(figs, (list, tuple)):
|
| 571 |
+
# If only one image in list, save as PNG
|
| 572 |
+
if len(figs) == 1 and isinstance(figs[0], PILImage.Image):
|
| 573 |
+
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
|
| 574 |
+
os.close(fd)
|
| 575 |
+
figs[0].save(path)
|
| 576 |
+
return path
|
| 577 |
+
|
| 578 |
+
fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_')
|
| 579 |
+
os.close(fd)
|
| 580 |
+
|
| 581 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
| 582 |
+
# Save Map
|
| 583 |
+
if figs[0] is not None:
|
| 584 |
+
map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png')
|
| 585 |
+
figs[0].save(map_path)
|
| 586 |
+
zipf.write(map_path, arcname='map_distribution.png')
|
| 587 |
+
|
| 588 |
+
# Save Results
|
| 589 |
+
if len(figs) > 1 and figs[1] is not None:
|
| 590 |
+
res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png')
|
| 591 |
+
figs[1].save(res_path)
|
| 592 |
+
zipf.write(res_path, arcname='retrieval_results.png')
|
| 593 |
+
|
| 594 |
+
# Save Results Text
|
| 595 |
+
if len(figs) > 2 and figs[2] is not None:
|
| 596 |
+
txt_path = os.path.join(tempfile.gettempdir(), 'results.txt')
|
| 597 |
+
with open(txt_path, 'w', encoding='utf-8') as f:
|
| 598 |
+
f.write(figs[2])
|
| 599 |
+
zipf.write(txt_path, arcname='results.txt')
|
| 600 |
+
|
| 601 |
+
return zip_path
|
| 602 |
+
|
| 603 |
+
# Fallback for Plotly figure (if any)
|
| 604 |
+
# Create a temporary file
|
| 605 |
+
fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_')
|
| 606 |
+
os.close(fd)
|
| 607 |
+
|
| 608 |
+
# Write to the temporary file
|
| 609 |
+
figs.write_html(path)
|
| 610 |
+
return path
|
| 611 |
+
except Exception as e:
|
| 612 |
+
print(f"Error saving: {e}")
|
| 613 |
+
return None
|
| 614 |
+
|
| 615 |
+
# Gradio Blocks Interface
|
| 616 |
+
with gr.Blocks(title="EarthEmbeddingExplorer") as demo:
|
| 617 |
+
gr.Markdown("# EarthEmbeddingExplorer")
|
| 618 |
+
gr.HTML("""
|
| 619 |
+
<div style="font-size: 1.2em;">
|
| 620 |
+
EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images.
|
| 621 |
+
</div>
|
| 622 |
+
|
| 623 |
+
""")
|
| 624 |
+
|
| 625 |
+
with gr.Row():
|
| 626 |
+
with gr.Column(scale=4):
|
| 627 |
+
with gr.Tabs():
|
| 628 |
+
with gr.TabItem("Text Search") as tab_text:
|
| 629 |
+
model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model")
|
| 630 |
+
query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier")
|
| 631 |
+
|
| 632 |
+
gr.Examples(
|
| 633 |
+
examples=[
|
| 634 |
+
["a satellite image of a river around a city"],
|
| 635 |
+
["a satellite image of a rainforest"],
|
| 636 |
+
["a satellite image of a slum"],
|
| 637 |
+
["a satellite image of a glacier"],
|
| 638 |
+
["a satellite image of snow covered mountains"]
|
| 639 |
+
],
|
| 640 |
+
inputs=[query_input],
|
| 641 |
+
label="Text Examples"
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
search_btn = gr.Button("Search by Text", variant="primary")
|
| 645 |
+
|
| 646 |
+
with gr.TabItem("Image Search") as tab_image:
|
| 647 |
+
model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP", "DINOv2"], value="FarSLIP", label="Model")
|
| 648 |
+
|
| 649 |
+
gr.Markdown("### Option 1: Upload or Select Image")
|
| 650 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
| 651 |
+
|
| 652 |
+
gr.Examples(
|
| 653 |
+
examples=[
|
| 654 |
+
["./examples/example1.png"],
|
| 655 |
+
["./examples/example2.png"],
|
| 656 |
+
["./examples/example3.png"]
|
| 657 |
+
],
|
| 658 |
+
inputs=[image_input],
|
| 659 |
+
label="Image Examples"
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
gr.Markdown("### Option 2: Click Map or Enter Coordinates")
|
| 663 |
+
btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
|
| 664 |
+
|
| 665 |
+
with gr.Row():
|
| 666 |
+
img_lat = gr.Number(label="Latitude", interactive=True)
|
| 667 |
+
img_lon = gr.Number(label="Longitude", interactive=True)
|
| 668 |
+
|
| 669 |
+
img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
|
| 670 |
+
img_click_status = gr.Markdown("")
|
| 671 |
+
|
| 672 |
+
btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary")
|
| 673 |
+
|
| 674 |
+
search_img_btn = gr.Button("Search by Image", variant="primary")
|
| 675 |
+
|
| 676 |
+
with gr.TabItem("Location Search") as tab_location:
|
| 677 |
+
gr.Markdown("Search using **SatCLIP** location encoder.")
|
| 678 |
+
|
| 679 |
+
gr.Markdown("### Click Map or Enter Coordinates")
|
| 680 |
+
btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
|
| 681 |
+
|
| 682 |
+
with gr.Row():
|
| 683 |
+
lat_input = gr.Number(label="Latitude", value=30.0, interactive=True)
|
| 684 |
+
lon_input = gr.Number(label="Longitude", value=120.0, interactive=True)
|
| 685 |
+
|
| 686 |
+
loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
|
| 687 |
+
loc_click_status = gr.Markdown("")
|
| 688 |
+
|
| 689 |
+
gr.Examples(
|
| 690 |
+
examples=[
|
| 691 |
+
[30.32, 120.15],
|
| 692 |
+
[40.7128, -74.0060],
|
| 693 |
+
[24.65, 46.71],
|
| 694 |
+
[-3.4653, -62.2159],
|
| 695 |
+
[64.4, 16.8]
|
| 696 |
+
],
|
| 697 |
+
inputs=[lat_input, lon_input],
|
| 698 |
+
label="Location Examples"
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
search_loc_btn = gr.Button("Search by Location", variant="primary")
|
| 702 |
+
|
| 703 |
+
threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)")
|
| 704 |
+
status_output = gr.Textbox(label="Status", lines=10)
|
| 705 |
+
save_btn = gr.Button("Download Result")
|
| 706 |
+
download_file = gr.File(label="Zipped Results", height=40)
|
| 707 |
+
|
| 708 |
+
with gr.Column(scale=6):
|
| 709 |
+
plot_map = gr.Image(
|
| 710 |
+
label="Geographical Distribution",
|
| 711 |
+
type="pil",
|
| 712 |
+
interactive=False,
|
| 713 |
+
height=400,
|
| 714 |
+
width=800,
|
| 715 |
+
visible=True
|
| 716 |
+
)
|
| 717 |
+
plot_map_interactive = gr.Plot(
|
| 718 |
+
label="Geographical Distribution (Interactive)",
|
| 719 |
+
visible=False
|
| 720 |
+
)
|
| 721 |
+
results_plot = gr.Image(label="Top 5 Matched Images", type="pil")
|
| 722 |
+
gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto")
|
| 723 |
+
|
| 724 |
+
current_fig = gr.State()
|
| 725 |
+
map_data_state = gr.State()
|
| 726 |
+
|
| 727 |
+
# Initial Load
|
| 728 |
+
demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive])
|
| 729 |
+
|
| 730 |
+
# Reset Map Buttons
|
| 731 |
+
btn_reset_map_img.click(
|
| 732 |
+
fn=reset_to_global_map,
|
| 733 |
+
outputs=[plot_map, current_fig, map_data_state]
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
btn_reset_map_loc.click(
|
| 737 |
+
fn=reset_to_global_map,
|
| 738 |
+
outputs=[plot_map, current_fig, map_data_state]
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Map Click Event - updates Image Search coordinates
|
| 742 |
+
plot_map.select(
|
| 743 |
+
fn=handle_map_click,
|
| 744 |
+
inputs=[map_data_state],
|
| 745 |
+
outputs=[img_lat, img_lon, img_pid, img_click_status]
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Map Click Event - also updates Location Search coordinates
|
| 749 |
+
plot_map.select(
|
| 750 |
+
fn=handle_map_click,
|
| 751 |
+
inputs=[map_data_state],
|
| 752 |
+
outputs=[lat_input, lon_input, loc_pid, loc_click_status]
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
# Download Image by Geolocation
|
| 756 |
+
btn_download_img.click(
|
| 757 |
+
fn=download_image_by_location,
|
| 758 |
+
inputs=[img_lat, img_lon, img_pid, model_selector_img],
|
| 759 |
+
outputs=[image_input, img_click_status]
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# Search Event (Text)
|
| 763 |
+
search_btn.click(
|
| 764 |
+
fn=search_text,
|
| 765 |
+
inputs=[query_input, threshold_slider, model_selector_text],
|
| 766 |
+
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
# Search Event (Image)
|
| 770 |
+
search_img_btn.click(
|
| 771 |
+
fn=search_image,
|
| 772 |
+
inputs=[image_input, threshold_slider, model_selector_img],
|
| 773 |
+
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
# Search Event (Location)
|
| 777 |
+
search_loc_btn.click(
|
| 778 |
+
fn=search_location,
|
| 779 |
+
inputs=[lat_input, lon_input, threshold_slider],
|
| 780 |
+
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# Save Event
|
| 784 |
+
save_btn.click(
|
| 785 |
+
fn=save_plot,
|
| 786 |
+
inputs=[current_fig],
|
| 787 |
+
outputs=[download_file]
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
# Tab Selection Events
|
| 791 |
+
def show_static_map():
|
| 792 |
+
return gr.update(visible=True), gr.update(visible=False)
|
| 793 |
+
|
| 794 |
+
tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
|
| 795 |
+
tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
|
| 796 |
+
tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
|
| 797 |
+
|
| 798 |
+
if __name__ == "__main__":
|
| 799 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
compute_embeddings.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#!/usr/bin/env python3
|
| 3 |
+
"""
|
| 4 |
+
Compute Embeddings for Major-TOM Sentinel-2 Images
|
| 5 |
+
|
| 6 |
+
This script generates embeddings for Sentinel-2 imagery using various models:
|
| 7 |
+
- DINOv2: Vision Transformer trained with self-supervised learning
|
| 8 |
+
- SigLIP: Vision-Language model with sigmoid loss
|
| 9 |
+
- FarSLIP: Remote sensing fine-tuned CLIP
|
| 10 |
+
- SatCLIP: Satellite imagery CLIP with location awareness
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python compute_embeddings.py --model dinov2 --device cuda:1
|
| 14 |
+
python compute_embeddings.py --model siglip --device cuda:5
|
| 15 |
+
python compute_embeddings.py --model satclip --device cuda:3
|
| 16 |
+
python compute_embeddings.py --model farslip --device cuda:4
|
| 17 |
+
|
| 18 |
+
Author: Generated by Copilot
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import argparse
|
| 24 |
+
import logging
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import pandas as pd
|
| 30 |
+
import torch
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from tqdm.auto import tqdm
|
| 33 |
+
|
| 34 |
+
# Add project root to path
|
| 35 |
+
PROJECT_ROOT = Path(__file__).parent.absolute()
|
| 36 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 37 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 38 |
+
|
| 39 |
+
from models.load_config import load_and_process_config
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# =============================================================================
|
| 43 |
+
# Configuration
|
| 44 |
+
# =============================================================================
|
| 45 |
+
METADATA_PATH = Path("/data1/zyj/Core-S2L2A-249k/Core_S2L2A_249k_crop_384x384_metadata.parquet")
|
| 46 |
+
IMAGE_PARQUET_DIR = Path("/data1/zyj/Core-S2L2A-249k/images")
|
| 47 |
+
OUTPUT_BASE_DIR = Path("/data1/zyj/EarthEmbeddings/Core-S2L2A-249k")
|
| 48 |
+
|
| 49 |
+
# Columns to remove from output
|
| 50 |
+
COLUMNS_TO_REMOVE = ['cloud_cover', 'nodata', 'geometry_wkt', 'bands', 'image_shape', 'image_dtype']
|
| 51 |
+
|
| 52 |
+
# Columns to rename
|
| 53 |
+
COLUMNS_RENAME = {'crs': 'utm_crs'}
|
| 54 |
+
|
| 55 |
+
# Pixel bbox for center 384x384 crop from 1068x1068 original
|
| 56 |
+
# (1068 - 384) / 2 = 342
|
| 57 |
+
PIXEL_BBOX = [342, 342, 726, 726] # [x_min, y_min, x_max, y_max]
|
| 58 |
+
|
| 59 |
+
# Model output paths
|
| 60 |
+
MODEL_OUTPUT_PATHS = {
|
| 61 |
+
'dinov2': OUTPUT_BASE_DIR / 'dinov2' / 'DINOv2_crop_384x384.parquet',
|
| 62 |
+
'siglip': OUTPUT_BASE_DIR / 'siglip' / 'SigLIP_crop_384x384.parquet',
|
| 63 |
+
'farslip': OUTPUT_BASE_DIR / 'farslip' / 'FarSLIP_crop_384x384.parquet',
|
| 64 |
+
'satclip': OUTPUT_BASE_DIR / 'satclip' / 'SatCLIP_crop_384x384.parquet',
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
# Batch sizes for different models
|
| 68 |
+
BATCH_SIZES = {
|
| 69 |
+
'dinov2': 64,
|
| 70 |
+
'siglip': 64,
|
| 71 |
+
'farslip': 64,
|
| 72 |
+
'satclip': 128,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# =============================================================================
|
| 77 |
+
# Setup Logging
|
| 78 |
+
# =============================================================================
|
| 79 |
+
def setup_logging(model_name: str):
|
| 80 |
+
"""Configure logging to both file and console."""
|
| 81 |
+
log_dir = PROJECT_ROOT / "logs"
|
| 82 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
log_file = log_dir / f"compute_embeddings_{model_name}.log"
|
| 84 |
+
|
| 85 |
+
logging.basicConfig(
|
| 86 |
+
level=logging.INFO,
|
| 87 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 88 |
+
handlers=[
|
| 89 |
+
logging.FileHandler(log_file),
|
| 90 |
+
logging.StreamHandler(sys.stdout)
|
| 91 |
+
]
|
| 92 |
+
)
|
| 93 |
+
return logging.getLogger(__name__)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# =============================================================================
|
| 97 |
+
# Image Preprocessing Functions
|
| 98 |
+
# =============================================================================
|
| 99 |
+
def decode_image_bytes(row) -> np.ndarray:
|
| 100 |
+
"""
|
| 101 |
+
Decode image bytes from parquet row to numpy array.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
row: pandas Series with 'image_bytes', 'image_shape', 'image_dtype'
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
np.ndarray of shape (H, W, 12) with uint16 values
|
| 108 |
+
"""
|
| 109 |
+
shape = tuple(map(int, row['image_shape']))
|
| 110 |
+
dtype = np.dtype(row['image_dtype'])
|
| 111 |
+
img_flat = np.frombuffer(row['image_bytes'], dtype=dtype)
|
| 112 |
+
return img_flat.reshape(shape)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def extract_rgb_image(img_array: np.ndarray, clip_max: float = 4000.0) -> Image.Image:
|
| 116 |
+
"""
|
| 117 |
+
Extract RGB channels from 12-band Sentinel-2 array.
|
| 118 |
+
|
| 119 |
+
Sentinel-2 Bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12]
|
| 120 |
+
RGB Mapping: R=B04(idx 3), G=B03(idx 2), B=B02(idx 1)
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
img_array: numpy array of shape (H, W, 12)
|
| 124 |
+
clip_max: Value to clip reflectance data for visualization
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
PIL.Image: RGB image
|
| 128 |
+
"""
|
| 129 |
+
# Select RGB Channels: R=B04(3), G=B03(2), B=B02(1)
|
| 130 |
+
rgb_bands = img_array[:, :, [3, 2, 1]].astype(np.float32)
|
| 131 |
+
|
| 132 |
+
# Normalize and Clip
|
| 133 |
+
rgb_normalized = np.clip(rgb_bands / clip_max, 0, 1)
|
| 134 |
+
|
| 135 |
+
# Convert to 8-bit
|
| 136 |
+
rgb_uint8 = (rgb_normalized * 255).astype(np.uint8)
|
| 137 |
+
|
| 138 |
+
return Image.fromarray(rgb_uint8)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# =============================================================================
|
| 142 |
+
# Model Loading Functions
|
| 143 |
+
# =============================================================================
|
| 144 |
+
def load_model(model_name: str, device: str, config: dict):
|
| 145 |
+
"""
|
| 146 |
+
Load the specified model.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
model_name: One of 'dinov2', 'siglip', 'farslip', 'satclip'
|
| 150 |
+
device: Device string like 'cuda:0' or 'cpu'
|
| 151 |
+
config: Configuration dictionary from local.yaml
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Model instance
|
| 155 |
+
"""
|
| 156 |
+
logger = logging.getLogger(__name__)
|
| 157 |
+
|
| 158 |
+
if model_name == 'dinov2':
|
| 159 |
+
from models.dinov2_model import DINOv2Model
|
| 160 |
+
model_config = config.get('dinov2', {})
|
| 161 |
+
model = DINOv2Model(
|
| 162 |
+
ckpt_path=model_config.get('ckpt_path', '/data1/zyj/checkpoints/dinov2-large'),
|
| 163 |
+
model_name='facebook/dinov2-large',
|
| 164 |
+
embedding_path=None, # We're generating, not loading
|
| 165 |
+
device=device
|
| 166 |
+
)
|
| 167 |
+
logger.info(f"DINOv2 model loaded on {device}")
|
| 168 |
+
return model
|
| 169 |
+
|
| 170 |
+
elif model_name == 'siglip':
|
| 171 |
+
from models.siglip_model import SigLIPModel
|
| 172 |
+
model_config = config.get('siglip', {})
|
| 173 |
+
model = SigLIPModel(
|
| 174 |
+
ckpt_path=model_config.get('ckpt_path', './checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin'),
|
| 175 |
+
model_name='ViT-SO400M-14-SigLIP-384',
|
| 176 |
+
tokenizer_path=model_config.get('tokenizer_path', './checkpoints/ViT-SO400M-14-SigLIP-384'),
|
| 177 |
+
embedding_path=None,
|
| 178 |
+
device=device
|
| 179 |
+
)
|
| 180 |
+
# Disable embedding loading since we set path to None
|
| 181 |
+
model.df_embed = None
|
| 182 |
+
model.image_embeddings = None
|
| 183 |
+
logger.info(f"SigLIP model loaded on {device}")
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
elif model_name == 'farslip':
|
| 187 |
+
from models.farslip_model import FarSLIPModel
|
| 188 |
+
model_config = config.get('farslip', {})
|
| 189 |
+
model = FarSLIPModel(
|
| 190 |
+
ckpt_path=model_config.get('ckpt_path', './checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt'),
|
| 191 |
+
model_name='ViT-B-16',
|
| 192 |
+
embedding_path=None,
|
| 193 |
+
device=device
|
| 194 |
+
)
|
| 195 |
+
logger.info(f"FarSLIP model loaded on {device}")
|
| 196 |
+
return model
|
| 197 |
+
|
| 198 |
+
elif model_name == 'satclip':
|
| 199 |
+
from models.satclip_ms_model import SatCLIPMSModel
|
| 200 |
+
model_config = config.get('satclip', {})
|
| 201 |
+
model = SatCLIPMSModel(
|
| 202 |
+
ckpt_path=model_config.get('ckpt_path', './checkpoints/SatCLIP/satclip-vit16-l40.ckpt'),
|
| 203 |
+
embedding_path=None,
|
| 204 |
+
device=device
|
| 205 |
+
)
|
| 206 |
+
logger.info(f"SatCLIP-MS model loaded on {device}")
|
| 207 |
+
return model
|
| 208 |
+
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(f"Unknown model: {model_name}")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# =============================================================================
|
| 214 |
+
# Embedding Computation Functions
|
| 215 |
+
# =============================================================================
|
| 216 |
+
def compute_embedding_single(model, model_name: str, img_array: np.ndarray) -> np.ndarray:
|
| 217 |
+
"""
|
| 218 |
+
Compute embedding for a single image.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
model: Model instance
|
| 222 |
+
model_name: Model identifier
|
| 223 |
+
img_array: numpy array of shape (H, W, 12)
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
np.ndarray: 1D embedding vector
|
| 227 |
+
"""
|
| 228 |
+
if model_name in ['dinov2', 'siglip', 'farslip']:
|
| 229 |
+
# These models use RGB input
|
| 230 |
+
rgb_img = extract_rgb_image(img_array)
|
| 231 |
+
feature = model.encode_image(rgb_img)
|
| 232 |
+
if feature is not None:
|
| 233 |
+
return feature.cpu().numpy().flatten()
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
elif model_name == 'satclip':
|
| 237 |
+
# SatCLIP can use multi-spectral input directly
|
| 238 |
+
feature = model.encode_image(img_array, is_multispectral=True)
|
| 239 |
+
if feature is not None:
|
| 240 |
+
return feature.cpu().numpy().flatten()
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def compute_embedding_batch(model, model_name: str, img_arrays: list) -> list:
|
| 247 |
+
"""
|
| 248 |
+
Compute embeddings for a batch of images.
|
| 249 |
+
Falls back to single-image processing if batch method unavailable.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
model: Model instance
|
| 253 |
+
model_name: Model identifier
|
| 254 |
+
img_arrays: List of numpy arrays of shape (H, W, 12)
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
List of 1D embedding vectors (numpy arrays), None for failed items
|
| 258 |
+
"""
|
| 259 |
+
n_images = len(img_arrays)
|
| 260 |
+
|
| 261 |
+
if model_name in ['dinov2', 'siglip', 'farslip']:
|
| 262 |
+
# These models use RGB input
|
| 263 |
+
rgb_imgs = [extract_rgb_image(arr) for arr in img_arrays]
|
| 264 |
+
|
| 265 |
+
# Try batch encoding first
|
| 266 |
+
if hasattr(model, 'encode_images'):
|
| 267 |
+
try:
|
| 268 |
+
features = model.encode_images(rgb_imgs)
|
| 269 |
+
if features is not None:
|
| 270 |
+
return [features[i].cpu().numpy().flatten() for i in range(len(features))]
|
| 271 |
+
except Exception:
|
| 272 |
+
pass # Fall back to single processing
|
| 273 |
+
|
| 274 |
+
# Fall back to single image encoding
|
| 275 |
+
results = []
|
| 276 |
+
for img in rgb_imgs:
|
| 277 |
+
try:
|
| 278 |
+
feature = model.encode_image(img)
|
| 279 |
+
if feature is not None:
|
| 280 |
+
results.append(feature.cpu().numpy().flatten())
|
| 281 |
+
else:
|
| 282 |
+
results.append(None)
|
| 283 |
+
except Exception:
|
| 284 |
+
results.append(None)
|
| 285 |
+
return results
|
| 286 |
+
|
| 287 |
+
elif model_name == 'satclip':
|
| 288 |
+
# SatCLIP uses multi-spectral input
|
| 289 |
+
if hasattr(model, 'encode_images'):
|
| 290 |
+
try:
|
| 291 |
+
features = model.encode_images(img_arrays, is_multispectral=True)
|
| 292 |
+
if features is not None:
|
| 293 |
+
return [features[i].cpu().numpy().flatten() for i in range(len(features))]
|
| 294 |
+
except Exception:
|
| 295 |
+
pass # Fall back to single processing
|
| 296 |
+
|
| 297 |
+
# Fall back to single image encoding
|
| 298 |
+
results = []
|
| 299 |
+
for arr in img_arrays:
|
| 300 |
+
try:
|
| 301 |
+
feature = model.encode_image(arr, is_multispectral=True)
|
| 302 |
+
if feature is not None:
|
| 303 |
+
results.append(feature.cpu().numpy().flatten())
|
| 304 |
+
else:
|
| 305 |
+
results.append(None)
|
| 306 |
+
except Exception:
|
| 307 |
+
results.append(None)
|
| 308 |
+
return results
|
| 309 |
+
|
| 310 |
+
return [None] * n_images
|
| 311 |
+
|
| 312 |
+
# def process_parquet_file(
|
| 313 |
+
# file_path: Path,
|
| 314 |
+
# model,
|
| 315 |
+
# model_name: str,
|
| 316 |
+
# batch_size: int = 64
|
| 317 |
+
# ) -> pd.DataFrame:
|
| 318 |
+
# """
|
| 319 |
+
# Process a single parquet file and generate embeddings.
|
| 320 |
+
|
| 321 |
+
# Args:
|
| 322 |
+
# file_path: Path to input parquet file
|
| 323 |
+
# model: Model instance
|
| 324 |
+
# model_name: Model identifier
|
| 325 |
+
# batch_size: Batch size for processing
|
| 326 |
+
|
| 327 |
+
# Returns:
|
| 328 |
+
# DataFrame with embeddings
|
| 329 |
+
# """
|
| 330 |
+
# logger = logging.getLogger(__name__)
|
| 331 |
+
|
| 332 |
+
# # Load data
|
| 333 |
+
# df = pd.read_parquet(file_path)
|
| 334 |
+
|
| 335 |
+
# embeddings_list = []
|
| 336 |
+
# valid_indices = []
|
| 337 |
+
|
| 338 |
+
# # Process in batches (for future batch optimization)
|
| 339 |
+
# for idx, row in df.iterrows():
|
| 340 |
+
# try:
|
| 341 |
+
# # Decode image
|
| 342 |
+
# img_array = decode_image_bytes(row)
|
| 343 |
+
|
| 344 |
+
# # Compute embedding
|
| 345 |
+
# embedding = compute_embedding_single(model, model_name, img_array)
|
| 346 |
+
|
| 347 |
+
# if embedding is not None:
|
| 348 |
+
# embeddings_list.append(embedding)
|
| 349 |
+
# valid_indices.append(idx)
|
| 350 |
+
|
| 351 |
+
# except Exception as e:
|
| 352 |
+
# logger.warning(f"Error processing row {idx}: {e}")
|
| 353 |
+
# continue
|
| 354 |
+
|
| 355 |
+
# if not embeddings_list:
|
| 356 |
+
# logger.warning(f"No valid embeddings for {file_path.name}")
|
| 357 |
+
# return None
|
| 358 |
+
|
| 359 |
+
# # Build result DataFrame
|
| 360 |
+
# result_df = df.loc[valid_indices].copy()
|
| 361 |
+
|
| 362 |
+
# # Remove unwanted columns
|
| 363 |
+
# cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns]
|
| 364 |
+
# if cols_to_drop:
|
| 365 |
+
# result_df = result_df.drop(columns=cols_to_drop)
|
| 366 |
+
|
| 367 |
+
# # Remove image_bytes (large binary data)
|
| 368 |
+
# if 'image_bytes' in result_df.columns:
|
| 369 |
+
# result_df = result_df.drop(columns=['image_bytes'])
|
| 370 |
+
|
| 371 |
+
# # Remove geometry column (binary)
|
| 372 |
+
# if 'geometry' in result_df.columns:
|
| 373 |
+
# result_df = result_df.drop(columns=['geometry'])
|
| 374 |
+
|
| 375 |
+
# # Rename columns
|
| 376 |
+
# result_df = result_df.rename(columns=COLUMNS_RENAME)
|
| 377 |
+
|
| 378 |
+
# # Add pixel_bbox
|
| 379 |
+
# result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df)
|
| 380 |
+
|
| 381 |
+
# # Add embedding
|
| 382 |
+
# result_df['embedding'] = embeddings_list
|
| 383 |
+
|
| 384 |
+
# return result_df
|
| 385 |
+
|
| 386 |
+
def process_parquet_file(
|
| 387 |
+
file_path: Path,
|
| 388 |
+
model,
|
| 389 |
+
model_name: str,
|
| 390 |
+
batch_size: int = 64
|
| 391 |
+
) -> pd.DataFrame:
|
| 392 |
+
"""
|
| 393 |
+
Process a single parquet file and generate embeddings using batch processing.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
file_path: Path to input parquet file
|
| 397 |
+
model: Model instance
|
| 398 |
+
model_name: Model identifier
|
| 399 |
+
batch_size: Batch size for processing
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
DataFrame with embeddings
|
| 403 |
+
"""
|
| 404 |
+
logger = logging.getLogger(__name__)
|
| 405 |
+
|
| 406 |
+
# Load data
|
| 407 |
+
df = pd.read_parquet(file_path)
|
| 408 |
+
n_rows = len(df)
|
| 409 |
+
|
| 410 |
+
embeddings_list = [None] * n_rows
|
| 411 |
+
valid_mask = [False] * n_rows
|
| 412 |
+
|
| 413 |
+
# Process in batches
|
| 414 |
+
for batch_start in range(0, n_rows, batch_size):
|
| 415 |
+
batch_end = min(batch_start + batch_size, n_rows)
|
| 416 |
+
batch_indices = list(range(batch_start, batch_end))
|
| 417 |
+
|
| 418 |
+
# Decode images for this batch
|
| 419 |
+
batch_arrays = []
|
| 420 |
+
batch_valid_indices = []
|
| 421 |
+
|
| 422 |
+
for idx in batch_indices:
|
| 423 |
+
try:
|
| 424 |
+
row = df.iloc[idx]
|
| 425 |
+
img_array = decode_image_bytes(row)
|
| 426 |
+
batch_arrays.append(img_array)
|
| 427 |
+
batch_valid_indices.append(idx)
|
| 428 |
+
except Exception as e:
|
| 429 |
+
logger.warning(f"Error decoding row {idx}: {e}")
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
if not batch_arrays:
|
| 433 |
+
continue
|
| 434 |
+
|
| 435 |
+
# Compute embeddings for this batch
|
| 436 |
+
try:
|
| 437 |
+
batch_embeddings = compute_embedding_batch(model, model_name, batch_arrays)
|
| 438 |
+
|
| 439 |
+
# Store results
|
| 440 |
+
for i, idx in enumerate(batch_valid_indices):
|
| 441 |
+
if batch_embeddings[i] is not None:
|
| 442 |
+
embeddings_list[idx] = batch_embeddings[i]
|
| 443 |
+
valid_mask[idx] = True
|
| 444 |
+
|
| 445 |
+
except Exception as e:
|
| 446 |
+
logger.warning(f"Error computing batch embeddings: {e}")
|
| 447 |
+
# Fall back to single image processing for this batch
|
| 448 |
+
for i, idx in enumerate(batch_valid_indices):
|
| 449 |
+
try:
|
| 450 |
+
embedding = compute_embedding_single(model, model_name, batch_arrays[i])
|
| 451 |
+
if embedding is not None:
|
| 452 |
+
embeddings_list[idx] = embedding
|
| 453 |
+
valid_mask[idx] = True
|
| 454 |
+
except Exception as inner_e:
|
| 455 |
+
logger.warning(f"Error processing row {idx}: {inner_e}")
|
| 456 |
+
continue
|
| 457 |
+
|
| 458 |
+
# Filter to valid rows only
|
| 459 |
+
valid_indices = [i for i, v in enumerate(valid_mask) if v]
|
| 460 |
+
|
| 461 |
+
if not valid_indices:
|
| 462 |
+
logger.warning(f"No valid embeddings for {file_path.name}")
|
| 463 |
+
return None
|
| 464 |
+
|
| 465 |
+
# Build result DataFrame
|
| 466 |
+
result_df = df.iloc[valid_indices].copy()
|
| 467 |
+
valid_embeddings = [embeddings_list[i] for i in valid_indices]
|
| 468 |
+
|
| 469 |
+
# Remove unwanted columns
|
| 470 |
+
cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns]
|
| 471 |
+
if cols_to_drop:
|
| 472 |
+
result_df = result_df.drop(columns=cols_to_drop)
|
| 473 |
+
|
| 474 |
+
# Remove image_bytes (large binary data)
|
| 475 |
+
if 'image_bytes' in result_df.columns:
|
| 476 |
+
result_df = result_df.drop(columns=['image_bytes'])
|
| 477 |
+
|
| 478 |
+
# Remove geometry column (binary)
|
| 479 |
+
if 'geometry' in result_df.columns:
|
| 480 |
+
result_df = result_df.drop(columns=['geometry'])
|
| 481 |
+
|
| 482 |
+
# Rename columns
|
| 483 |
+
result_df = result_df.rename(columns=COLUMNS_RENAME)
|
| 484 |
+
|
| 485 |
+
# Add pixel_bbox
|
| 486 |
+
result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df)
|
| 487 |
+
|
| 488 |
+
# Add embedding
|
| 489 |
+
result_df['embedding'] = valid_embeddings
|
| 490 |
+
|
| 491 |
+
return result_df
|
| 492 |
+
|
| 493 |
+
# =============================================================================
|
| 494 |
+
# Main Processing Pipeline
|
| 495 |
+
# =============================================================================
|
| 496 |
+
def main():
|
| 497 |
+
parser = argparse.ArgumentParser(description='Compute embeddings for Major-TOM images')
|
| 498 |
+
parser.add_argument('--model', type=str, required=True,
|
| 499 |
+
choices=['dinov2', 'siglip', 'farslip', 'satclip'],
|
| 500 |
+
help='Model to use for embedding computation')
|
| 501 |
+
parser.add_argument('--device', type=str, default='cuda:0',
|
| 502 |
+
help='Device to run on (e.g., cuda:0, cuda:1, cpu)')
|
| 503 |
+
parser.add_argument('--batch-size', type=int, default=None,
|
| 504 |
+
help='Batch size for processing (default: model-specific)')
|
| 505 |
+
parser.add_argument('--max-files', type=int, default=None,
|
| 506 |
+
help='Maximum number of files to process (for testing)')
|
| 507 |
+
|
| 508 |
+
args = parser.parse_args()
|
| 509 |
+
|
| 510 |
+
# Setup logging
|
| 511 |
+
logger = setup_logging(args.model)
|
| 512 |
+
|
| 513 |
+
logger.info("=" * 80)
|
| 514 |
+
logger.info(f"Computing {args.model.upper()} embeddings")
|
| 515 |
+
logger.info(f"Timestamp: {datetime.now().isoformat()}")
|
| 516 |
+
logger.info(f"Device: {args.device}")
|
| 517 |
+
logger.info("=" * 80)
|
| 518 |
+
|
| 519 |
+
# Load configuration
|
| 520 |
+
config = load_and_process_config()
|
| 521 |
+
if config is None:
|
| 522 |
+
logger.warning("No config file found, using default paths")
|
| 523 |
+
config = {}
|
| 524 |
+
|
| 525 |
+
# Determine batch size
|
| 526 |
+
batch_size = args.batch_size or BATCH_SIZES.get(args.model, 64)
|
| 527 |
+
logger.info(f"Batch size: {batch_size}")
|
| 528 |
+
|
| 529 |
+
# Get output path
|
| 530 |
+
output_path = MODEL_OUTPUT_PATHS[args.model]
|
| 531 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 532 |
+
logger.info(f"Output path: {output_path}")
|
| 533 |
+
|
| 534 |
+
# Load model
|
| 535 |
+
logger.info(f"Loading {args.model} model...")
|
| 536 |
+
model = load_model(args.model, args.device, config)
|
| 537 |
+
|
| 538 |
+
# Get input files
|
| 539 |
+
parquet_files = sorted(IMAGE_PARQUET_DIR.glob("batch_*.parquet"))
|
| 540 |
+
if args.max_files:
|
| 541 |
+
parquet_files = parquet_files[:args.max_files]
|
| 542 |
+
|
| 543 |
+
logger.info(f"Found {len(parquet_files)} input files")
|
| 544 |
+
|
| 545 |
+
# Process files
|
| 546 |
+
all_results = []
|
| 547 |
+
total_rows = 0
|
| 548 |
+
|
| 549 |
+
for file_path in tqdm(parquet_files, desc=f"Processing {args.model}"):
|
| 550 |
+
try:
|
| 551 |
+
result_df = process_parquet_file(file_path, model, args.model, batch_size)
|
| 552 |
+
|
| 553 |
+
if result_df is not None:
|
| 554 |
+
all_results.append(result_df)
|
| 555 |
+
total_rows += len(result_df)
|
| 556 |
+
logger.info(f"[{file_path.name}] Processed {len(result_df)} rows")
|
| 557 |
+
|
| 558 |
+
except Exception as e:
|
| 559 |
+
logger.error(f"Error processing {file_path.name}: {e}")
|
| 560 |
+
import traceback
|
| 561 |
+
traceback.print_exc()
|
| 562 |
+
continue
|
| 563 |
+
|
| 564 |
+
# Merge and save
|
| 565 |
+
if all_results:
|
| 566 |
+
logger.info("Merging all results...")
|
| 567 |
+
final_df = pd.concat(all_results, ignore_index=True)
|
| 568 |
+
|
| 569 |
+
# Validate columns
|
| 570 |
+
logger.info(f"Final columns: {list(final_df.columns)}")
|
| 571 |
+
|
| 572 |
+
# Check for removed columns
|
| 573 |
+
removed = [c for c in COLUMNS_TO_REMOVE if c in final_df.columns]
|
| 574 |
+
if removed:
|
| 575 |
+
logger.warning(f"Columns still present that should be removed: {removed}")
|
| 576 |
+
else:
|
| 577 |
+
logger.info("✓ All unwanted columns removed")
|
| 578 |
+
|
| 579 |
+
# Check for renamed columns
|
| 580 |
+
if 'utm_crs' in final_df.columns and 'crs' not in final_df.columns:
|
| 581 |
+
logger.info("✓ Column 'crs' renamed to 'utm_crs'")
|
| 582 |
+
|
| 583 |
+
# Check for pixel_bbox
|
| 584 |
+
if 'pixel_bbox' in final_df.columns:
|
| 585 |
+
logger.info("✓ Column 'pixel_bbox' added")
|
| 586 |
+
|
| 587 |
+
# Save
|
| 588 |
+
logger.info(f"Saving to {output_path}...")
|
| 589 |
+
final_df.to_parquet(output_path, index=False)
|
| 590 |
+
|
| 591 |
+
logger.info(f"=" * 80)
|
| 592 |
+
logger.info(f"Processing complete!")
|
| 593 |
+
logger.info(f"Total rows: {len(final_df):,}")
|
| 594 |
+
logger.info(f"Embedding dimension: {len(final_df['embedding'].iloc[0])}")
|
| 595 |
+
logger.info(f"Output file: {output_path}")
|
| 596 |
+
logger.info(f"=" * 80)
|
| 597 |
+
|
| 598 |
+
else:
|
| 599 |
+
logger.error("No data processed!")
|
| 600 |
+
return 1
|
| 601 |
+
|
| 602 |
+
return 0
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
if __name__ == "__main__":
|
| 606 |
+
sys.exit(main())
|
configs/huggingface.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
siglip:
|
| 2 |
+
ckpt_path: "hf"
|
| 3 |
+
model_name: "ViT-SO400M-14-SigLIP-384"
|
| 4 |
+
tokenizer_path: "hf"
|
| 5 |
+
embedding_path: "hf://ML4RS-Anonymous/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet.parquet"
|
| 6 |
+
farslip:
|
| 7 |
+
ckpt_path: "hf"
|
| 8 |
+
model_name: "ViT-B-16"
|
| 9 |
+
embedding_path: "hf://ML4RS-Anonymous/EarthEmbeddings/Core-S2L2A-249k/farslip/FarSLIP_crop_384x384.parquet.parquet"
|
| 10 |
+
satclip:
|
| 11 |
+
ckpt_path: "hf"
|
| 12 |
+
embedding_path: "hf://ML4RS-Anonymous/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet.parquet"
|
| 13 |
+
dinov2:
|
| 14 |
+
ckpt_path: "hf"
|
| 15 |
+
embedding_path: "hf://ML4RS-Anonymous/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet.parquet"
|
countries.geo.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_utils.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fsspec
|
| 2 |
+
import pyarrow.parquet as pq
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from rasterio.io import MemoryFile
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import cartopy.crs as ccrs
|
| 9 |
+
import cartopy.io.img_tiles as cimgt
|
| 10 |
+
from matplotlib.patches import Rectangle
|
| 11 |
+
import math
|
| 12 |
+
from matplotlib.figure import Figure
|
| 13 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def crop_center(img_array, cropx, cropy):
|
| 17 |
+
y, x, c = img_array.shape
|
| 18 |
+
startx = x // 2 - (cropx // 2)
|
| 19 |
+
starty = y // 2 - (cropy // 2)
|
| 20 |
+
return img_array[starty:starty+cropy, startx:startx+cropx]
|
| 21 |
+
|
| 22 |
+
def read_tif_bytes(tif_bytes):
|
| 23 |
+
with MemoryFile(tif_bytes) as mem_f:
|
| 24 |
+
with mem_f.open(driver='GTiff') as f:
|
| 25 |
+
return f.read().squeeze()
|
| 26 |
+
|
| 27 |
+
def read_row_memory(row_dict, columns=["thumbnail"]):
|
| 28 |
+
url = row_dict['parquet_url']
|
| 29 |
+
row_idx = row_dict['parquet_row']
|
| 30 |
+
|
| 31 |
+
fs_options = {
|
| 32 |
+
"cache_type": "readahead",
|
| 33 |
+
"block_size": 5 * 1024 * 1024
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
with fsspec.open(url, mode='rb', **fs_options) as f:
|
| 37 |
+
with pq.ParquetFile(f) as pf:
|
| 38 |
+
table = pf.read_row_group(row_idx, columns=columns)
|
| 39 |
+
|
| 40 |
+
row_output = {}
|
| 41 |
+
for col in columns:
|
| 42 |
+
col_data = table[col][0].as_py()
|
| 43 |
+
|
| 44 |
+
if col != 'thumbnail':
|
| 45 |
+
row_output[col] = read_tif_bytes(col_data)
|
| 46 |
+
else:
|
| 47 |
+
stream = BytesIO(col_data)
|
| 48 |
+
row_output[col] = Image.open(stream)
|
| 49 |
+
|
| 50 |
+
return row_output
|
| 51 |
+
|
| 52 |
+
def download_and_process_image(product_id, df_source=None, verbose=True):
|
| 53 |
+
if df_source is None:
|
| 54 |
+
if verbose: print("❌ Error: No DataFrame provided.")
|
| 55 |
+
return None, None
|
| 56 |
+
|
| 57 |
+
row_subset = df_source[df_source['product_id'] == product_id]
|
| 58 |
+
if len(row_subset) == 0:
|
| 59 |
+
if verbose: print(f"❌ Error: Product ID {product_id} not found in DataFrame.")
|
| 60 |
+
return None, None
|
| 61 |
+
|
| 62 |
+
row_dict = row_subset.iloc[0].to_dict()
|
| 63 |
+
|
| 64 |
+
if 'parquet_url' in row_dict:
|
| 65 |
+
url = row_dict['parquet_url']
|
| 66 |
+
if 'huggingface.co' in url:
|
| 67 |
+
row_dict['parquet_url'] = url.replace('https://huggingface.co', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
|
| 68 |
+
elif 'hf-mirror.com' in url:
|
| 69 |
+
row_dict['parquet_url'] = url.replace('https://hf-mirror.com', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
|
| 70 |
+
else:
|
| 71 |
+
if verbose: print("❌ Error: 'parquet_url' missing in metadata.")
|
| 72 |
+
return None, None
|
| 73 |
+
|
| 74 |
+
if verbose: print(f"⬇️ Fetching data for {product_id} from {row_dict['parquet_url']}...")
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
bands_data = read_row_memory(row_dict, columns=['B04', 'B03', 'B02'])
|
| 78 |
+
|
| 79 |
+
if not all(b in bands_data for b in ['B04', 'B03', 'B02']):
|
| 80 |
+
if verbose: print(f"❌ Error: Missing bands in fetched data for {product_id}")
|
| 81 |
+
return None, None
|
| 82 |
+
|
| 83 |
+
rgb_img = np.stack([bands_data['B04'], bands_data['B03'], bands_data['B02']], axis=-1)
|
| 84 |
+
|
| 85 |
+
if verbose:
|
| 86 |
+
print(f"Raw RGB stats: Min={rgb_img.min()}, Max={rgb_img.max()}, Mean={rgb_img.mean()}, Dtype={rgb_img.dtype}")
|
| 87 |
+
|
| 88 |
+
# Check if data is already 0-255 or 0-1
|
| 89 |
+
if rgb_img.max() <= 255:
|
| 90 |
+
# Assume it might be uint8 or scaled
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
rgb_norm = (2.5 * (rgb_img.astype(float) / 10000.0)).clip(0, 1)
|
| 94 |
+
rgb_uint8 = (rgb_norm * 255).astype(np.uint8)
|
| 95 |
+
|
| 96 |
+
if verbose:
|
| 97 |
+
print(f"Processed RGB stats: Min={rgb_uint8.min()}, Max={rgb_uint8.max()}, Mean={rgb_uint8.mean()}")
|
| 98 |
+
|
| 99 |
+
img_full = Image.fromarray(rgb_uint8)
|
| 100 |
+
|
| 101 |
+
if rgb_uint8.shape[0] >= 384 and rgb_uint8.shape[1] >= 384:
|
| 102 |
+
cropped_array = crop_center(rgb_uint8, 384, 384)
|
| 103 |
+
img_384 = Image.fromarray(cropped_array)
|
| 104 |
+
else:
|
| 105 |
+
if verbose: print(f"⚠️ Image too small {rgb_uint8.shape}, resizing to 384x384.")
|
| 106 |
+
img_384 = img_full.resize((384, 384))
|
| 107 |
+
|
| 108 |
+
if verbose: print(f"✅ Successfully processed {product_id}")
|
| 109 |
+
return img_384, img_full
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
if verbose: print(f"❌ Error processing {product_id}: {e}")
|
| 113 |
+
import traceback
|
| 114 |
+
traceback.print_exc()
|
| 115 |
+
return None, None
|
| 116 |
+
|
| 117 |
+
# Define Esri Imagery Class
|
| 118 |
+
class EsriImagery(cimgt.GoogleTiles):
|
| 119 |
+
def _image_url(self, tile):
|
| 120 |
+
x, y, z = tile
|
| 121 |
+
return f'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
|
| 122 |
+
|
| 123 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 124 |
+
|
| 125 |
+
def get_placeholder_image(text="Image Unavailable", size=(384, 384)):
|
| 126 |
+
img = Image.new('RGB', size, color=(200, 200, 200))
|
| 127 |
+
d = ImageDraw.Draw(img)
|
| 128 |
+
try:
|
| 129 |
+
# Try to load a default font
|
| 130 |
+
font = ImageFont.load_default()
|
| 131 |
+
except:
|
| 132 |
+
font = None
|
| 133 |
+
|
| 134 |
+
# Draw text in center (rough approximation)
|
| 135 |
+
# For better centering we would need font metrics, but simple is fine here
|
| 136 |
+
d.text((20, size[1]//2), text, fill=(0, 0, 0), font=font)
|
| 137 |
+
return img
|
| 138 |
+
|
| 139 |
+
def get_esri_satellite_image(lat, lon, score=None, rank=None, query=None):
|
| 140 |
+
"""
|
| 141 |
+
Generates a satellite image visualization using Esri World Imagery via Cartopy.
|
| 142 |
+
Matches the style of the provided notebook.
|
| 143 |
+
Uses OO Matplotlib API for thread safety.
|
| 144 |
+
"""
|
| 145 |
+
try:
|
| 146 |
+
imagery = EsriImagery()
|
| 147 |
+
|
| 148 |
+
# Create figure using OO API
|
| 149 |
+
fig = Figure(figsize=(5, 5), dpi=100)
|
| 150 |
+
canvas = FigureCanvasAgg(fig)
|
| 151 |
+
ax = fig.add_subplot(1, 1, 1, projection=imagery.crs)
|
| 152 |
+
|
| 153 |
+
# Set extent to approx 10km x 10km around the point
|
| 154 |
+
extent_deg = 0.05
|
| 155 |
+
ax.set_extent([lon - extent_deg, lon + extent_deg, lat - extent_deg, lat + extent_deg], crs=ccrs.PlateCarree())
|
| 156 |
+
|
| 157 |
+
# Add the imagery
|
| 158 |
+
ax.add_image(imagery, 14)
|
| 159 |
+
|
| 160 |
+
# Add a marker for the center
|
| 161 |
+
ax.plot(lon, lat, marker='+', color='yellow', markersize=12, markeredgewidth=2, transform=ccrs.PlateCarree())
|
| 162 |
+
|
| 163 |
+
# Add Bounding Box (3840m x 3840m)
|
| 164 |
+
box_size_m = 384 * 10 # 3840m
|
| 165 |
+
|
| 166 |
+
# Convert meters to degrees (approx)
|
| 167 |
+
# 1 deg lat = 111320m
|
| 168 |
+
# 1 deg lon = 111320m * cos(lat)
|
| 169 |
+
dlat = (box_size_m / 111320)
|
| 170 |
+
dlon = (box_size_m / (111320 * math.cos(math.radians(lat))))
|
| 171 |
+
|
| 172 |
+
# Bottom-Left corner
|
| 173 |
+
rect_lon = lon - dlon / 2
|
| 174 |
+
rect_lat = lat - dlat / 2
|
| 175 |
+
|
| 176 |
+
# Add Rectangle
|
| 177 |
+
rect = Rectangle((rect_lon, rect_lat), dlon, dlat,
|
| 178 |
+
linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree())
|
| 179 |
+
ax.add_patch(rect)
|
| 180 |
+
|
| 181 |
+
# Title
|
| 182 |
+
title_parts = []
|
| 183 |
+
if query: title_parts.append(f"{query}")
|
| 184 |
+
if rank is not None: title_parts.append(f"Rank {rank}")
|
| 185 |
+
if score is not None: title_parts.append(f"Score: {score:.4f}")
|
| 186 |
+
|
| 187 |
+
ax.set_title("\n".join(title_parts), fontsize=10)
|
| 188 |
+
|
| 189 |
+
# Save to buffer
|
| 190 |
+
buf = BytesIO()
|
| 191 |
+
fig.savefig(buf, format='png', bbox_inches='tight')
|
| 192 |
+
buf.seek(0)
|
| 193 |
+
|
| 194 |
+
return Image.open(buf)
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
# Suppress full traceback for network errors to avoid log spam
|
| 198 |
+
error_msg = str(e)
|
| 199 |
+
if "Connection reset by peer" in error_msg or "Network is unreachable" in error_msg or "urlopen error" in error_msg:
|
| 200 |
+
print(f"⚠️ Network warning: Could not fetch Esri satellite map for ({lat:.4f}, {lon:.4f}). Server might be offline.")
|
| 201 |
+
else:
|
| 202 |
+
print(f"Error generating Esri image for {lat}, {lon}: {e}")
|
| 203 |
+
# Only print traceback for non-network errors
|
| 204 |
+
# import traceback
|
| 205 |
+
# traceback.print_exc()
|
| 206 |
+
|
| 207 |
+
# Return a placeholder image with text
|
| 208 |
+
return get_placeholder_image(f"Map Unavailable\n({lat:.2f}, {lon:.2f})")
|
| 209 |
+
|
| 210 |
+
def get_esri_satellite_image_url(lat, lon, zoom=14):
|
| 211 |
+
"""
|
| 212 |
+
Returns the URL for the Esri World Imagery tile at the given location.
|
| 213 |
+
"""
|
| 214 |
+
try:
|
| 215 |
+
imagery = EsriImagery()
|
| 216 |
+
# Calculate tile coordinates
|
| 217 |
+
# This is a simplification, cimgt handles this internally usually
|
| 218 |
+
# But for direct URL we might need more logic or just use the static map approach above
|
| 219 |
+
# For now, let's stick to the static map generation which works
|
| 220 |
+
pass
|
| 221 |
+
except:
|
| 222 |
+
pass
|
| 223 |
+
return None
|
examples/example1.png
ADDED
|
Git LFS Details
|
examples/example2.png
ADDED
|
Git LFS Details
|
examples/example3.png
ADDED
|
Git LFS Details
|
logs/compute_embeddings_dinov2.log
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-02-01 09:07:55,115 [INFO] ================================================================================
|
| 2 |
+
2026-02-01 09:07:55,115 [INFO] Computing DINOV2 embeddings
|
| 3 |
+
2026-02-01 09:07:55,115 [INFO] Timestamp: 2026-02-01T09:07:55.115269
|
| 4 |
+
2026-02-01 09:07:55,115 [INFO] Device: cuda:0
|
| 5 |
+
2026-02-01 09:07:55,115 [INFO] ================================================================================
|
| 6 |
+
2026-02-01 09:07:55,116 [INFO] Batch size: 64
|
| 7 |
+
2026-02-01 09:07:55,116 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet
|
| 8 |
+
2026-02-01 09:07:55,116 [INFO] Loading dinov2 model...
|
| 9 |
+
2026-02-01 09:07:58,665 [INFO] DINOv2 model loaded on cuda:0
|
| 10 |
+
2026-02-01 09:07:58,666 [INFO] Found 1 input files
|
| 11 |
+
2026-02-01 09:08:48,122 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
|
| 12 |
+
2026-02-01 09:08:48,122 [INFO] Merging all results...
|
| 13 |
+
2026-02-01 09:08:48,122 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
|
| 14 |
+
2026-02-01 09:08:48,122 [INFO] ✓ All unwanted columns removed
|
| 15 |
+
2026-02-01 09:08:48,122 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
|
| 16 |
+
2026-02-01 09:08:48,122 [INFO] ✓ Column 'pixel_bbox' added
|
| 17 |
+
2026-02-01 09:08:48,122 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet...
|
| 18 |
+
2026-02-01 09:08:48,228 [INFO] ================================================================================
|
| 19 |
+
2026-02-01 09:08:48,228 [INFO] Processing complete!
|
| 20 |
+
2026-02-01 09:08:48,228 [INFO] Total rows: 1,996
|
| 21 |
+
2026-02-01 09:08:48,228 [INFO] Embedding dimension: 1024
|
| 22 |
+
2026-02-01 09:08:48,228 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet
|
| 23 |
+
2026-02-01 09:08:48,228 [INFO] ================================================================================
|
| 24 |
+
2026-02-01 09:43:06,596 [INFO] ================================================================================
|
| 25 |
+
2026-02-01 09:43:06,596 [INFO] Computing DINOV2 embeddings
|
| 26 |
+
2026-02-01 09:43:06,596 [INFO] Timestamp: 2026-02-01T09:43:06.596521
|
| 27 |
+
2026-02-01 09:43:06,596 [INFO] Device: cuda:1
|
| 28 |
+
2026-02-01 09:43:06,596 [INFO] ================================================================================
|
| 29 |
+
2026-02-01 09:43:06,597 [INFO] Batch size: 64
|
| 30 |
+
2026-02-01 09:43:06,597 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet
|
| 31 |
+
2026-02-01 09:43:06,597 [INFO] Loading dinov2 model...
|
| 32 |
+
2026-02-01 09:43:08,665 [INFO] DINOv2 model loaded on cuda:1
|
| 33 |
+
2026-02-01 09:43:08,666 [INFO] Found 125 input files
|
| 34 |
+
2026-02-01 09:43:59,600 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
|
| 35 |
+
2026-02-01 09:44:50,531 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
|
| 36 |
+
2026-02-01 09:45:40,104 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
|
| 37 |
+
2026-02-01 09:46:31,203 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
|
| 38 |
+
2026-02-01 09:47:22,240 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
|
| 39 |
+
2026-02-01 09:48:17,789 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
|
| 40 |
+
2026-02-01 09:49:12,206 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
|
| 41 |
+
2026-02-01 09:50:04,633 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
|
| 42 |
+
2026-02-01 09:51:01,688 [INFO] [batch_0009_384x384.parquet] Processed 1993 rows
|
| 43 |
+
2026-02-01 09:51:52,258 [INFO] [batch_0010_384x384.parquet] Processed 1993 rows
|
| 44 |
+
2026-02-01 09:52:43,385 [INFO] [batch_0011_384x384.parquet] Processed 1990 rows
|
| 45 |
+
2026-02-01 09:53:33,664 [INFO] [batch_0012_384x384.parquet] Processed 1993 rows
|
| 46 |
+
2026-02-01 09:54:23,450 [INFO] [batch_0013_384x384.parquet] Processed 1993 rows
|
| 47 |
+
2026-02-01 09:55:14,741 [INFO] [batch_0014_384x384.parquet] Processed 1991 rows
|
| 48 |
+
2026-02-01 09:56:05,637 [INFO] [batch_0015_384x384.parquet] Processed 1993 rows
|
| 49 |
+
2026-02-01 09:57:02,579 [INFO] [batch_0016_384x384.parquet] Processed 1990 rows
|
| 50 |
+
2026-02-01 09:57:59,164 [INFO] [batch_0017_384x384.parquet] Processed 1991 rows
|
| 51 |
+
2026-02-01 09:58:54,668 [INFO] [batch_0018_384x384.parquet] Processed 1991 rows
|
| 52 |
+
2026-02-01 09:59:50,748 [INFO] [batch_0019_384x384.parquet] Processed 1996 rows
|
| 53 |
+
2026-02-01 10:00:44,987 [INFO] [batch_0020_384x384.parquet] Processed 1994 rows
|
| 54 |
+
2026-02-01 10:01:41,422 [INFO] [batch_0021_384x384.parquet] Processed 1996 rows
|
| 55 |
+
2026-02-01 10:02:39,884 [INFO] [batch_0022_384x384.parquet] Processed 1995 rows
|
| 56 |
+
2026-02-01 10:03:41,408 [INFO] [batch_0023_384x384.parquet] Processed 1992 rows
|
| 57 |
+
2026-02-01 10:04:44,392 [INFO] [batch_0024_384x384.parquet] Processed 1989 rows
|
| 58 |
+
2026-02-01 10:05:47,970 [INFO] [batch_0025_384x384.parquet] Processed 1993 rows
|
| 59 |
+
2026-02-01 10:06:47,594 [INFO] [batch_0026_384x384.parquet] Processed 1995 rows
|
| 60 |
+
2026-02-01 10:07:46,292 [INFO] [batch_0027_384x384.parquet] Processed 1997 rows
|
| 61 |
+
2026-02-01 10:08:43,976 [INFO] [batch_0028_384x384.parquet] Processed 1991 rows
|
| 62 |
+
2026-02-01 10:09:43,099 [INFO] [batch_0029_384x384.parquet] Processed 1992 rows
|
| 63 |
+
2026-02-01 10:10:40,183 [INFO] [batch_0030_384x384.parquet] Processed 1993 rows
|
| 64 |
+
2026-02-01 10:11:44,485 [INFO] [batch_0031_384x384.parquet] Processed 1988 rows
|
| 65 |
+
2026-02-01 10:12:39,796 [INFO] [batch_0032_384x384.parquet] Processed 1994 rows
|
| 66 |
+
2026-02-01 10:13:45,836 [INFO] [batch_0033_384x384.parquet] Processed 1992 rows
|
| 67 |
+
2026-02-01 10:14:44,908 [INFO] [batch_0034_384x384.parquet] Processed 1994 rows
|
| 68 |
+
2026-02-01 10:15:44,326 [INFO] [batch_0035_384x384.parquet] Processed 1994 rows
|
| 69 |
+
2026-02-01 10:16:43,931 [INFO] [batch_0036_384x384.parquet] Processed 1996 rows
|
| 70 |
+
2026-02-01 10:17:41,513 [INFO] [batch_0037_384x384.parquet] Processed 1995 rows
|
| 71 |
+
2026-02-01 10:18:39,810 [INFO] [batch_0038_384x384.parquet] Processed 1993 rows
|
| 72 |
+
2026-02-01 10:19:36,710 [INFO] [batch_0039_384x384.parquet] Processed 1989 rows
|
| 73 |
+
2026-02-01 10:20:31,841 [INFO] [batch_0040_384x384.parquet] Processed 1990 rows
|
| 74 |
+
2026-02-01 10:21:29,236 [INFO] [batch_0041_384x384.parquet] Processed 1998 rows
|
| 75 |
+
2026-02-01 10:22:32,483 [INFO] [batch_0042_384x384.parquet] Processed 1997 rows
|
| 76 |
+
2026-02-01 10:23:28,852 [INFO] [batch_0043_384x384.parquet] Processed 1988 rows
|
| 77 |
+
2026-02-01 10:24:24,324 [INFO] [batch_0044_384x384.parquet] Processed 1991 rows
|
| 78 |
+
2026-02-01 10:25:22,097 [INFO] [batch_0045_384x384.parquet] Processed 1993 rows
|
| 79 |
+
2026-02-01 10:26:18,196 [INFO] [batch_0046_384x384.parquet] Processed 1994 rows
|
| 80 |
+
2026-02-01 10:27:34,649 [INFO] [batch_0047_384x384.parquet] Processed 1995 rows
|
| 81 |
+
2026-02-01 10:28:30,976 [INFO] [batch_0048_384x384.parquet] Processed 1992 rows
|
| 82 |
+
2026-02-01 10:29:41,715 [INFO] [batch_0049_384x384.parquet] Processed 1996 rows
|
| 83 |
+
2026-02-01 10:30:45,082 [INFO] [batch_0050_384x384.parquet] Processed 1991 rows
|
| 84 |
+
2026-02-01 10:31:46,711 [INFO] [batch_0051_384x384.parquet] Processed 1997 rows
|
| 85 |
+
2026-02-01 10:32:45,127 [INFO] [batch_0052_384x384.parquet] Processed 1993 rows
|
| 86 |
+
2026-02-01 10:33:48,960 [INFO] [batch_0053_384x384.parquet] Processed 1995 rows
|
| 87 |
+
2026-02-01 10:35:01,705 [INFO] [batch_0054_384x384.parquet] Processed 1997 rows
|
| 88 |
+
2026-02-01 10:36:11,677 [INFO] [batch_0055_384x384.parquet] Processed 1995 rows
|
| 89 |
+
2026-02-01 10:37:17,746 [INFO] [batch_0056_384x384.parquet] Processed 1997 rows
|
| 90 |
+
2026-02-01 10:38:28,458 [INFO] [batch_0057_384x384.parquet] Processed 1991 rows
|
| 91 |
+
2026-02-01 10:39:38,673 [INFO] [batch_0058_384x384.parquet] Processed 1994 rows
|
| 92 |
+
2026-02-01 10:40:48,784 [INFO] [batch_0059_384x384.parquet] Processed 1993 rows
|
| 93 |
+
2026-02-01 10:41:47,477 [INFO] [batch_0060_384x384.parquet] Processed 1995 rows
|
| 94 |
+
2026-02-01 10:42:55,595 [INFO] [batch_0061_384x384.parquet] Processed 1995 rows
|
| 95 |
+
2026-02-01 10:44:08,413 [INFO] [batch_0062_384x384.parquet] Processed 1998 rows
|
| 96 |
+
2026-02-01 10:45:27,616 [INFO] [batch_0063_384x384.parquet] Processed 1997 rows
|
| 97 |
+
2026-02-01 10:46:40,936 [INFO] [batch_0064_384x384.parquet] Processed 1992 rows
|
| 98 |
+
2026-02-01 10:47:38,737 [INFO] [batch_0065_384x384.parquet] Processed 1994 rows
|
| 99 |
+
2026-02-01 10:48:46,233 [INFO] [batch_0066_384x384.parquet] Processed 1992 rows
|
| 100 |
+
2026-02-01 10:49:56,228 [INFO] [batch_0067_384x384.parquet] Processed 1993 rows
|
| 101 |
+
2026-02-01 10:51:12,380 [INFO] [batch_0068_384x384.parquet] Processed 1994 rows
|
| 102 |
+
2026-02-01 10:52:27,369 [INFO] [batch_0069_384x384.parquet] Processed 1992 rows
|
| 103 |
+
2026-02-01 10:53:42,056 [INFO] [batch_0070_384x384.parquet] Processed 1997 rows
|
| 104 |
+
2026-02-01 10:54:50,573 [INFO] [batch_0071_384x384.parquet] Processed 1996 rows
|
| 105 |
+
2026-02-01 10:56:03,974 [INFO] [batch_0072_384x384.parquet] Processed 1992 rows
|
| 106 |
+
2026-02-01 10:57:09,742 [INFO] [batch_0073_384x384.parquet] Processed 1995 rows
|
| 107 |
+
2026-02-01 10:58:22,365 [INFO] [batch_0074_384x384.parquet] Processed 1992 rows
|
| 108 |
+
2026-02-01 10:59:33,712 [INFO] [batch_0075_384x384.parquet] Processed 1991 rows
|
| 109 |
+
2026-02-01 11:00:48,387 [INFO] [batch_0076_384x384.parquet] Processed 1998 rows
|
| 110 |
+
2026-02-01 11:01:47,919 [INFO] [batch_0077_384x384.parquet] Processed 1996 rows
|
| 111 |
+
2026-02-01 11:03:01,336 [INFO] [batch_0078_384x384.parquet] Processed 1992 rows
|
| 112 |
+
2026-02-01 11:04:04,437 [INFO] [batch_0079_384x384.parquet] Processed 1995 rows
|
| 113 |
+
2026-02-01 11:05:15,344 [INFO] [batch_0080_384x384.parquet] Processed 1993 rows
|
| 114 |
+
2026-02-01 11:06:26,434 [INFO] [batch_0081_384x384.parquet] Processed 1995 rows
|
| 115 |
+
2026-02-01 11:07:29,500 [INFO] [batch_0082_384x384.parquet] Processed 1989 rows
|
| 116 |
+
2026-02-01 11:08:41,452 [INFO] [batch_0083_384x384.parquet] Processed 1995 rows
|
| 117 |
+
2026-02-01 11:09:52,372 [INFO] [batch_0084_384x384.parquet] Processed 1996 rows
|
| 118 |
+
2026-02-01 11:10:54,102 [INFO] [batch_0085_384x384.parquet] Processed 1997 rows
|
| 119 |
+
2026-02-01 11:12:05,011 [INFO] [batch_0086_384x384.parquet] Processed 1996 rows
|
| 120 |
+
2026-02-01 11:13:18,046 [INFO] [batch_0087_384x384.parquet] Processed 1994 rows
|
| 121 |
+
2026-02-01 11:14:28,554 [INFO] [batch_0088_384x384.parquet] Processed 1992 rows
|
| 122 |
+
2026-02-01 11:15:30,371 [INFO] [batch_0089_384x384.parquet] Processed 1993 rows
|
| 123 |
+
2026-02-01 11:16:36,098 [INFO] [batch_0090_384x384.parquet] Processed 1993 rows
|
| 124 |
+
2026-02-01 11:17:47,559 [INFO] [batch_0091_384x384.parquet] Processed 1995 rows
|
| 125 |
+
2026-02-01 11:18:59,181 [INFO] [batch_0092_384x384.parquet] Processed 1994 rows
|
| 126 |
+
2026-02-01 11:20:10,040 [INFO] [batch_0093_384x384.parquet] Processed 1998 rows
|
| 127 |
+
2026-02-01 11:21:11,780 [INFO] [batch_0094_384x384.parquet] Processed 1993 rows
|
| 128 |
+
2026-02-01 11:22:13,323 [INFO] [batch_0095_384x384.parquet] Processed 1995 rows
|
| 129 |
+
2026-02-01 11:23:13,963 [INFO] [batch_0096_384x384.parquet] Processed 1997 rows
|
| 130 |
+
2026-02-01 11:24:11,380 [INFO] [batch_0097_384x384.parquet] Processed 1990 rows
|
| 131 |
+
2026-02-01 11:25:16,113 [INFO] [batch_0098_384x384.parquet] Processed 1995 rows
|
| 132 |
+
2026-02-01 11:26:15,319 [INFO] [batch_0099_384x384.parquet] Processed 1992 rows
|
| 133 |
+
2026-02-01 11:27:09,846 [INFO] [batch_0100_384x384.parquet] Processed 1993 rows
|
| 134 |
+
2026-02-01 11:28:13,634 [INFO] [batch_0101_384x384.parquet] Processed 1994 rows
|
| 135 |
+
2026-02-01 11:29:19,508 [INFO] [batch_0102_384x384.parquet] Processed 1991 rows
|
| 136 |
+
2026-02-01 11:30:27,321 [INFO] [batch_0103_384x384.parquet] Processed 1990 rows
|
| 137 |
+
2026-02-01 11:31:38,038 [INFO] [batch_0104_384x384.parquet] Processed 1995 rows
|
| 138 |
+
2026-02-01 11:32:55,342 [INFO] [batch_0105_384x384.parquet] Processed 1993 rows
|
| 139 |
+
2026-02-01 11:34:02,868 [INFO] [batch_0106_384x384.parquet] Processed 1988 rows
|
| 140 |
+
2026-02-01 11:35:08,481 [INFO] [batch_0107_384x384.parquet] Processed 1996 rows
|
| 141 |
+
2026-02-01 11:36:17,025 [INFO] [batch_0108_384x384.parquet] Processed 1992 rows
|
| 142 |
+
2026-02-01 11:37:26,799 [INFO] [batch_0109_384x384.parquet] Processed 1993 rows
|
| 143 |
+
2026-02-01 11:38:39,274 [INFO] [batch_0110_384x384.parquet] Processed 1996 rows
|
| 144 |
+
2026-02-01 11:39:49,743 [INFO] [batch_0111_384x384.parquet] Processed 1991 rows
|
| 145 |
+
2026-02-01 11:40:47,923 [INFO] [batch_0112_384x384.parquet] Processed 1994 rows
|
| 146 |
+
2026-02-01 11:41:53,376 [INFO] [batch_0113_384x384.parquet] Processed 1996 rows
|
| 147 |
+
2026-02-01 11:42:53,847 [INFO] [batch_0114_384x384.parquet] Processed 1997 rows
|
| 148 |
+
2026-02-01 11:43:47,456 [INFO] [batch_0115_384x384.parquet] Processed 1997 rows
|
| 149 |
+
2026-02-01 11:44:47,188 [INFO] [batch_0116_384x384.parquet] Processed 1992 rows
|
| 150 |
+
2026-02-01 11:45:44,350 [INFO] [batch_0117_384x384.parquet] Processed 1992 rows
|
| 151 |
+
2026-02-01 11:46:51,765 [INFO] [batch_0118_384x384.parquet] Processed 1993 rows
|
| 152 |
+
2026-02-01 11:47:54,777 [INFO] [batch_0119_384x384.parquet] Processed 1998 rows
|
| 153 |
+
2026-02-01 11:48:58,907 [INFO] [batch_0120_384x384.parquet] Processed 1995 rows
|
| 154 |
+
2026-02-01 11:49:59,917 [INFO] [batch_0121_384x384.parquet] Processed 1995 rows
|
| 155 |
+
2026-02-01 11:51:00,476 [INFO] [batch_0122_384x384.parquet] Processed 1992 rows
|
| 156 |
+
2026-02-01 11:52:05,414 [INFO] [batch_0123_384x384.parquet] Processed 1994 rows
|
| 157 |
+
2026-02-01 11:53:06,075 [INFO] [batch_0124_384x384.parquet] Processed 1995 rows
|
| 158 |
+
2026-02-01 11:53:54,915 [INFO] [batch_0125_384x384.parquet] Processed 1505 rows
|
| 159 |
+
2026-02-01 11:53:54,915 [INFO] Merging all results...
|
| 160 |
+
2026-02-01 11:53:54,970 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
|
| 161 |
+
2026-02-01 11:53:54,971 [INFO] ✓ All unwanted columns removed
|
| 162 |
+
2026-02-01 11:53:54,971 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
|
| 163 |
+
2026-02-01 11:53:54,971 [INFO] ✓ Column 'pixel_bbox' added
|
| 164 |
+
2026-02-01 11:53:54,971 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet...
|
| 165 |
+
2026-02-01 11:54:03,559 [INFO] ================================================================================
|
| 166 |
+
2026-02-01 11:54:03,559 [INFO] Processing complete!
|
| 167 |
+
2026-02-01 11:54:03,559 [INFO] Total rows: 248,719
|
| 168 |
+
2026-02-01 11:54:03,560 [INFO] Embedding dimension: 1024
|
| 169 |
+
2026-02-01 11:54:03,560 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/dinov2/DINOv2_crop_384x384.parquet
|
| 170 |
+
2026-02-01 11:54:03,560 [INFO] ================================================================================
|
logs/compute_embeddings_farslip.log
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-02-01 09:54:48,604 [INFO] ================================================================================
|
| 2 |
+
2026-02-01 09:54:48,605 [INFO] Computing FARSLIP embeddings
|
| 3 |
+
2026-02-01 09:54:48,605 [INFO] Timestamp: 2026-02-01T09:54:48.605134
|
| 4 |
+
2026-02-01 09:54:48,605 [INFO] Device: cuda:4
|
| 5 |
+
2026-02-01 09:54:48,605 [INFO] ================================================================================
|
| 6 |
+
2026-02-01 09:54:48,606 [INFO] Batch size: 64
|
| 7 |
+
2026-02-01 09:54:48,607 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/farslip/FarSLIP_crop_384x384.parquet
|
| 8 |
+
2026-02-01 09:54:48,607 [INFO] Loading farslip model...
|
| 9 |
+
2026-02-01 09:54:48,613 [INFO] Loaded ViT-B-16 model config.
|
| 10 |
+
2026-02-01 09:54:50,536 [INFO] Loading pretrained ViT-B-16 weights (/data1/zyj/checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt).
|
| 11 |
+
2026-02-01 09:54:51,666 [INFO] Missing keys: []
|
| 12 |
+
2026-02-01 09:54:51,745 [INFO] FarSLIP model loaded on cuda:4
|
| 13 |
+
2026-02-01 09:54:51,745 [INFO] Found 125 input files
|
| 14 |
+
2026-02-01 09:55:38,785 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
|
| 15 |
+
2026-02-01 09:56:18,239 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
|
| 16 |
+
2026-02-01 09:57:17,259 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
|
| 17 |
+
2026-02-01 09:58:08,339 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
|
| 18 |
+
2026-02-01 09:59:00,302 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
|
| 19 |
+
2026-02-01 10:00:15,416 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
|
| 20 |
+
2026-02-01 10:01:22,601 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
|
| 21 |
+
2026-02-01 10:02:25,131 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
|
| 22 |
+
2026-02-01 10:03:31,735 [INFO] [batch_0009_384x384.parquet] Processed 1993 rows
|
| 23 |
+
2026-02-01 10:04:47,342 [INFO] [batch_0010_384x384.parquet] Processed 1993 rows
|
| 24 |
+
2026-02-01 10:05:54,617 [INFO] [batch_0011_384x384.parquet] Processed 1990 rows
|
| 25 |
+
2026-02-01 10:06:58,372 [INFO] [batch_0012_384x384.parquet] Processed 1993 rows
|
| 26 |
+
2026-02-01 10:08:16,301 [INFO] [batch_0013_384x384.parquet] Processed 1993 rows
|
| 27 |
+
2026-02-01 10:09:11,722 [INFO] [batch_0014_384x384.parquet] Processed 1991 rows
|
| 28 |
+
2026-02-01 10:10:23,603 [INFO] [batch_0015_384x384.parquet] Processed 1993 rows
|
| 29 |
+
2026-02-01 10:11:38,047 [INFO] [batch_0016_384x384.parquet] Processed 1990 rows
|
| 30 |
+
2026-02-01 10:12:22,943 [INFO] [batch_0017_384x384.parquet] Processed 1991 rows
|
| 31 |
+
2026-02-01 10:13:41,095 [INFO] [batch_0018_384x384.parquet] Processed 1991 rows
|
| 32 |
+
2026-02-01 10:14:47,596 [INFO] [batch_0019_384x384.parquet] Processed 1996 rows
|
| 33 |
+
2026-02-01 10:15:40,983 [INFO] [batch_0020_384x384.parquet] Processed 1994 rows
|
| 34 |
+
2026-02-01 10:16:52,878 [INFO] [batch_0021_384x384.parquet] Processed 1996 rows
|
| 35 |
+
2026-02-01 10:17:43,460 [INFO] [batch_0022_384x384.parquet] Processed 1995 rows
|
| 36 |
+
2026-02-01 10:18:41,479 [INFO] [batch_0023_384x384.parquet] Processed 1992 rows
|
| 37 |
+
2026-02-01 10:19:40,728 [INFO] [batch_0024_384x384.parquet] Processed 1989 rows
|
| 38 |
+
2026-02-01 10:20:25,503 [INFO] [batch_0025_384x384.parquet] Processed 1993 rows
|
| 39 |
+
2026-02-01 10:21:27,428 [INFO] [batch_0026_384x384.parquet] Processed 1995 rows
|
| 40 |
+
2026-02-01 10:22:23,776 [INFO] [batch_0027_384x384.parquet] Processed 1997 rows
|
| 41 |
+
2026-02-01 10:23:16,992 [INFO] [batch_0028_384x384.parquet] Processed 1991 rows
|
| 42 |
+
2026-02-01 10:24:14,634 [INFO] [batch_0029_384x384.parquet] Processed 1992 rows
|
| 43 |
+
2026-02-01 10:24:55,464 [INFO] [batch_0030_384x384.parquet] Processed 1993 rows
|
| 44 |
+
2026-02-01 10:25:56,600 [INFO] [batch_0031_384x384.parquet] Processed 1988 rows
|
| 45 |
+
2026-02-01 10:26:40,392 [INFO] [batch_0032_384x384.parquet] Processed 1994 rows
|
| 46 |
+
2026-02-01 10:27:49,696 [INFO] [batch_0033_384x384.parquet] Processed 1992 rows
|
| 47 |
+
2026-02-01 10:28:49,831 [INFO] [batch_0034_384x384.parquet] Processed 1994 rows
|
| 48 |
+
2026-02-01 10:29:42,378 [INFO] [batch_0035_384x384.parquet] Processed 1994 rows
|
| 49 |
+
2026-02-01 10:30:48,969 [INFO] [batch_0036_384x384.parquet] Processed 1996 rows
|
| 50 |
+
2026-02-01 10:32:01,922 [INFO] [batch_0037_384x384.parquet] Processed 1995 rows
|
| 51 |
+
2026-02-01 10:32:47,057 [INFO] [batch_0038_384x384.parquet] Processed 1993 rows
|
| 52 |
+
2026-02-01 10:34:01,196 [INFO] [batch_0039_384x384.parquet] Processed 1989 rows
|
| 53 |
+
2026-02-01 10:35:19,501 [INFO] [batch_0040_384x384.parquet] Processed 1990 rows
|
| 54 |
+
2026-02-01 10:36:09,997 [INFO] [batch_0041_384x384.parquet] Processed 1998 rows
|
| 55 |
+
2026-02-01 10:37:25,589 [INFO] [batch_0042_384x384.parquet] Processed 1997 rows
|
| 56 |
+
2026-02-01 10:38:42,876 [INFO] [batch_0043_384x384.parquet] Processed 1988 rows
|
| 57 |
+
2026-02-01 10:39:31,979 [INFO] [batch_0044_384x384.parquet] Processed 1991 rows
|
| 58 |
+
2026-02-01 10:40:43,745 [INFO] [batch_0045_384x384.parquet] Processed 1993 rows
|
| 59 |
+
2026-02-01 10:41:59,576 [INFO] [batch_0046_384x384.parquet] Processed 1994 rows
|
| 60 |
+
2026-02-01 10:42:53,620 [INFO] [batch_0047_384x384.parquet] Processed 1995 rows
|
| 61 |
+
2026-02-01 10:44:25,584 [INFO] [batch_0048_384x384.parquet] Processed 1992 rows
|
| 62 |
+
2026-02-01 10:46:13,258 [INFO] [batch_0049_384x384.parquet] Processed 1996 rows
|
| 63 |
+
2026-02-01 10:47:13,109 [INFO] [batch_0050_384x384.parquet] Processed 1991 rows
|
| 64 |
+
2026-02-01 10:48:13,385 [INFO] [batch_0051_384x384.parquet] Processed 1997 rows
|
| 65 |
+
2026-02-01 10:49:48,140 [INFO] [batch_0052_384x384.parquet] Processed 1993 rows
|
| 66 |
+
2026-02-01 10:51:22,710 [INFO] [batch_0053_384x384.parquet] Processed 1995 rows
|
| 67 |
+
2026-02-01 10:52:23,823 [INFO] [batch_0054_384x384.parquet] Processed 1997 rows
|
| 68 |
+
2026-02-01 10:53:48,669 [INFO] [batch_0055_384x384.parquet] Processed 1995 rows
|
| 69 |
+
2026-02-01 10:55:03,785 [INFO] [batch_0056_384x384.parquet] Processed 1997 rows
|
| 70 |
+
2026-02-01 10:55:56,653 [INFO] [batch_0057_384x384.parquet] Processed 1991 rows
|
| 71 |
+
2026-02-01 10:56:50,364 [INFO] [batch_0058_384x384.parquet] Processed 1994 rows
|
| 72 |
+
2026-02-01 10:57:33,268 [INFO] [batch_0059_384x384.parquet] Processed 1993 rows
|
| 73 |
+
2026-02-01 10:58:36,103 [INFO] [batch_0060_384x384.parquet] Processed 1995 rows
|
| 74 |
+
2026-02-01 10:59:43,156 [INFO] [batch_0061_384x384.parquet] Processed 1995 rows
|
| 75 |
+
2026-02-01 11:00:45,280 [INFO] [batch_0062_384x384.parquet] Processed 1998 rows
|
| 76 |
+
2026-02-01 11:02:03,960 [INFO] [batch_0063_384x384.parquet] Processed 1997 rows
|
| 77 |
+
2026-02-01 11:03:01,993 [INFO] [batch_0064_384x384.parquet] Processed 1992 rows
|
| 78 |
+
2026-02-01 11:04:18,812 [INFO] [batch_0065_384x384.parquet] Processed 1994 rows
|
| 79 |
+
2026-02-01 11:05:34,954 [INFO] [batch_0066_384x384.parquet] Processed 1992 rows
|
| 80 |
+
2026-02-01 11:06:26,502 [INFO] [batch_0067_384x384.parquet] Processed 1993 rows
|
| 81 |
+
2026-02-01 11:07:42,754 [INFO] [batch_0068_384x384.parquet] Processed 1994 rows
|
| 82 |
+
2026-02-01 11:09:01,751 [INFO] [batch_0069_384x384.parquet] Processed 1992 rows
|
| 83 |
+
2026-02-01 11:09:49,394 [INFO] [batch_0070_384x384.parquet] Processed 1997 rows
|
| 84 |
+
2026-02-01 11:11:06,518 [INFO] [batch_0071_384x384.parquet] Processed 1996 rows
|
| 85 |
+
2026-02-01 11:12:22,688 [INFO] [batch_0072_384x384.parquet] Processed 1992 rows
|
| 86 |
+
2026-02-01 11:13:14,831 [INFO] [batch_0073_384x384.parquet] Processed 1995 rows
|
| 87 |
+
2026-02-01 11:14:14,879 [INFO] [batch_0074_384x384.parquet] Processed 1992 rows
|
| 88 |
+
2026-02-01 11:14:58,098 [INFO] [batch_0075_384x384.parquet] Processed 1991 rows
|
| 89 |
+
2026-02-01 11:15:43,764 [INFO] [batch_0076_384x384.parquet] Processed 1998 rows
|
| 90 |
+
2026-02-01 11:16:53,710 [INFO] [batch_0077_384x384.parquet] Processed 1996 rows
|
| 91 |
+
2026-02-01 11:17:51,040 [INFO] [batch_0078_384x384.parquet] Processed 1992 rows
|
| 92 |
+
2026-02-01 11:18:57,871 [INFO] [batch_0079_384x384.parquet] Processed 1995 rows
|
| 93 |
+
2026-02-01 11:20:06,930 [INFO] [batch_0080_384x384.parquet] Processed 1993 rows
|
| 94 |
+
2026-02-01 11:20:51,630 [INFO] [batch_0081_384x384.parquet] Processed 1995 rows
|
| 95 |
+
2026-02-01 11:21:43,270 [INFO] [batch_0082_384x384.parquet] Processed 1989 rows
|
| 96 |
+
2026-02-01 11:22:29,228 [INFO] [batch_0083_384x384.parquet] Processed 1995 rows
|
| 97 |
+
2026-02-01 11:23:23,236 [INFO] [batch_0084_384x384.parquet] Processed 1996 rows
|
| 98 |
+
2026-02-01 11:24:32,532 [INFO] [batch_0085_384x384.parquet] Processed 1997 rows
|
| 99 |
+
2026-02-01 11:25:20,336 [INFO] [batch_0086_384x384.parquet] Processed 1996 rows
|
| 100 |
+
2026-02-01 11:26:33,616 [INFO] [batch_0087_384x384.parquet] Processed 1994 rows
|
| 101 |
+
2026-02-01 11:27:24,449 [INFO] [batch_0088_384x384.parquet] Processed 1992 rows
|
| 102 |
+
2026-02-01 11:28:20,047 [INFO] [batch_0089_384x384.parquet] Processed 1993 rows
|
| 103 |
+
2026-02-01 11:29:43,109 [INFO] [batch_0090_384x384.parquet] Processed 1993 rows
|
| 104 |
+
2026-02-01 11:30:41,652 [INFO] [batch_0091_384x384.parquet] Processed 1995 rows
|
| 105 |
+
2026-02-01 11:31:43,751 [INFO] [batch_0092_384x384.parquet] Processed 1994 rows
|
| 106 |
+
2026-02-01 11:33:10,661 [INFO] [batch_0093_384x384.parquet] Processed 1998 rows
|
| 107 |
+
2026-02-01 11:34:12,721 [INFO] [batch_0094_384x384.parquet] Processed 1993 rows
|
| 108 |
+
2026-02-01 11:35:09,887 [INFO] [batch_0095_384x384.parquet] Processed 1995 rows
|
| 109 |
+
2026-02-01 11:36:36,141 [INFO] [batch_0096_384x384.parquet] Processed 1997 rows
|
| 110 |
+
2026-02-01 11:37:41,740 [INFO] [batch_0097_384x384.parquet] Processed 1990 rows
|
| 111 |
+
2026-02-01 11:38:40,066 [INFO] [batch_0098_384x384.parquet] Processed 1995 rows
|
| 112 |
+
2026-02-01 11:39:45,765 [INFO] [batch_0099_384x384.parquet] Processed 1992 rows
|
| 113 |
+
2026-02-01 11:40:40,739 [INFO] [batch_0100_384x384.parquet] Processed 1993 rows
|
| 114 |
+
2026-02-01 11:41:41,583 [INFO] [batch_0101_384x384.parquet] Processed 1994 rows
|
| 115 |
+
2026-02-01 11:42:47,504 [INFO] [batch_0102_384x384.parquet] Processed 1991 rows
|
| 116 |
+
2026-02-01 11:43:31,148 [INFO] [batch_0103_384x384.parquet] Processed 1990 rows
|
| 117 |
+
2026-02-01 11:44:38,070 [INFO] [batch_0104_384x384.parquet] Processed 1995 rows
|
| 118 |
+
2026-02-01 11:45:48,089 [INFO] [batch_0105_384x384.parquet] Processed 1993 rows
|
| 119 |
+
2026-02-01 11:46:47,156 [INFO] [batch_0106_384x384.parquet] Processed 1988 rows
|
| 120 |
+
2026-02-01 11:48:06,340 [INFO] [batch_0107_384x384.parquet] Processed 1996 rows
|
| 121 |
+
2026-02-01 11:49:08,016 [INFO] [batch_0108_384x384.parquet] Processed 1992 rows
|
| 122 |
+
2026-02-01 11:50:27,665 [INFO] [batch_0109_384x384.parquet] Processed 1993 rows
|
| 123 |
+
2026-02-01 11:51:38,073 [INFO] [batch_0110_384x384.parquet] Processed 1996 rows
|
| 124 |
+
2026-02-01 11:52:26,956 [INFO] [batch_0111_384x384.parquet] Processed 1991 rows
|
| 125 |
+
2026-02-01 11:53:44,395 [INFO] [batch_0112_384x384.parquet] Processed 1994 rows
|
| 126 |
+
2026-02-01 11:54:23,803 [INFO] [batch_0113_384x384.parquet] Processed 1996 rows
|
| 127 |
+
2026-02-01 11:55:07,867 [INFO] [batch_0114_384x384.parquet] Processed 1997 rows
|
| 128 |
+
2026-02-01 11:55:54,834 [INFO] [batch_0115_384x384.parquet] Processed 1997 rows
|
| 129 |
+
2026-02-01 11:56:36,849 [INFO] [batch_0116_384x384.parquet] Processed 1992 rows
|
| 130 |
+
2026-02-01 11:57:20,506 [INFO] [batch_0117_384x384.parquet] Processed 1992 rows
|
| 131 |
+
2026-02-01 11:57:58,985 [INFO] [batch_0118_384x384.parquet] Processed 1993 rows
|
| 132 |
+
2026-02-01 11:58:38,965 [INFO] [batch_0119_384x384.parquet] Processed 1998 rows
|
| 133 |
+
2026-02-01 11:59:16,459 [INFO] [batch_0120_384x384.parquet] Processed 1995 rows
|
| 134 |
+
2026-02-01 11:59:55,497 [INFO] [batch_0121_384x384.parquet] Processed 1995 rows
|
| 135 |
+
2026-02-01 12:00:33,857 [INFO] [batch_0122_384x384.parquet] Processed 1992 rows
|
| 136 |
+
2026-02-01 12:01:15,871 [INFO] [batch_0123_384x384.parquet] Processed 1994 rows
|
| 137 |
+
2026-02-01 12:01:53,537 [INFO] [batch_0124_384x384.parquet] Processed 1995 rows
|
| 138 |
+
2026-02-01 12:02:22,334 [INFO] [batch_0125_384x384.parquet] Processed 1505 rows
|
| 139 |
+
2026-02-01 12:02:22,334 [INFO] Merging all results...
|
| 140 |
+
2026-02-01 12:02:22,384 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
|
| 141 |
+
2026-02-01 12:02:22,384 [INFO] ✓ All unwanted columns removed
|
| 142 |
+
2026-02-01 12:02:22,384 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
|
| 143 |
+
2026-02-01 12:02:22,384 [INFO] ✓ Column 'pixel_bbox' added
|
| 144 |
+
2026-02-01 12:02:22,384 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/farslip/FarSLIP_crop_384x384.parquet...
|
| 145 |
+
2026-02-01 12:02:25,588 [INFO] ================================================================================
|
| 146 |
+
2026-02-01 12:02:25,588 [INFO] Processing complete!
|
| 147 |
+
2026-02-01 12:02:25,588 [INFO] Total rows: 248,719
|
| 148 |
+
2026-02-01 12:02:25,589 [INFO] Embedding dimension: 512
|
| 149 |
+
2026-02-01 12:02:25,589 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/farslip/FarSLIP_crop_384x384.parquet
|
| 150 |
+
2026-02-01 12:02:25,589 [INFO] ================================================================================
|
logs/compute_embeddings_satclip.log
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-02-01 09:09:57,720 [INFO] ================================================================================
|
| 2 |
+
2026-02-01 09:09:57,720 [INFO] Computing SATCLIP embeddings
|
| 3 |
+
2026-02-01 09:09:57,720 [INFO] Timestamp: 2026-02-01T09:09:57.720447
|
| 4 |
+
2026-02-01 09:09:57,720 [INFO] Device: cuda:1
|
| 5 |
+
2026-02-01 09:09:57,720 [INFO] ================================================================================
|
| 6 |
+
2026-02-01 09:09:57,721 [INFO] Batch size: 128
|
| 7 |
+
2026-02-01 09:09:57,721 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
|
| 8 |
+
2026-02-01 09:09:57,721 [INFO] Loading satclip model...
|
| 9 |
+
2026-02-01 09:09:57,727 [INFO] SatCLIP-MS model loaded on cuda:1
|
| 10 |
+
2026-02-01 09:09:57,728 [INFO] Found 1 input files
|
| 11 |
+
2026-02-01 09:10:21,830 [WARNING] No valid embeddings for batch_0001_384x384.parquet
|
| 12 |
+
2026-02-01 09:10:22,107 [ERROR] No data processed!
|
| 13 |
+
2026-02-01 09:39:17,993 [INFO] ================================================================================
|
| 14 |
+
2026-02-01 09:39:17,993 [INFO] Computing SATCLIP embeddings
|
| 15 |
+
2026-02-01 09:39:17,993 [INFO] Timestamp: 2026-02-01T09:39:17.993775
|
| 16 |
+
2026-02-01 09:39:17,993 [INFO] Device: cuda:1
|
| 17 |
+
2026-02-01 09:39:17,993 [INFO] ================================================================================
|
| 18 |
+
2026-02-01 09:39:17,994 [INFO] Batch size: 128
|
| 19 |
+
2026-02-01 09:39:17,994 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
|
| 20 |
+
2026-02-01 09:39:17,994 [INFO] Loading satclip model...
|
| 21 |
+
2026-02-01 09:39:20,179 [INFO] SatCLIP-MS model loaded on cuda:1
|
| 22 |
+
2026-02-01 09:39:20,180 [INFO] Found 1 input files
|
| 23 |
+
2026-02-01 09:40:01,084 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
|
| 24 |
+
2026-02-01 09:40:01,084 [INFO] Merging all results...
|
| 25 |
+
2026-02-01 09:40:01,085 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
|
| 26 |
+
2026-02-01 09:40:01,085 [INFO] ✓ All unwanted columns removed
|
| 27 |
+
2026-02-01 09:40:01,085 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
|
| 28 |
+
2026-02-01 09:40:01,085 [INFO] ✓ Column 'pixel_bbox' added
|
| 29 |
+
2026-02-01 09:40:01,085 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet...
|
| 30 |
+
2026-02-01 09:40:01,134 [INFO] ================================================================================
|
| 31 |
+
2026-02-01 09:40:01,134 [INFO] Processing complete!
|
| 32 |
+
2026-02-01 09:40:01,134 [INFO] Total rows: 1,996
|
| 33 |
+
2026-02-01 09:40:01,134 [INFO] Embedding dimension: 256
|
| 34 |
+
2026-02-01 09:40:01,134 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
|
| 35 |
+
2026-02-01 09:40:01,134 [INFO] ================================================================================
|
| 36 |
+
2026-02-01 09:43:19,666 [INFO] ================================================================================
|
| 37 |
+
2026-02-01 09:43:19,666 [INFO] Computing SATCLIP embeddings
|
| 38 |
+
2026-02-01 09:43:19,666 [INFO] Timestamp: 2026-02-01T09:43:19.666577
|
| 39 |
+
2026-02-01 09:43:19,666 [INFO] Device: cuda:3
|
| 40 |
+
2026-02-01 09:43:19,666 [INFO] ================================================================================
|
| 41 |
+
2026-02-01 09:43:19,668 [INFO] Batch size: 128
|
| 42 |
+
2026-02-01 09:43:19,668 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
|
| 43 |
+
2026-02-01 09:43:19,668 [INFO] Loading satclip model...
|
| 44 |
+
2026-02-01 09:43:21,344 [INFO] SatCLIP-MS model loaded on cuda:3
|
| 45 |
+
2026-02-01 09:43:21,345 [INFO] Found 125 input files
|
| 46 |
+
2026-02-01 09:44:03,000 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
|
| 47 |
+
2026-02-01 09:44:46,041 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
|
| 48 |
+
2026-02-01 09:45:27,652 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
|
| 49 |
+
2026-02-01 09:46:15,446 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
|
| 50 |
+
2026-02-01 09:47:09,769 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
|
| 51 |
+
2026-02-01 09:47:59,773 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
|
| 52 |
+
2026-02-01 09:48:51,057 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
|
| 53 |
+
2026-02-01 09:49:34,202 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
|
| 54 |
+
2026-02-01 09:50:25,944 [INFO] [batch_0009_384x384.parquet] Processed 1993 rows
|
| 55 |
+
2026-02-01 09:51:09,586 [INFO] [batch_0010_384x384.parquet] Processed 1993 rows
|
| 56 |
+
2026-02-01 09:51:56,545 [INFO] [batch_0011_384x384.parquet] Processed 1990 rows
|
| 57 |
+
2026-02-01 09:52:44,526 [INFO] [batch_0012_384x384.parquet] Processed 1993 rows
|
| 58 |
+
2026-02-01 09:53:32,729 [INFO] [batch_0013_384x384.parquet] Processed 1993 rows
|
| 59 |
+
2026-02-01 09:54:14,312 [INFO] [batch_0014_384x384.parquet] Processed 1991 rows
|
| 60 |
+
2026-02-01 09:55:05,975 [INFO] [batch_0015_384x384.parquet] Processed 1993 rows
|
| 61 |
+
2026-02-01 09:55:57,268 [INFO] [batch_0016_384x384.parquet] Processed 1990 rows
|
| 62 |
+
2026-02-01 09:57:00,591 [INFO] [batch_0017_384x384.parquet] Processed 1991 rows
|
| 63 |
+
2026-02-01 09:57:48,464 [INFO] [batch_0018_384x384.parquet] Processed 1991 rows
|
| 64 |
+
2026-02-01 09:58:52,420 [INFO] [batch_0019_384x384.parquet] Processed 1996 rows
|
| 65 |
+
2026-02-01 10:00:04,202 [INFO] [batch_0020_384x384.parquet] Processed 1994 rows
|
| 66 |
+
2026-02-01 10:01:10,309 [INFO] [batch_0021_384x384.parquet] Processed 1996 rows
|
| 67 |
+
2026-02-01 10:02:15,265 [INFO] [batch_0022_384x384.parquet] Processed 1995 rows
|
| 68 |
+
2026-02-01 10:03:31,554 [INFO] [batch_0023_384x384.parquet] Processed 1992 rows
|
| 69 |
+
2026-02-01 10:04:40,240 [INFO] [batch_0024_384x384.parquet] Processed 1989 rows
|
| 70 |
+
2026-02-01 10:05:55,812 [INFO] [batch_0025_384x384.parquet] Processed 1993 rows
|
| 71 |
+
2026-02-01 10:07:00,366 [INFO] [batch_0026_384x384.parquet] Processed 1995 rows
|
| 72 |
+
2026-02-01 10:08:10,532 [INFO] [batch_0027_384x384.parquet] Processed 1997 rows
|
| 73 |
+
2026-02-01 10:09:11,505 [INFO] [batch_0028_384x384.parquet] Processed 1991 rows
|
| 74 |
+
2026-02-01 10:10:21,951 [INFO] [batch_0029_384x384.parquet] Processed 1992 rows
|
| 75 |
+
2026-02-01 10:11:30,988 [INFO] [batch_0030_384x384.parquet] Processed 1993 rows
|
| 76 |
+
2026-02-01 10:12:26,034 [INFO] [batch_0031_384x384.parquet] Processed 1988 rows
|
| 77 |
+
2026-02-01 10:13:36,732 [INFO] [batch_0032_384x384.parquet] Processed 1994 rows
|
| 78 |
+
2026-02-01 10:14:36,787 [INFO] [batch_0033_384x384.parquet] Processed 1992 rows
|
| 79 |
+
2026-02-01 10:15:36,921 [INFO] [batch_0034_384x384.parquet] Processed 1994 rows
|
| 80 |
+
2026-02-01 10:16:38,623 [INFO] [batch_0035_384x384.parquet] Processed 1994 rows
|
| 81 |
+
2026-02-01 10:17:27,583 [INFO] [batch_0036_384x384.parquet] Processed 1996 rows
|
| 82 |
+
2026-02-01 10:18:29,976 [INFO] [batch_0037_384x384.parquet] Processed 1995 rows
|
| 83 |
+
2026-02-01 10:19:26,843 [INFO] [batch_0038_384x384.parquet] Processed 1993 rows
|
| 84 |
+
2026-02-01 10:20:14,532 [INFO] [batch_0039_384x384.parquet] Processed 1989 rows
|
| 85 |
+
2026-02-01 10:21:13,694 [INFO] [batch_0040_384x384.parquet] Processed 1990 rows
|
| 86 |
+
2026-02-01 10:22:05,858 [INFO] [batch_0041_384x384.parquet] Processed 1998 rows
|
| 87 |
+
2026-02-01 10:23:04,226 [INFO] [batch_0042_384x384.parquet] Processed 1997 rows
|
| 88 |
+
2026-02-01 10:23:56,641 [INFO] [batch_0043_384x384.parquet] Processed 1988 rows
|
| 89 |
+
2026-02-01 10:24:38,594 [INFO] [batch_0044_384x384.parquet] Processed 1991 rows
|
| 90 |
+
2026-02-01 10:25:42,517 [INFO] [batch_0045_384x384.parquet] Processed 1993 rows
|
| 91 |
+
2026-02-01 10:26:23,732 [INFO] [batch_0046_384x384.parquet] Processed 1994 rows
|
| 92 |
+
2026-02-01 10:27:39,298 [INFO] [batch_0047_384x384.parquet] Processed 1995 rows
|
| 93 |
+
2026-02-01 10:28:34,546 [INFO] [batch_0048_384x384.parquet] Processed 1992 rows
|
| 94 |
+
2026-02-01 10:29:35,568 [INFO] [batch_0049_384x384.parquet] Processed 1996 rows
|
| 95 |
+
2026-02-01 10:30:38,004 [INFO] [batch_0050_384x384.parquet] Processed 1991 rows
|
| 96 |
+
2026-02-01 10:31:50,544 [INFO] [batch_0051_384x384.parquet] Processed 1997 rows
|
| 97 |
+
2026-02-01 10:32:38,165 [INFO] [batch_0052_384x384.parquet] Processed 1993 rows
|
| 98 |
+
2026-02-01 10:33:54,330 [INFO] [batch_0053_384x384.parquet] Processed 1995 rows
|
| 99 |
+
2026-02-01 10:35:11,070 [INFO] [batch_0054_384x384.parquet] Processed 1997 rows
|
| 100 |
+
2026-02-01 10:36:06,495 [INFO] [batch_0055_384x384.parquet] Processed 1995 rows
|
| 101 |
+
2026-02-01 10:37:26,449 [INFO] [batch_0056_384x384.parquet] Processed 1997 rows
|
| 102 |
+
2026-02-01 10:38:40,433 [INFO] [batch_0057_384x384.parquet] Processed 1991 rows
|
| 103 |
+
2026-02-01 10:39:36,229 [INFO] [batch_0058_384x384.parquet] Processed 1994 rows
|
| 104 |
+
2026-02-01 10:40:50,558 [INFO] [batch_0059_384x384.parquet] Processed 1993 rows
|
| 105 |
+
2026-02-01 10:42:00,100 [INFO] [batch_0060_384x384.parquet] Processed 1995 rows
|
| 106 |
+
2026-02-01 10:42:53,440 [INFO] [batch_0061_384x384.parquet] Processed 1995 rows
|
| 107 |
+
2026-02-01 10:44:21,706 [INFO] [batch_0062_384x384.parquet] Processed 1998 rows
|
| 108 |
+
2026-02-01 10:45:56,656 [INFO] [batch_0063_384x384.parquet] Processed 1997 rows
|
| 109 |
+
2026-02-01 10:46:53,942 [INFO] [batch_0064_384x384.parquet] Processed 1992 rows
|
| 110 |
+
2026-02-01 10:47:47,760 [INFO] [batch_0065_384x384.parquet] Processed 1994 rows
|
| 111 |
+
2026-02-01 10:48:37,571 [INFO] [batch_0066_384x384.parquet] Processed 1992 rows
|
| 112 |
+
2026-02-01 10:50:00,819 [INFO] [batch_0067_384x384.parquet] Processed 1993 rows
|
| 113 |
+
2026-02-01 10:51:30,799 [INFO] [batch_0068_384x384.parquet] Processed 1994 rows
|
| 114 |
+
2026-02-01 10:52:28,413 [INFO] [batch_0069_384x384.parquet] Processed 1992 rows
|
| 115 |
+
2026-02-01 10:53:50,597 [INFO] [batch_0070_384x384.parquet] Processed 1997 rows
|
| 116 |
+
2026-02-01 10:55:01,173 [INFO] [batch_0071_384x384.parquet] Processed 1996 rows
|
| 117 |
+
2026-02-01 10:56:03,395 [INFO] [batch_0072_384x384.parquet] Processed 1992 rows
|
| 118 |
+
2026-02-01 10:57:10,601 [INFO] [batch_0073_384x384.parquet] Processed 1995 rows
|
| 119 |
+
2026-02-01 10:58:22,789 [INFO] [batch_0074_384x384.parquet] Processed 1992 rows
|
| 120 |
+
2026-02-01 10:59:39,697 [INFO] [batch_0075_384x384.parquet] Processed 1991 rows
|
| 121 |
+
2026-02-01 11:00:48,962 [INFO] [batch_0076_384x384.parquet] Processed 1998 rows
|
| 122 |
+
2026-02-01 11:01:59,729 [INFO] [batch_0077_384x384.parquet] Processed 1996 rows
|
| 123 |
+
2026-02-01 11:03:01,575 [INFO] [batch_0078_384x384.parquet] Processed 1992 rows
|
| 124 |
+
2026-02-01 11:04:15,721 [INFO] [batch_0079_384x384.parquet] Processed 1995 rows
|
| 125 |
+
2026-02-01 11:05:26,147 [INFO] [batch_0080_384x384.parquet] Processed 1993 rows
|
| 126 |
+
2026-02-01 11:06:21,742 [INFO] [batch_0081_384x384.parquet] Processed 1995 rows
|
| 127 |
+
2026-02-01 11:07:34,071 [INFO] [batch_0082_384x384.parquet] Processed 1989 rows
|
| 128 |
+
2026-02-01 11:08:51,443 [INFO] [batch_0083_384x384.parquet] Processed 1995 rows
|
| 129 |
+
2026-02-01 11:09:45,289 [INFO] [batch_0084_384x384.parquet] Processed 1996 rows
|
| 130 |
+
2026-02-01 11:10:59,507 [INFO] [batch_0085_384x384.parquet] Processed 1997 rows
|
| 131 |
+
2026-02-01 11:12:12,671 [INFO] [batch_0086_384x384.parquet] Processed 1996 rows
|
| 132 |
+
2026-02-01 11:13:16,945 [INFO] [batch_0087_384x384.parquet] Processed 1994 rows
|
| 133 |
+
2026-02-01 11:14:26,324 [INFO] [batch_0088_384x384.parquet] Processed 1992 rows
|
| 134 |
+
2026-02-01 11:15:25,871 [INFO] [batch_0089_384x384.parquet] Processed 1993 rows
|
| 135 |
+
2026-02-01 11:16:43,653 [INFO] [batch_0090_384x384.parquet] Processed 1993 rows
|
| 136 |
+
2026-02-01 11:17:52,205 [INFO] [batch_0091_384x384.parquet] Processed 1995 rows
|
| 137 |
+
2026-02-01 11:19:02,073 [INFO] [batch_0092_384x384.parquet] Processed 1994 rows
|
| 138 |
+
2026-02-01 11:20:14,843 [INFO] [batch_0093_384x384.parquet] Processed 1998 rows
|
| 139 |
+
2026-02-01 11:21:09,193 [INFO] [batch_0094_384x384.parquet] Processed 1993 rows
|
| 140 |
+
2026-02-01 11:22:03,303 [INFO] [batch_0095_384x384.parquet] Processed 1995 rows
|
| 141 |
+
2026-02-01 11:23:12,708 [INFO] [batch_0096_384x384.parquet] Processed 1997 rows
|
| 142 |
+
2026-02-01 11:24:18,831 [INFO] [batch_0097_384x384.parquet] Processed 1990 rows
|
| 143 |
+
2026-02-01 11:25:07,701 [INFO] [batch_0098_384x384.parquet] Processed 1995 rows
|
| 144 |
+
2026-02-01 11:26:18,306 [INFO] [batch_0099_384x384.parquet] Processed 1992 rows
|
| 145 |
+
2026-02-01 11:27:02,698 [INFO] [batch_0100_384x384.parquet] Processed 1993 rows
|
| 146 |
+
2026-02-01 11:28:08,644 [INFO] [batch_0101_384x384.parquet] Processed 1994 rows
|
| 147 |
+
2026-02-01 11:29:33,678 [INFO] [batch_0102_384x384.parquet] Processed 1991 rows
|
| 148 |
+
2026-02-01 11:30:25,760 [INFO] [batch_0103_384x384.parquet] Processed 1990 rows
|
| 149 |
+
2026-02-01 11:31:38,365 [INFO] [batch_0104_384x384.parquet] Processed 1995 rows
|
| 150 |
+
2026-02-01 11:33:06,206 [INFO] [batch_0105_384x384.parquet] Processed 1993 rows
|
| 151 |
+
2026-02-01 11:33:59,497 [INFO] [batch_0106_384x384.parquet] Processed 1988 rows
|
| 152 |
+
2026-02-01 11:35:04,565 [INFO] [batch_0107_384x384.parquet] Processed 1996 rows
|
| 153 |
+
2026-02-01 11:36:30,898 [INFO] [batch_0108_384x384.parquet] Processed 1992 rows
|
| 154 |
+
2026-02-01 11:37:34,766 [INFO] [batch_0109_384x384.parquet] Processed 1993 rows
|
| 155 |
+
2026-02-01 11:38:36,780 [INFO] [batch_0110_384x384.parquet] Processed 1996 rows
|
| 156 |
+
2026-02-01 11:39:53,826 [INFO] [batch_0111_384x384.parquet] Processed 1991 rows
|
| 157 |
+
2026-02-01 11:40:48,014 [INFO] [batch_0112_384x384.parquet] Processed 1994 rows
|
| 158 |
+
2026-02-01 11:41:49,113 [INFO] [batch_0113_384x384.parquet] Processed 1996 rows
|
| 159 |
+
2026-02-01 11:42:56,188 [INFO] [batch_0114_384x384.parquet] Processed 1997 rows
|
| 160 |
+
2026-02-01 11:43:43,288 [INFO] [batch_0115_384x384.parquet] Processed 1997 rows
|
| 161 |
+
2026-02-01 11:44:48,748 [INFO] [batch_0116_384x384.parquet] Processed 1992 rows
|
| 162 |
+
2026-02-01 11:45:54,394 [INFO] [batch_0117_384x384.parquet] Processed 1992 rows
|
| 163 |
+
2026-02-01 11:46:53,275 [INFO] [batch_0118_384x384.parquet] Processed 1993 rows
|
| 164 |
+
2026-02-01 11:48:08,611 [INFO] [batch_0119_384x384.parquet] Processed 1998 rows
|
| 165 |
+
2026-02-01 11:49:07,195 [INFO] [batch_0120_384x384.parquet] Processed 1995 rows
|
| 166 |
+
2026-02-01 11:50:22,347 [INFO] [batch_0121_384x384.parquet] Processed 1995 rows
|
| 167 |
+
2026-02-01 11:51:26,391 [INFO] [batch_0122_384x384.parquet] Processed 1992 rows
|
| 168 |
+
2026-02-01 11:52:22,734 [INFO] [batch_0123_384x384.parquet] Processed 1994 rows
|
| 169 |
+
2026-02-01 11:53:34,357 [INFO] [batch_0124_384x384.parquet] Processed 1995 rows
|
| 170 |
+
2026-02-01 11:54:05,024 [INFO] [batch_0125_384x384.parquet] Processed 1505 rows
|
| 171 |
+
2026-02-01 11:54:05,024 [INFO] Merging all results...
|
| 172 |
+
2026-02-01 11:54:05,057 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
|
| 173 |
+
2026-02-01 11:54:05,058 [INFO] ✓ All unwanted columns removed
|
| 174 |
+
2026-02-01 11:54:05,058 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
|
| 175 |
+
2026-02-01 11:54:05,058 [INFO] ✓ Column 'pixel_bbox' added
|
| 176 |
+
2026-02-01 11:54:05,058 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet...
|
| 177 |
+
2026-02-01 11:54:06,861 [INFO] ================================================================================
|
| 178 |
+
2026-02-01 11:54:06,861 [INFO] Processing complete!
|
| 179 |
+
2026-02-01 11:54:06,861 [INFO] Total rows: 248,719
|
| 180 |
+
2026-02-01 11:54:06,862 [INFO] Embedding dimension: 256
|
| 181 |
+
2026-02-01 11:54:06,862 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/satclip/SatCLIP_crop_384x384.parquet
|
| 182 |
+
2026-02-01 11:54:06,862 [INFO] ================================================================================
|
logs/compute_embeddings_siglip.log
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-02-01 09:43:14,001 [INFO] ================================================================================
|
| 2 |
+
2026-02-01 09:43:14,002 [INFO] Computing SIGLIP embeddings
|
| 3 |
+
2026-02-01 09:43:14,002 [INFO] Timestamp: 2026-02-01T09:43:14.002069
|
| 4 |
+
2026-02-01 09:43:14,002 [INFO] Device: cuda:2
|
| 5 |
+
2026-02-01 09:43:14,002 [INFO] ================================================================================
|
| 6 |
+
2026-02-01 09:43:14,003 [INFO] Batch size: 64
|
| 7 |
+
2026-02-01 09:43:14,003 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet
|
| 8 |
+
2026-02-01 09:43:14,004 [INFO] Loading siglip model...
|
| 9 |
+
2026-02-01 09:43:14,196 [INFO] Parsing model identifier. Schema: None, Identifier: ViT-SO400M-14-SigLIP-384
|
| 10 |
+
2026-02-01 09:43:14,196 [INFO] Loaded built-in ViT-SO400M-14-SigLIP-384 model config.
|
| 11 |
+
2026-02-01 09:43:14,197 [INFO] `pretrained` specifies file path: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
|
| 12 |
+
2026-02-01 09:43:14,197 [INFO] Instantiating model architecture: CustomTextCLIP
|
| 13 |
+
2026-02-01 09:43:22,955 [INFO] Loading full pretrained weights from: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
|
| 14 |
+
2026-02-01 09:43:24,815 [INFO] Final image preprocessing configuration set: {'size': (384, 384), 'mode': 'RGB', 'mean': (0.48145466, 0.4578275, 0.40821073), 'std': (0.26862954, 0.26130258, 0.27577711), 'interpolation': 'bicubic', 'resize_mode': 'shortest', 'fill_color': 0}
|
| 15 |
+
2026-02-01 09:43:24,815 [INFO] Model ViT-SO400M-14-SigLIP-384 creation process complete.
|
| 16 |
+
2026-02-01 09:43:25,908 [INFO] SigLIP model loaded on cuda:2
|
| 17 |
+
2026-02-01 09:43:25,909 [INFO] Found 125 input files
|
| 18 |
+
2026-02-01 09:44:47,927 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
|
| 19 |
+
2026-02-01 09:46:05,633 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
|
| 20 |
+
2026-02-01 09:47:28,903 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
|
| 21 |
+
2026-02-01 09:48:39,715 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
|
| 22 |
+
2026-02-01 09:49:56,387 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
|
| 23 |
+
2026-02-01 09:51:18,436 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
|
| 24 |
+
2026-02-01 09:52:45,064 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
|
| 25 |
+
2026-02-01 09:54:13,231 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
|
| 26 |
+
2026-02-01 09:55:40,342 [INFO] ================================================================================
|
| 27 |
+
2026-02-01 09:55:40,343 [INFO] Computing SIGLIP embeddings
|
| 28 |
+
2026-02-01 09:55:40,343 [INFO] Timestamp: 2026-02-01T09:55:40.343045
|
| 29 |
+
2026-02-01 09:55:40,343 [INFO] Device: cuda:2
|
| 30 |
+
2026-02-01 09:55:40,343 [INFO] ================================================================================
|
| 31 |
+
2026-02-01 09:55:40,344 [INFO] Batch size: 256
|
| 32 |
+
2026-02-01 09:55:40,344 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet
|
| 33 |
+
2026-02-01 09:55:40,344 [INFO] Loading siglip model...
|
| 34 |
+
2026-02-01 09:55:40,494 [INFO] Parsing model identifier. Schema: None, Identifier: ViT-SO400M-14-SigLIP-384
|
| 35 |
+
2026-02-01 09:55:40,494 [INFO] Loaded built-in ViT-SO400M-14-SigLIP-384 model config.
|
| 36 |
+
2026-02-01 09:55:40,494 [INFO] `pretrained` specifies file path: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
|
| 37 |
+
2026-02-01 09:55:40,494 [INFO] Instantiating model architecture: CustomTextCLIP
|
| 38 |
+
2026-02-01 09:55:50,054 [INFO] Loading full pretrained weights from: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
|
| 39 |
+
2026-02-01 09:55:52,457 [INFO] Final image preprocessing configuration set: {'size': (384, 384), 'mode': 'RGB', 'mean': (0.48145466, 0.4578275, 0.40821073), 'std': (0.26862954, 0.26130258, 0.27577711), 'interpolation': 'bicubic', 'resize_mode': 'shortest', 'fill_color': 0}
|
| 40 |
+
2026-02-01 09:55:52,457 [INFO] Model ViT-SO400M-14-SigLIP-384 creation process complete.
|
| 41 |
+
2026-02-01 09:55:53,533 [INFO] SigLIP model loaded on cuda:2
|
| 42 |
+
2026-02-01 09:55:53,534 [INFO] Found 125 input files
|
| 43 |
+
2026-02-01 09:57:15,361 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
|
| 44 |
+
2026-02-01 09:58:38,916 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
|
| 45 |
+
2026-02-01 10:00:13,289 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
|
| 46 |
+
2026-02-01 10:01:38,351 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
|
| 47 |
+
2026-02-01 10:03:13,561 [INFO] [batch_0005_384x384.parquet] Processed 1992 rows
|
| 48 |
+
2026-02-01 10:04:55,295 [INFO] [batch_0006_384x384.parquet] Processed 1992 rows
|
| 49 |
+
2026-02-01 10:06:42,957 [INFO] [batch_0007_384x384.parquet] Processed 1998 rows
|
| 50 |
+
2026-02-01 10:08:27,547 [INFO] [batch_0008_384x384.parquet] Processed 1994 rows
|
| 51 |
+
2026-02-01 10:10:15,515 [INFO] [batch_0009_384x384.parquet] Processed 1993 rows
|
| 52 |
+
2026-02-01 10:11:54,632 [INFO] [batch_0010_384x384.parquet] Processed 1993 rows
|
| 53 |
+
2026-02-01 10:13:42,862 [INFO] [batch_0011_384x384.parquet] Processed 1990 rows
|
| 54 |
+
2026-02-01 10:15:23,412 [INFO] [batch_0012_384x384.parquet] Processed 1993 rows
|
| 55 |
+
2026-02-01 10:16:55,431 [INFO] [batch_0013_384x384.parquet] Processed 1993 rows
|
| 56 |
+
2026-02-01 10:18:30,326 [INFO] [batch_0014_384x384.parquet] Processed 1991 rows
|
| 57 |
+
2026-02-01 10:19:54,738 [INFO] [batch_0015_384x384.parquet] Processed 1993 rows
|
| 58 |
+
2026-02-01 10:21:25,001 [INFO] [batch_0016_384x384.parquet] Processed 1990 rows
|
| 59 |
+
2026-02-01 10:23:00,423 [INFO] [batch_0017_384x384.parquet] Processed 1991 rows
|
| 60 |
+
2026-02-01 10:24:21,837 [INFO] [batch_0018_384x384.parquet] Processed 1991 rows
|
| 61 |
+
2026-02-01 10:26:00,517 [INFO] [batch_0019_384x384.parquet] Processed 1996 rows
|
| 62 |
+
2026-02-01 10:27:39,553 [INFO] [batch_0020_384x384.parquet] Processed 1994 rows
|
| 63 |
+
2026-02-01 10:29:02,772 [INFO] [batch_0021_384x384.parquet] Processed 1996 rows
|
| 64 |
+
2026-02-01 10:30:43,286 [INFO] [batch_0022_384x384.parquet] Processed 1995 rows
|
| 65 |
+
2026-02-01 10:32:18,498 [INFO] [batch_0023_384x384.parquet] Processed 1992 rows
|
| 66 |
+
2026-02-01 10:33:59,552 [INFO] [batch_0024_384x384.parquet] Processed 1989 rows
|
| 67 |
+
2026-02-01 10:35:36,652 [INFO] [batch_0025_384x384.parquet] Processed 1993 rows
|
| 68 |
+
2026-02-01 10:37:22,505 [INFO] [batch_0026_384x384.parquet] Processed 1995 rows
|
| 69 |
+
2026-02-01 10:39:04,911 [INFO] [batch_0027_384x384.parquet] Processed 1997 rows
|
| 70 |
+
2026-02-01 10:40:47,184 [INFO] [batch_0028_384x384.parquet] Processed 1991 rows
|
| 71 |
+
2026-02-01 10:42:27,627 [INFO] [batch_0029_384x384.parquet] Processed 1992 rows
|
| 72 |
+
2026-02-01 10:43:40,600 [INFO] ================================================================================
|
| 73 |
+
2026-02-01 10:43:40,600 [INFO] Computing SIGLIP embeddings
|
| 74 |
+
2026-02-01 10:43:40,600 [INFO] Timestamp: 2026-02-01T10:43:40.600706
|
| 75 |
+
2026-02-01 10:43:40,600 [INFO] Device: cuda:5
|
| 76 |
+
2026-02-01 10:43:40,600 [INFO] ================================================================================
|
| 77 |
+
2026-02-01 10:43:40,602 [INFO] Batch size: 64
|
| 78 |
+
2026-02-01 10:43:40,602 [INFO] Output path: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet
|
| 79 |
+
2026-02-01 10:43:40,602 [INFO] Loading siglip model...
|
| 80 |
+
2026-02-01 10:43:40,778 [INFO] Parsing model identifier. Schema: None, Identifier: ViT-SO400M-14-SigLIP-384
|
| 81 |
+
2026-02-01 10:43:40,778 [INFO] Loaded built-in ViT-SO400M-14-SigLIP-384 model config.
|
| 82 |
+
2026-02-01 10:43:40,778 [INFO] `pretrained` specifies file path: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
|
| 83 |
+
2026-02-01 10:43:40,778 [INFO] Instantiating model architecture: CustomTextCLIP
|
| 84 |
+
2026-02-01 10:43:59,641 [INFO] Loading full pretrained weights from: /data1/zyj/checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin
|
| 85 |
+
2026-02-01 10:44:04,702 [INFO] Final image preprocessing configuration set: {'size': (384, 384), 'mode': 'RGB', 'mean': (0.48145466, 0.4578275, 0.40821073), 'std': (0.26862954, 0.26130258, 0.27577711), 'interpolation': 'bicubic', 'resize_mode': 'shortest', 'fill_color': 0}
|
| 86 |
+
2026-02-01 10:44:04,702 [INFO] Model ViT-SO400M-14-SigLIP-384 creation process complete.
|
| 87 |
+
2026-02-01 10:44:06,271 [INFO] SigLIP model loaded on cuda:5
|
| 88 |
+
2026-02-01 10:44:06,272 [INFO] Found 125 input files
|
| 89 |
+
2026-02-01 10:44:20,369 [INFO] [batch_0030_384x384.parquet] Processed 1993 rows
|
| 90 |
+
2026-02-01 10:45:59,867 [INFO] [batch_0001_384x384.parquet] Processed 1996 rows
|
| 91 |
+
2026-02-01 10:46:32,133 [INFO] [batch_0031_384x384.parquet] Processed 1988 rows
|
| 92 |
+
2026-02-01 10:47:08,397 [INFO] [batch_0002_384x384.parquet] Processed 1990 rows
|
| 93 |
+
2026-02-01 10:48:03,827 [INFO] [batch_0032_384x384.parquet] Processed 1994 rows
|
| 94 |
+
2026-02-01 10:48:20,770 [INFO] [batch_0003_384x384.parquet] Processed 1997 rows
|
| 95 |
+
2026-02-01 10:50:02,578 [INFO] [batch_0033_384x384.parquet] Processed 1992 rows
|
| 96 |
+
2026-02-01 10:50:06,189 [INFO] [batch_0004_384x384.parquet] Processed 1992 rows
|
| 97 |
+
2026-02-01 10:52:02,296 [INFO] [batch_0034_384x384.parquet] Processed 1994 rows
|
| 98 |
+
2026-02-01 10:53:52,804 [INFO] [batch_0035_384x384.parquet] Processed 1994 rows
|
| 99 |
+
2026-02-01 10:55:40,379 [INFO] [batch_0036_384x384.parquet] Processed 1996 rows
|
| 100 |
+
2026-02-01 10:57:08,912 [INFO] [batch_0037_384x384.parquet] Processed 1995 rows
|
| 101 |
+
2026-02-01 10:58:42,083 [INFO] [batch_0038_384x384.parquet] Processed 1993 rows
|
| 102 |
+
2026-02-01 11:00:31,963 [INFO] [batch_0039_384x384.parquet] Processed 1989 rows
|
| 103 |
+
2026-02-01 11:02:16,803 [INFO] [batch_0040_384x384.parquet] Processed 1990 rows
|
| 104 |
+
2026-02-01 11:04:12,580 [INFO] [batch_0041_384x384.parquet] Processed 1998 rows
|
| 105 |
+
2026-02-01 11:05:52,695 [INFO] [batch_0042_384x384.parquet] Processed 1997 rows
|
| 106 |
+
2026-02-01 11:07:38,215 [INFO] [batch_0043_384x384.parquet] Processed 1988 rows
|
| 107 |
+
2026-02-01 11:09:18,740 [INFO] [batch_0044_384x384.parquet] Processed 1991 rows
|
| 108 |
+
2026-02-01 11:10:59,852 [INFO] [batch_0045_384x384.parquet] Processed 1993 rows
|
| 109 |
+
2026-02-01 11:12:35,695 [INFO] [batch_0046_384x384.parquet] Processed 1994 rows
|
| 110 |
+
2026-02-01 11:14:12,998 [INFO] [batch_0047_384x384.parquet] Processed 1995 rows
|
| 111 |
+
2026-02-01 11:15:30,214 [INFO] [batch_0048_384x384.parquet] Processed 1992 rows
|
| 112 |
+
2026-02-01 11:17:05,225 [INFO] [batch_0049_384x384.parquet] Processed 1996 rows
|
| 113 |
+
2026-02-01 11:18:50,252 [INFO] [batch_0050_384x384.parquet] Processed 1991 rows
|
| 114 |
+
2026-02-01 11:20:25,931 [INFO] [batch_0051_384x384.parquet] Processed 1997 rows
|
| 115 |
+
2026-02-01 11:21:43,527 [INFO] [batch_0052_384x384.parquet] Processed 1993 rows
|
| 116 |
+
2026-02-01 11:23:12,150 [INFO] [batch_0053_384x384.parquet] Processed 1995 rows
|
| 117 |
+
2026-02-01 11:24:47,385 [INFO] [batch_0054_384x384.parquet] Processed 1997 rows
|
| 118 |
+
2026-02-01 11:26:31,520 [INFO] [batch_0055_384x384.parquet] Processed 1995 rows
|
| 119 |
+
2026-02-01 11:28:03,476 [INFO] [batch_0056_384x384.parquet] Processed 1997 rows
|
| 120 |
+
2026-02-01 11:29:48,548 [INFO] [batch_0057_384x384.parquet] Processed 1991 rows
|
| 121 |
+
2026-02-01 11:31:29,605 [INFO] [batch_0058_384x384.parquet] Processed 1994 rows
|
| 122 |
+
2026-02-01 11:33:17,760 [INFO] [batch_0059_384x384.parquet] Processed 1993 rows
|
| 123 |
+
2026-02-01 11:34:50,684 [INFO] [batch_0060_384x384.parquet] Processed 1995 rows
|
| 124 |
+
2026-02-01 11:36:38,080 [INFO] [batch_0061_384x384.parquet] Processed 1995 rows
|
| 125 |
+
2026-02-01 11:38:19,287 [INFO] [batch_0062_384x384.parquet] Processed 1998 rows
|
| 126 |
+
2026-02-01 11:40:01,382 [INFO] [batch_0063_384x384.parquet] Processed 1997 rows
|
| 127 |
+
2026-02-01 11:41:28,396 [INFO] [batch_0064_384x384.parquet] Processed 1992 rows
|
| 128 |
+
2026-02-01 11:43:07,187 [INFO] [batch_0065_384x384.parquet] Processed 1994 rows
|
| 129 |
+
2026-02-01 11:44:47,035 [INFO] [batch_0066_384x384.parquet] Processed 1992 rows
|
| 130 |
+
2026-02-01 11:46:38,657 [INFO] [batch_0067_384x384.parquet] Processed 1993 rows
|
| 131 |
+
2026-02-01 11:48:25,045 [INFO] [batch_0068_384x384.parquet] Processed 1994 rows
|
| 132 |
+
2026-02-01 11:50:24,090 [INFO] [batch_0069_384x384.parquet] Processed 1992 rows
|
| 133 |
+
2026-02-01 11:52:05,360 [INFO] [batch_0070_384x384.parquet] Processed 1997 rows
|
| 134 |
+
2026-02-01 11:53:51,383 [INFO] [batch_0071_384x384.parquet] Processed 1996 rows
|
| 135 |
+
2026-02-01 11:55:00,188 [INFO] [batch_0072_384x384.parquet] Processed 1992 rows
|
| 136 |
+
2026-02-01 11:56:16,122 [INFO] [batch_0073_384x384.parquet] Processed 1995 rows
|
| 137 |
+
2026-02-01 11:57:30,601 [INFO] [batch_0074_384x384.parquet] Processed 1992 rows
|
| 138 |
+
2026-02-01 11:58:47,717 [INFO] [batch_0075_384x384.parquet] Processed 1991 rows
|
| 139 |
+
2026-02-01 12:00:01,207 [INFO] [batch_0076_384x384.parquet] Processed 1998 rows
|
| 140 |
+
2026-02-01 12:01:14,471 [INFO] [batch_0077_384x384.parquet] Processed 1996 rows
|
| 141 |
+
2026-02-01 12:02:31,575 [INFO] [batch_0078_384x384.parquet] Processed 1992 rows
|
| 142 |
+
2026-02-01 12:03:52,303 [INFO] [batch_0079_384x384.parquet] Processed 1995 rows
|
| 143 |
+
2026-02-01 12:05:06,370 [INFO] [batch_0080_384x384.parquet] Processed 1993 rows
|
| 144 |
+
2026-02-01 12:06:16,989 [INFO] [batch_0081_384x384.parquet] Processed 1995 rows
|
| 145 |
+
2026-02-01 12:07:32,029 [INFO] [batch_0082_384x384.parquet] Processed 1989 rows
|
| 146 |
+
2026-02-01 12:08:47,568 [INFO] [batch_0083_384x384.parquet] Processed 1995 rows
|
| 147 |
+
2026-02-01 12:10:03,544 [INFO] [batch_0084_384x384.parquet] Processed 1996 rows
|
| 148 |
+
2026-02-01 12:11:20,376 [INFO] [batch_0085_384x384.parquet] Processed 1997 rows
|
| 149 |
+
2026-02-01 12:12:38,318 [INFO] [batch_0086_384x384.parquet] Processed 1996 rows
|
| 150 |
+
2026-02-01 12:13:56,314 [INFO] [batch_0087_384x384.parquet] Processed 1994 rows
|
| 151 |
+
2026-02-01 12:15:14,513 [INFO] [batch_0088_384x384.parquet] Processed 1992 rows
|
| 152 |
+
2026-02-01 12:16:32,334 [INFO] [batch_0089_384x384.parquet] Processed 1993 rows
|
| 153 |
+
2026-02-01 12:17:52,186 [INFO] [batch_0090_384x384.parquet] Processed 1993 rows
|
| 154 |
+
2026-02-01 12:19:10,443 [INFO] [batch_0091_384x384.parquet] Processed 1995 rows
|
| 155 |
+
2026-02-01 12:20:24,543 [INFO] [batch_0092_384x384.parquet] Processed 1994 rows
|
| 156 |
+
2026-02-01 12:21:42,150 [INFO] [batch_0093_384x384.parquet] Processed 1998 rows
|
| 157 |
+
2026-02-01 12:22:50,203 [INFO] [batch_0094_384x384.parquet] Processed 1993 rows
|
| 158 |
+
2026-02-01 12:24:08,849 [INFO] [batch_0095_384x384.parquet] Processed 1995 rows
|
| 159 |
+
2026-02-01 12:25:14,387 [INFO] [batch_0096_384x384.parquet] Processed 1997 rows
|
| 160 |
+
2026-02-01 12:26:27,496 [INFO] [batch_0097_384x384.parquet] Processed 1990 rows
|
| 161 |
+
2026-02-01 12:27:38,051 [INFO] [batch_0098_384x384.parquet] Processed 1995 rows
|
| 162 |
+
2026-02-01 12:28:46,151 [INFO] [batch_0099_384x384.parquet] Processed 1992 rows
|
| 163 |
+
2026-02-01 12:29:56,731 [INFO] [batch_0100_384x384.parquet] Processed 1993 rows
|
| 164 |
+
2026-02-01 12:31:13,328 [INFO] [batch_0101_384x384.parquet] Processed 1994 rows
|
| 165 |
+
2026-02-01 12:32:22,428 [INFO] [batch_0102_384x384.parquet] Processed 1991 rows
|
| 166 |
+
2026-02-01 12:33:34,185 [INFO] [batch_0103_384x384.parquet] Processed 1990 rows
|
| 167 |
+
2026-02-01 12:34:42,817 [INFO] [batch_0104_384x384.parquet] Processed 1995 rows
|
| 168 |
+
2026-02-01 12:35:53,075 [INFO] [batch_0105_384x384.parquet] Processed 1993 rows
|
| 169 |
+
2026-02-01 12:37:03,504 [INFO] [batch_0106_384x384.parquet] Processed 1988 rows
|
| 170 |
+
2026-02-01 12:38:12,118 [INFO] [batch_0107_384x384.parquet] Processed 1996 rows
|
| 171 |
+
2026-02-01 12:39:26,579 [INFO] [batch_0108_384x384.parquet] Processed 1992 rows
|
| 172 |
+
2026-02-01 12:40:38,968 [INFO] [batch_0109_384x384.parquet] Processed 1993 rows
|
| 173 |
+
2026-02-01 12:41:50,225 [INFO] [batch_0110_384x384.parquet] Processed 1996 rows
|
| 174 |
+
2026-02-01 12:42:59,782 [INFO] [batch_0111_384x384.parquet] Processed 1991 rows
|
| 175 |
+
2026-02-01 12:44:13,255 [INFO] [batch_0112_384x384.parquet] Processed 1994 rows
|
| 176 |
+
2026-02-01 12:45:25,863 [INFO] [batch_0113_384x384.parquet] Processed 1996 rows
|
| 177 |
+
2026-02-01 12:46:42,753 [INFO] [batch_0114_384x384.parquet] Processed 1997 rows
|
| 178 |
+
2026-02-01 12:47:54,480 [INFO] [batch_0115_384x384.parquet] Processed 1997 rows
|
| 179 |
+
2026-02-01 12:49:00,711 [INFO] [batch_0116_384x384.parquet] Processed 1992 rows
|
| 180 |
+
2026-02-01 12:50:12,844 [INFO] [batch_0117_384x384.parquet] Processed 1992 rows
|
| 181 |
+
2026-02-01 12:51:27,205 [INFO] [batch_0118_384x384.parquet] Processed 1993 rows
|
| 182 |
+
2026-02-01 12:52:36,479 [INFO] [batch_0119_384x384.parquet] Processed 1998 rows
|
| 183 |
+
2026-02-01 12:53:54,416 [INFO] [batch_0120_384x384.parquet] Processed 1995 rows
|
| 184 |
+
2026-02-01 12:55:03,501 [INFO] [batch_0121_384x384.parquet] Processed 1995 rows
|
| 185 |
+
2026-02-01 12:56:14,997 [INFO] [batch_0122_384x384.parquet] Processed 1992 rows
|
| 186 |
+
2026-02-01 12:57:29,495 [INFO] [batch_0123_384x384.parquet] Processed 1994 rows
|
| 187 |
+
2026-02-01 12:58:37,341 [INFO] [batch_0124_384x384.parquet] Processed 1995 rows
|
| 188 |
+
2026-02-01 12:59:35,927 [INFO] [batch_0125_384x384.parquet] Processed 1505 rows
|
| 189 |
+
2026-02-01 12:59:35,927 [INFO] Merging all results...
|
| 190 |
+
2026-02-01 12:59:35,965 [INFO] Final columns: ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'centre_lat', 'centre_lon', 'utm_crs', 'parquet_url', 'parquet_row', 'pixel_bbox', 'embedding']
|
| 191 |
+
2026-02-01 12:59:35,965 [INFO] ✓ All unwanted columns removed
|
| 192 |
+
2026-02-01 12:59:35,965 [INFO] ✓ Column 'crs' renamed to 'utm_crs'
|
| 193 |
+
2026-02-01 12:59:35,965 [INFO] ✓ Column 'pixel_bbox' added
|
| 194 |
+
2026-02-01 12:59:35,965 [INFO] Saving to /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet...
|
| 195 |
+
2026-02-01 12:59:44,647 [INFO] ================================================================================
|
| 196 |
+
2026-02-01 12:59:44,648 [INFO] Processing complete!
|
| 197 |
+
2026-02-01 12:59:44,648 [INFO] Total rows: 248,719
|
| 198 |
+
2026-02-01 12:59:44,648 [INFO] Embedding dimension: 1152
|
| 199 |
+
2026-02-01 12:59:44,648 [INFO] Output file: /data1/zyj/EarthEmbeddings/Core-S2L2A-249k/siglip/SigLIP_crop_384x384.parquet
|
| 200 |
+
2026-02-01 12:59:44,648 [INFO] ================================================================================
|
models/FarSLIP/.gitignore
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/logs/
|
| 2 |
+
**/wandb/
|
| 3 |
+
models/
|
| 4 |
+
features/
|
| 5 |
+
results/
|
| 6 |
+
src/open_clip_train/config.py
|
| 7 |
+
src/open_clip_train/output_samples/
|
| 8 |
+
**/results_retrieval/
|
| 9 |
+
**/results_classification/
|
| 10 |
+
checkpoints/
|
| 11 |
+
|
| 12 |
+
tests/data/
|
| 13 |
+
*.pt
|
| 14 |
+
|
| 15 |
+
# Byte-compiled / optimized / DLL files
|
| 16 |
+
__pycache__/
|
| 17 |
+
*.py[cod]
|
| 18 |
+
*$py.class
|
| 19 |
+
|
| 20 |
+
# C extensions
|
| 21 |
+
*.so
|
| 22 |
+
|
| 23 |
+
# Distribution / packaging
|
| 24 |
+
.Python
|
| 25 |
+
build/
|
| 26 |
+
develop-eggs/
|
| 27 |
+
dist/
|
| 28 |
+
downloads/
|
| 29 |
+
eggs/
|
| 30 |
+
.eggs/
|
| 31 |
+
lib/
|
| 32 |
+
lib64/
|
| 33 |
+
parts/
|
| 34 |
+
sdist/
|
| 35 |
+
var/
|
| 36 |
+
wheels/
|
| 37 |
+
pip-wheel-metadata/
|
| 38 |
+
share/python-wheels/
|
| 39 |
+
*.egg-info/
|
| 40 |
+
.installed.cfg
|
| 41 |
+
*.egg
|
| 42 |
+
MANIFEST
|
| 43 |
+
|
| 44 |
+
# PyInstaller
|
| 45 |
+
# Usually these files are written by a python script from a template
|
| 46 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 47 |
+
*.manifest
|
| 48 |
+
*.spec
|
| 49 |
+
|
| 50 |
+
# Installer logs
|
| 51 |
+
pip-log.txt
|
| 52 |
+
pip-delete-this-directory.txt
|
| 53 |
+
|
| 54 |
+
# Unit test / coverage reports
|
| 55 |
+
htmlcov/
|
| 56 |
+
.tox/
|
| 57 |
+
.nox/
|
| 58 |
+
.coverage
|
| 59 |
+
.coverage.*
|
| 60 |
+
.cache
|
| 61 |
+
nosetests.xml
|
| 62 |
+
coverage.xml
|
| 63 |
+
*.cover
|
| 64 |
+
*.py,cover
|
| 65 |
+
.hypothesis/
|
| 66 |
+
.pytest_cache/
|
| 67 |
+
|
| 68 |
+
# Translations
|
| 69 |
+
*.mo
|
| 70 |
+
*.pot
|
| 71 |
+
|
| 72 |
+
# Django stuff:
|
| 73 |
+
*.log
|
| 74 |
+
local_settings.py
|
| 75 |
+
db.sqlite3
|
| 76 |
+
db.sqlite3-journal
|
| 77 |
+
|
| 78 |
+
# Flask stuff:
|
| 79 |
+
instance/
|
| 80 |
+
.webassets-cache
|
| 81 |
+
|
| 82 |
+
# Scrapy stuff:
|
| 83 |
+
.scrapy
|
| 84 |
+
|
| 85 |
+
# Sphinx documentation
|
| 86 |
+
docs/_build/
|
| 87 |
+
|
| 88 |
+
# PyBuilder
|
| 89 |
+
target/
|
| 90 |
+
|
| 91 |
+
# Jupyter Notebook
|
| 92 |
+
.ipynb_checkpoints
|
| 93 |
+
|
| 94 |
+
# IPython
|
| 95 |
+
profile_default/
|
| 96 |
+
ipython_config.py
|
| 97 |
+
|
| 98 |
+
# pyenv
|
| 99 |
+
.python-version
|
| 100 |
+
|
| 101 |
+
# pipenv
|
| 102 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 103 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 104 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 105 |
+
# install all needed dependencies.
|
| 106 |
+
#Pipfile.lock
|
| 107 |
+
|
| 108 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 109 |
+
__pypackages__/
|
| 110 |
+
|
| 111 |
+
# Celery stuff
|
| 112 |
+
celerybeat-schedule
|
| 113 |
+
celerybeat.pid
|
| 114 |
+
|
| 115 |
+
# SageMath parsed files
|
| 116 |
+
*.sage.py
|
| 117 |
+
|
| 118 |
+
# Environments
|
| 119 |
+
.env
|
| 120 |
+
.venv
|
| 121 |
+
env/
|
| 122 |
+
venv/
|
| 123 |
+
ENV/
|
| 124 |
+
env.bak/
|
| 125 |
+
venv.bak/
|
| 126 |
+
|
| 127 |
+
# Spyder project settings
|
| 128 |
+
.spyderproject
|
| 129 |
+
.spyproject
|
| 130 |
+
|
| 131 |
+
# Rope project settings
|
| 132 |
+
.ropeproject
|
| 133 |
+
|
| 134 |
+
# mkdocs documentation
|
| 135 |
+
/site
|
| 136 |
+
|
| 137 |
+
# mypy
|
| 138 |
+
.mypy_cache/
|
| 139 |
+
.dmypy.json
|
| 140 |
+
dmypy.json
|
| 141 |
+
|
| 142 |
+
# Pyre type checker
|
| 143 |
+
.pyre/
|
| 144 |
+
sync.sh
|
| 145 |
+
gpu1sync.sh
|
| 146 |
+
.idea
|
| 147 |
+
*.pdf
|
| 148 |
+
**/._*
|
| 149 |
+
**/*DS_*
|
| 150 |
+
**.jsonl
|
| 151 |
+
src/sbatch
|
| 152 |
+
src/misc
|
| 153 |
+
.vscode
|
| 154 |
+
src/debug
|
| 155 |
+
core.*
|
| 156 |
+
|
| 157 |
+
*.out
|
| 158 |
+
|
| 159 |
+
# Allow
|
| 160 |
+
!src/evaluation/misc/results_dbs/*
|
models/FarSLIP/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 LHRS
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
models/FarSLIP/README.md
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center"> FarSLIP: Discovering Effective CLIP Adaptation for Fine-Grained Remote Sensing Understanding </h1>
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<a href="https://huggingface.co/datasets/ZhenShiL/MGRS-200k">
|
| 5 |
+
<img alt="Hugging Face Dataset" src="https://img.shields.io/badge/🤗%20Hugging%20Face-Dataset-blue">
|
| 6 |
+
</a>
|
| 7 |
+
<a href="https://huggingface.co/ZhenShiL/FarSLIP">
|
| 8 |
+
<img alt="Hugging Face Model" src="https://img.shields.io/badge/🤗%20Hugging%20Face-Model-yellow">
|
| 9 |
+
</a>
|
| 10 |
+
<a href="https://arxiv.org/abs/2511.14901">
|
| 11 |
+
<img alt="arXiv" src="https://img.shields.io/badge/arXiv-2511.14901-b31b1b">
|
| 12 |
+
</a>
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## Introduction
|
| 17 |
+
We introduce FarSLIP, a vision-language foundation model for remote sensing (RS) that achieves fine-grained vision-language alignment. FarSLIP demonstrates state-of-the-art performance on both fine-grained and image-level tasks, including open-vocabulary semantic segmentation, zero-shot classification, and image-text retrieval.
|
| 18 |
+
We also construct MGRS-200k, the first multi-granularity image-text dataset for RS. Each image is annotated with both short and long global-level captions, along with multiple object-category pairs.
|
| 19 |
+
|
| 20 |
+
<figure>
|
| 21 |
+
<div align="center">
|
| 22 |
+
<img src=assets/model.png width="60%">
|
| 23 |
+
</div>
|
| 24 |
+
</figure>
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## Table of Contents
|
| 28 |
+
- [Introduction](#Introduction)
|
| 29 |
+
- [Preparation](#Preparation)
|
| 30 |
+
- [Installation](#Installation)
|
| 31 |
+
- [Checkpoints](#Checkpoints)
|
| 32 |
+
- [Dataset](#Dataset)
|
| 33 |
+
- [Training](#Training)
|
| 34 |
+
- [Testing](#Testing)
|
| 35 |
+
- [Open-vocabulary semantic segmentation](#open-vocabulary-semantic-segmentation)
|
| 36 |
+
- [Zero-shot scene classification](#zero-shot-scene-classification)
|
| 37 |
+
- [Zero-shot image-text retrieval](#zero-shot-image-text-retrieval)
|
| 38 |
+
- [Acknowledgement](#Acknowledgement)
|
| 39 |
+
- [Citing](#Citing)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
## Preparation
|
| 46 |
+
|
| 47 |
+
### Installation
|
| 48 |
+
|
| 49 |
+
1. Clone this repository.
|
| 50 |
+
|
| 51 |
+
~~~shell
|
| 52 |
+
git clone git@github.com:NJU-LHRS/FarSLIP.git
|
| 53 |
+
cd FarSLIP
|
| 54 |
+
~~~
|
| 55 |
+
|
| 56 |
+
2. Create a new virtual environment.
|
| 57 |
+
|
| 58 |
+
~~~shell
|
| 59 |
+
conda create -n farslip python=3.10
|
| 60 |
+
conda activate farslip
|
| 61 |
+
~~~
|
| 62 |
+
|
| 63 |
+
3. Install dependences.
|
| 64 |
+
|
| 65 |
+
~~~shell
|
| 66 |
+
pip install -r requirements.txt
|
| 67 |
+
~~~
|
| 68 |
+
|
| 69 |
+
### Checkpoints
|
| 70 |
+
You can download all our checkpoints from [Huggingface](https://huggingface.co/ZhenShiL/FarSLIP), or selectively download them through the links below.
|
| 71 |
+
|
| 72 |
+
| Model name | ViT-arch. | Test encoder | OVSS mIoU (%) | ZSC top-1 acc. (%) | Download |
|
| 73 |
+
|-------------|-----------|--------------|----------------|--------------------|----------------|
|
| 74 |
+
| FarSLIP-s1 | ViT-B-32 | Vanilla | 29.87 | 58.64 | [FarSLIP1_ViT-B-32](https://huggingface.co/ZhenShiL/FarSLIP/resolve/main/FarSLIP1_ViT-B-32.pt?download=true) |
|
| 75 |
+
| FarSLIP-s1 | ViT-B-16 | LongCLIP | 35.44 | 61.89 | [FarSLIP1_ViT-B-16](https://huggingface.co/ZhenShiL/FarSLIP/resolve/main/FarSLIP1_ViT-B-16.pt?download=true) |
|
| 76 |
+
| FarSLIP-s2 | ViT-B-32 | Vanilla | 30.49 | 60.12 | [FarSLIP2_ViT-B-32](https://huggingface.co/ZhenShiL/FarSLIP/resolve/main/FarSLIP2_ViT-B-32.pt?download=true) |
|
| 77 |
+
| FarSLIP-s2 | ViT-B-16 | LongCLIP | 35.41 | 62.24 | [FarSLIP2_ViT-B-16](https://huggingface.co/ZhenShiL/FarSLIP/resolve/main/FarSLIP2_ViT-B-16.pt?download=true) |
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
### Dataset
|
| 81 |
+
FarSLIP is trained in two stages.
|
| 82 |
+
+ In the first stage, we use the [RS5M](https://github.com/om-ai-lab/RS5M) dataset. A quick portal to the RS5M dataset: [link](https://huggingface.co/datasets/omlab/RS5M).
|
| 83 |
+
+ In the second stage, we use the proposed MGRS-200k dataset, which is available on [Huggingface](https://huggingface.co/datasets/ZhenShiL/MGRS-200k).
|
| 84 |
+
|
| 85 |
+
[//]: # (<figure>)
|
| 86 |
+
|
| 87 |
+
[//]: # (<div align="center">)
|
| 88 |
+
|
| 89 |
+
[//]: # (<img src=assets/dataset.png width="80%">)
|
| 90 |
+
|
| 91 |
+
[//]: # (</div>)
|
| 92 |
+
|
| 93 |
+
[//]: # (<figcaption align="center"><em>Examples from MGRS-200k</em></figcaption>)
|
| 94 |
+
|
| 95 |
+
[//]: # (</figure>)
|
| 96 |
+
|
| 97 |
+
<p align="center">
|
| 98 |
+
<img src="assets/dataset.png" width="100%">
|
| 99 |
+
<br>
|
| 100 |
+
<em>Examples from MGRS-200k</em>
|
| 101 |
+
</p>
|
| 102 |
+
|
| 103 |
+
## Training
|
| 104 |
+
|
| 105 |
+
+ Validation data preparation
|
| 106 |
+
+ Replace --root-val-img-dir and --val-data in [config.py](./open_clip_train/config.py) with the paths to your [SkyScript](https://github.com/wangzhecheng/SkyScript?tab=readme-ov-file#download) validation dataset ('SkyScript_val_5K_filtered_by_CLIP_openai').
|
| 107 |
+
+ Stage1
|
| 108 |
+
~~~shell
|
| 109 |
+
torchrun --nproc_per_node=4 -m open_clip_train.main \
|
| 110 |
+
--train-dataset-name RS5M \
|
| 111 |
+
--train-data '/your/path/to/rs5m/{pub11,rs3}-train-{0000..0031}.tar' \
|
| 112 |
+
--train-dataset-type webdataset \
|
| 113 |
+
--train-num-samples 5070186 \
|
| 114 |
+
--method farslip1 \
|
| 115 |
+
--use-imagecrop-aug \
|
| 116 |
+
--local-method randomcrops \
|
| 117 |
+
--warmup 1000 \
|
| 118 |
+
--batch-size 40 \
|
| 119 |
+
--lr 1e-6 \
|
| 120 |
+
--wd 1.0 \
|
| 121 |
+
--epochs 1 \
|
| 122 |
+
--model ViT-B-16 \
|
| 123 |
+
--loss-type global_itc distill \
|
| 124 |
+
--distill-align roi2pooled
|
| 125 |
+
~~~
|
| 126 |
+
|
| 127 |
+
+ Stage2
|
| 128 |
+
~~~shell
|
| 129 |
+
torchrun --nproc_per_node=4 -m open_clip_train.main \
|
| 130 |
+
--train-dataset-name MGRS \
|
| 131 |
+
--root-train-img-dir '/your/path/to/mgrs/global_imgs/' \
|
| 132 |
+
--train-data '/your/path/to/mgrs/text_info.json' \
|
| 133 |
+
--train-dataset-type json \
|
| 134 |
+
--method farslip2 \
|
| 135 |
+
--warmup 250 \
|
| 136 |
+
--batch-size 40 \
|
| 137 |
+
--lr 4e-9 \
|
| 138 |
+
--wd 1.0 \
|
| 139 |
+
--epochs 10 \
|
| 140 |
+
--model ViT-B-16 \
|
| 141 |
+
--loss-type global_itc local_itc \
|
| 142 |
+
--local-itc-align cls
|
| 143 |
+
~~~
|
| 144 |
+
|
| 145 |
+
## Testing
|
| 146 |
+
### Open-vocabulary semantic segmentation
|
| 147 |
+
+ Please checkout [FarSLIP-OVSS](https://github.com/NJU-LHRS/FarSLIP-OVSS) for evaluation of open-vocabulary semantic segmentation in RS images.
|
| 148 |
+
|
| 149 |
+
<p align="center">
|
| 150 |
+
<img src="assets/ovss.png" width="100%">
|
| 151 |
+
<br>
|
| 152 |
+
<em>
|
| 153 |
+
OVSS accuracies across RS benchmarks (mIoU, %). G denotes general-domain models, and RS refers to RS-specific models.
|
| 154 |
+
f. indicates models specifically designed with fine-grained optimization. All models use an input image size of 224, except TIPS (448)
|
| 155 |
+
</em>
|
| 156 |
+
</p>
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
### Zero-shot scene classification
|
| 161 |
+
+ Please refer to [SkyScript](https://github.com/wangzhecheng/SkyScript?tab=readme-ov-file#download-benchmark-datasets) for scene classification dataset preparation, including 'SkyScript_cls', 'aid', 'eurosat', 'fmow', 'millionaid', 'patternnet', 'rsicb', 'nwpu'.
|
| 162 |
+
+ Replace the BENCHMARK_DATASET_ROOT_DIR in [tests/test_scene_classification.py](./tests/test_scene_classification.py) to your own path.
|
| 163 |
+
|
| 164 |
+
+ Run testing:
|
| 165 |
+
+ FarSLIP-s1
|
| 166 |
+
```
|
| 167 |
+
python -m tests.test_scene_classification --model-arch $VIT --model-name FarSLIP1 --force-quick-gelu --pretrained checkpoints/FarSLIP1_$VIT.pt
|
| 168 |
+
```
|
| 169 |
+
<!-- + FarSLIP-s2 with vanilla CLIP text encoder
|
| 170 |
+
```
|
| 171 |
+
python -m tests.test_scene_classification --model-arch $VIT --model-name FarSLIP2_VC --force-quick-gelu --pretrained checkpoints/FarSLIP2_VC_$VIT.pt
|
| 172 |
+
``` -->
|
| 173 |
+
+ FarSLIP-s2 with LongCLIP text encoder (supporting long text)
|
| 174 |
+
```
|
| 175 |
+
python -m tests.test_scene_classification --model-arch $VIT --model-name FarSLIP2 --force-quick-gelu --pretrained checkpoints/FarSLIP2_$VIT.pt --use-long-clip
|
| 176 |
+
```
|
| 177 |
+
- `$VIT` options: `ViT-B-16`, `ViT-B-32`
|
| 178 |
+
|
| 179 |
+
<figure>
|
| 180 |
+
<div align="center">
|
| 181 |
+
<img src=assets/classification.png width="100%">
|
| 182 |
+
</div>
|
| 183 |
+
<figcaption align="center">
|
| 184 |
+
<em>Comparison of zero-shot classification accuracies (Top-1 acc., %) of different RS-specific CLIP variants across multiple benchmarks.</em>
|
| 185 |
+
</figcaption>
|
| 186 |
+
</figure>
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
### Zero-shot image-text retrieval
|
| 190 |
+
+ Please refer to [SkyScript](https://github.com/wangzhecheng/SkyScript?tab=readme-ov-file#download-benchmark-datasets) for image-text retrieval dataset preparation, including 'RSICD', 'RSITMD', 'ucmcaptions', and ['SkyScript-retrieval'](https://github.com/wangzhecheng/SkyScript?tab=readme-ov-file#download) ('SkyScript_test_30K_filtered_by_CLIP_openai.csv').
|
| 191 |
+
+ Replace the DATA_CSV_PATH_DICT, SKYSCRIPT_IMAGE_DIR, RETRIEVAL_IMAGE_DIR in [tests/test_retrieval.py](./tests/test_retrieval.py) to your own path.
|
| 192 |
+
|
| 193 |
+
+ Run testing:
|
| 194 |
+
+ FarSLIP-s1
|
| 195 |
+
```
|
| 196 |
+
python -m tests.test_retrieval --model-arch $VIT --model-name FarSLIP1 --force-quick-gelu --pretrained checkpoints/FarSLIP1_$VIT.pt
|
| 197 |
+
```
|
| 198 |
+
<!-- + FarSLIP-s2 with vanilla CLIP text encoder
|
| 199 |
+
```
|
| 200 |
+
python -m tests.test_retrieval --model-arch $VIT --model-name FarSLIP2_VC --force-quick-gelu --pretrained checkpoints/FarSLIP2_VC_$VIT.pt
|
| 201 |
+
``` -->
|
| 202 |
+
+ FarSLIP-s2 with LongCLIP text encoder (supporting long text)
|
| 203 |
+
```
|
| 204 |
+
python -m tests.test_retrieval --model-arch $VIT --model-name FarSLIP2 --force-quick-gelu --pretrained checkpoints/FarSLIP2_$VIT.pt --use-long-clip
|
| 205 |
+
```
|
| 206 |
+
- `$VIT` options: `ViT-B-16`, `ViT-B-32`
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
<div align="center">
|
| 210 |
+
<img src=assets/retrieval.png width="50%">
|
| 211 |
+
</div>
|
| 212 |
+
<figcaption align="center">
|
| 213 |
+
<em>Comparison of cross-modal retrieval accuracies (%) of different RS-specific CLIP variants across multiple benchmarks. *
|
| 214 |
+
indicates models trained with in-hold supervision.</em>
|
| 215 |
+
</figcaption>
|
| 216 |
+
</figure>
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
## Acknowledgement
|
| 222 |
+
|
| 223 |
+
+ We gratitude to the following repositories for their wonderful works: [Open-CLIP](https://github.com/mlfoundations/open_clip), [CLIPSelf](https://github.com/wusize/CLIPSelf), [FineCLIP](https://github.com/Timsty1/FineCLIP), [Long-CLIP](https://github.com/beichenzbc/Long-CLIP), [SkyScript](https://github.com/wangzhecheng/SkyScript), [SegEarth](https://github.com/likyoo/SegEarth-OV).
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
## Citing
|
| 227 |
+
|
| 228 |
+
+ If you find our work is useful, please give us 🌟 in GitHub and consider cite our paper:
|
| 229 |
+
|
| 230 |
+
~~~tex
|
| 231 |
+
@article{li2025farslip,
|
| 232 |
+
title={FarSLIP: Discovering Effective CLIP Adaptation for Fine-Grained Remote Sensing Understanding},
|
| 233 |
+
author={Zhenshi Li and Weikang Yu and Dilxat Muhtar and Xueliang Zhang and Pengfeng Xiao and Pedram Ghamisi and Xiao Xiang Zhu},
|
| 234 |
+
journal={arXiv preprint arXiv:2511.14901},
|
| 235 |
+
year={2025}
|
| 236 |
+
}
|
| 237 |
+
~~~
|
models/FarSLIP/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .open_clip import *
|
models/FarSLIP/open_clip/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .version import __version__
|
| 2 |
+
|
| 3 |
+
from .coca_model import CoCa
|
| 4 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
| 5 |
+
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
| 6 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
| 7 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
| 8 |
+
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
| 9 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
|
| 10 |
+
get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
|
| 11 |
+
from .openai import load_openai_model, list_openai_models
|
| 12 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
| 13 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
| 14 |
+
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
| 15 |
+
from .tokenizer import SimpleTokenizer, tokenize, decode
|
| 16 |
+
from .transform import image_transform, AugmentationCfg
|
| 17 |
+
from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
|
| 18 |
+
from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
|
models/FarSLIP/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
models/FarSLIP/open_clip/coca_model.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
from .transformer import (
|
| 10 |
+
LayerNormFp32,
|
| 11 |
+
LayerNorm,
|
| 12 |
+
QuickGELU,
|
| 13 |
+
MultimodalTransformer,
|
| 14 |
+
)
|
| 15 |
+
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from transformers import (
|
| 19 |
+
BeamSearchScorer,
|
| 20 |
+
LogitsProcessorList,
|
| 21 |
+
TopPLogitsWarper,
|
| 22 |
+
TopKLogitsWarper,
|
| 23 |
+
RepetitionPenaltyLogitsProcessor,
|
| 24 |
+
MinLengthLogitsProcessor,
|
| 25 |
+
MaxLengthCriteria,
|
| 26 |
+
StopStringCriteria,
|
| 27 |
+
EosTokenCriteria,
|
| 28 |
+
StoppingCriteriaList
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
GENERATION_TYPES = {
|
| 32 |
+
"top_k": TopKLogitsWarper,
|
| 33 |
+
"top_p": TopPLogitsWarper,
|
| 34 |
+
"beam_search": "beam_search"
|
| 35 |
+
}
|
| 36 |
+
_has_transformers = True
|
| 37 |
+
except ImportError as e:
|
| 38 |
+
GENERATION_TYPES = {
|
| 39 |
+
"top_k": None,
|
| 40 |
+
"top_p": None,
|
| 41 |
+
"beam_search": "beam_search"
|
| 42 |
+
}
|
| 43 |
+
_has_transformers = False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class MultimodalCfg(CLIPTextCfg):
|
| 48 |
+
mlp_ratio: int = 4
|
| 49 |
+
dim_head: int = 64
|
| 50 |
+
heads: int = 8
|
| 51 |
+
n_queries: int = 256
|
| 52 |
+
attn_pooler_heads: int = 8
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _build_text_decoder_tower(
|
| 56 |
+
embed_dim,
|
| 57 |
+
multimodal_cfg,
|
| 58 |
+
quick_gelu: bool = False,
|
| 59 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 60 |
+
):
|
| 61 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
| 62 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 63 |
+
norm_layer = (
|
| 64 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
decoder = MultimodalTransformer(
|
| 68 |
+
context_length=multimodal_cfg.context_length,
|
| 69 |
+
width=multimodal_cfg.width,
|
| 70 |
+
heads=multimodal_cfg.heads,
|
| 71 |
+
layers=multimodal_cfg.layers,
|
| 72 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
| 73 |
+
output_dim=embed_dim,
|
| 74 |
+
act_layer=act_layer,
|
| 75 |
+
norm_layer=norm_layer,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return decoder
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
|
| 82 |
+
if not isinstance(token_id, torch.Tensor):
|
| 83 |
+
if isinstance(token_id, int):
|
| 84 |
+
token_id = [token_id]
|
| 85 |
+
token_id = torch.tensor(token_id, device=device)
|
| 86 |
+
return token_id
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class CoCa(nn.Module):
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
embed_dim,
|
| 93 |
+
multimodal_cfg: MultimodalCfg,
|
| 94 |
+
text_cfg: CLIPTextCfg,
|
| 95 |
+
vision_cfg: CLIPVisionCfg,
|
| 96 |
+
quick_gelu: bool = False,
|
| 97 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
| 98 |
+
init_logit_bias: Optional[float] = None,
|
| 99 |
+
nonscalar_logit_scale: bool = False,
|
| 100 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 101 |
+
pad_id: int = 0,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
| 105 |
+
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
| 106 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
| 107 |
+
|
| 108 |
+
self.text = _build_text_tower(
|
| 109 |
+
embed_dim=embed_dim,
|
| 110 |
+
text_cfg=text_cfg,
|
| 111 |
+
quick_gelu=quick_gelu,
|
| 112 |
+
cast_dtype=cast_dtype,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
vocab_size = (
|
| 116 |
+
text_cfg.vocab_size # for hf models
|
| 117 |
+
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
| 118 |
+
else text_cfg.vocab_size
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.visual = _build_vision_tower(
|
| 122 |
+
embed_dim=embed_dim,
|
| 123 |
+
vision_cfg=vision_cfg,
|
| 124 |
+
quick_gelu=quick_gelu,
|
| 125 |
+
cast_dtype=cast_dtype,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.text_decoder = _build_text_decoder_tower(
|
| 129 |
+
vocab_size,
|
| 130 |
+
multimodal_cfg=multimodal_cfg,
|
| 131 |
+
quick_gelu=quick_gelu,
|
| 132 |
+
cast_dtype=cast_dtype,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
lshape = [1] if nonscalar_logit_scale else []
|
| 136 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
| 137 |
+
if init_logit_bias is not None:
|
| 138 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
| 139 |
+
else:
|
| 140 |
+
self.logit_bias = None
|
| 141 |
+
self.pad_id = pad_id
|
| 142 |
+
|
| 143 |
+
self.context_length = multimodal_cfg.context_length
|
| 144 |
+
|
| 145 |
+
@torch.jit.ignore
|
| 146 |
+
def set_grad_checkpointing(self, enable: bool = True):
|
| 147 |
+
self.visual.set_grad_checkpointing(enable)
|
| 148 |
+
self.text.set_grad_checkpointing(enable)
|
| 149 |
+
self.text_decoder.set_grad_checkpointing(enable)
|
| 150 |
+
|
| 151 |
+
def _encode_image(self, images, normalize: bool = True):
|
| 152 |
+
image_latent, tokens_embs = self.visual(images)
|
| 153 |
+
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
| 154 |
+
return image_latent, tokens_embs
|
| 155 |
+
|
| 156 |
+
def _encode_text(self, text, normalize: bool = True):
|
| 157 |
+
text_latent, token_emb = self.text(text)
|
| 158 |
+
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
| 159 |
+
return text_latent, token_emb
|
| 160 |
+
|
| 161 |
+
def encode_image(self, images, normalize: bool = True):
|
| 162 |
+
image_latent, _ = self._encode_image(images, normalize=normalize)
|
| 163 |
+
return image_latent
|
| 164 |
+
|
| 165 |
+
def encode_text(self, text, normalize: bool = True):
|
| 166 |
+
text_latent, _ = self._encode_text(text, normalize=normalize)
|
| 167 |
+
return text_latent
|
| 168 |
+
|
| 169 |
+
def forward_intermediates(
|
| 170 |
+
self,
|
| 171 |
+
image: Optional[torch.Tensor] = None,
|
| 172 |
+
text: Optional[torch.Tensor] = None,
|
| 173 |
+
image_indices: Optional[Union[int, List[int]]] = None,
|
| 174 |
+
text_indices: Optional[Union[int, List[int]]] = None,
|
| 175 |
+
stop_early: bool = False,
|
| 176 |
+
normalize: bool = True,
|
| 177 |
+
normalize_intermediates: bool = False,
|
| 178 |
+
intermediates_only: bool = False,
|
| 179 |
+
image_output_fmt: str = 'NCHW',
|
| 180 |
+
image_output_extra_tokens: bool = False,
|
| 181 |
+
text_output_fmt: str = 'NLC',
|
| 182 |
+
text_output_extra_tokens: bool = False,
|
| 183 |
+
output_logits: bool = False,
|
| 184 |
+
output_logit_scale_bias: bool = False,
|
| 185 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
| 186 |
+
""" Forward features that returns intermediates.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
image: Input image tensor
|
| 190 |
+
text: Input text tensor
|
| 191 |
+
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
|
| 192 |
+
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 193 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 194 |
+
normalize: L2 Normalize final image and text features (if present)
|
| 195 |
+
normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
|
| 196 |
+
intermediates_only: Only return intermediate features, do not return final features
|
| 197 |
+
image_output_fmt: Shape of intermediate image feature outputs
|
| 198 |
+
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
| 199 |
+
text_output_fmt: Shape of intermediate text feature outputs
|
| 200 |
+
text_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
| 201 |
+
output_logits: Include logits in output
|
| 202 |
+
output_logit_scale_bias: Include the logit scale bias in the output
|
| 203 |
+
Returns:
|
| 204 |
+
|
| 205 |
+
"""
|
| 206 |
+
output = {}
|
| 207 |
+
if intermediates_only:
|
| 208 |
+
# intermediates only disables final feature normalization, and include logits
|
| 209 |
+
normalize = False
|
| 210 |
+
output_logits = False
|
| 211 |
+
if output_logits:
|
| 212 |
+
assert False, 'FIXME, needs implementing'
|
| 213 |
+
|
| 214 |
+
if image is not None:
|
| 215 |
+
image_output = self.visual.forward_intermediates(
|
| 216 |
+
image,
|
| 217 |
+
indices=image_indices,
|
| 218 |
+
stop_early=stop_early,
|
| 219 |
+
normalize_intermediates=normalize_intermediates,
|
| 220 |
+
intermediates_only=intermediates_only,
|
| 221 |
+
output_fmt=image_output_fmt,
|
| 222 |
+
output_extra_tokens=image_output_extra_tokens,
|
| 223 |
+
)
|
| 224 |
+
if normalize and "image_features" in image_output:
|
| 225 |
+
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
|
| 226 |
+
output.update(image_output)
|
| 227 |
+
|
| 228 |
+
if text is not None:
|
| 229 |
+
text_output = self.text.forward_intermediates(
|
| 230 |
+
text,
|
| 231 |
+
indices=text_indices,
|
| 232 |
+
stop_early=stop_early,
|
| 233 |
+
normalize_intermediates=normalize_intermediates,
|
| 234 |
+
intermediates_only=intermediates_only,
|
| 235 |
+
output_fmt=text_output_fmt,
|
| 236 |
+
output_extra_tokens=text_output_extra_tokens,
|
| 237 |
+
)
|
| 238 |
+
if normalize and "text_features" in text_output:
|
| 239 |
+
text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
|
| 240 |
+
output.update(text_output)
|
| 241 |
+
|
| 242 |
+
# FIXME text decoder
|
| 243 |
+
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
|
| 244 |
+
if output_logit_scale_bias:
|
| 245 |
+
output["logit_scale"] = logit_scale_exp
|
| 246 |
+
if self.logit_bias is not None:
|
| 247 |
+
output['logit_bias'] = self.logit_bias
|
| 248 |
+
|
| 249 |
+
return output
|
| 250 |
+
|
| 251 |
+
def forward(
|
| 252 |
+
self,
|
| 253 |
+
image,
|
| 254 |
+
text: Optional[torch.Tensor] = None,
|
| 255 |
+
image_latent: Optional[torch.Tensor] = None,
|
| 256 |
+
image_embs: Optional[torch.Tensor] = None,
|
| 257 |
+
output_labels: bool = True,
|
| 258 |
+
):
|
| 259 |
+
if image_latent is None or image_embs is None:
|
| 260 |
+
image_latent, image_embs = self._encode_image(image)
|
| 261 |
+
|
| 262 |
+
if text is None:
|
| 263 |
+
return {"image_features": image_latent, "image_embs": image_embs}
|
| 264 |
+
|
| 265 |
+
text_latent, token_embs = self._encode_text(text)
|
| 266 |
+
|
| 267 |
+
# FIXME this isn't an ideal solution, would like to improve -RW
|
| 268 |
+
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
|
| 269 |
+
if output_labels:
|
| 270 |
+
# align text_embs and thus logits with labels for teacher-forcing caption loss
|
| 271 |
+
token_embs = token_embs[:, :-1]
|
| 272 |
+
|
| 273 |
+
logits = self.text_decoder(image_embs, token_embs)
|
| 274 |
+
out_dict = {
|
| 275 |
+
"image_features": image_latent,
|
| 276 |
+
"text_features": text_latent,
|
| 277 |
+
"logits": logits,
|
| 278 |
+
"logit_scale": self.logit_scale.exp()
|
| 279 |
+
}
|
| 280 |
+
if labels is not None:
|
| 281 |
+
out_dict["labels"] = labels
|
| 282 |
+
if self.logit_bias is not None:
|
| 283 |
+
out_dict["logit_bias"] = self.logit_bias
|
| 284 |
+
return out_dict
|
| 285 |
+
|
| 286 |
+
def generate(
|
| 287 |
+
self,
|
| 288 |
+
image,
|
| 289 |
+
text=None,
|
| 290 |
+
seq_len=30,
|
| 291 |
+
max_seq_len=77,
|
| 292 |
+
temperature=1.,
|
| 293 |
+
generation_type="beam_search",
|
| 294 |
+
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
| 295 |
+
top_k=1, # keeps the top_k most probable tokens
|
| 296 |
+
pad_token_id=None,
|
| 297 |
+
eos_token_id=None,
|
| 298 |
+
sot_token_id=None,
|
| 299 |
+
num_beams=6,
|
| 300 |
+
num_beam_groups=3,
|
| 301 |
+
min_seq_len=5,
|
| 302 |
+
stopping_criteria=None,
|
| 303 |
+
repetition_penalty=1.0,
|
| 304 |
+
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
| 305 |
+
):
|
| 306 |
+
# taking many ideas and components from HuggingFace GenerationMixin
|
| 307 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
| 308 |
+
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
| 309 |
+
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
| 310 |
+
device = image.device
|
| 311 |
+
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
|
| 314 |
+
eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
|
| 315 |
+
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
| 316 |
+
logit_processor = LogitsProcessorList(
|
| 317 |
+
[
|
| 318 |
+
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
| 319 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
| 320 |
+
]
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if stopping_criteria is None:
|
| 324 |
+
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
| 325 |
+
stopping_criteria = StoppingCriteriaList(stopping_criteria)
|
| 326 |
+
|
| 327 |
+
if generation_type == "beam_search":
|
| 328 |
+
output = self._generate_beamsearch(
|
| 329 |
+
image_inputs=image,
|
| 330 |
+
pad_token_id=pad_token_id,
|
| 331 |
+
eos_token_id=eos_token_id,
|
| 332 |
+
sot_token_id=sot_token_id,
|
| 333 |
+
num_beams=num_beams,
|
| 334 |
+
num_beam_groups=num_beam_groups,
|
| 335 |
+
min_seq_len=min_seq_len,
|
| 336 |
+
stopping_criteria=stopping_criteria,
|
| 337 |
+
logit_processor=logit_processor,
|
| 338 |
+
)
|
| 339 |
+
if fixed_output_length and output.shape[1] < seq_len:
|
| 340 |
+
pad_len = seq_len - output.shape[1]
|
| 341 |
+
return torch.cat((
|
| 342 |
+
output,
|
| 343 |
+
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
|
| 344 |
+
),
|
| 345 |
+
dim=1
|
| 346 |
+
)
|
| 347 |
+
return output
|
| 348 |
+
|
| 349 |
+
elif generation_type == "top_p":
|
| 350 |
+
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
| 351 |
+
elif generation_type == "top_k":
|
| 352 |
+
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
| 353 |
+
else:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
f"generation_type has to be one of "
|
| 356 |
+
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
image_latent, image_embs = self._encode_image(image)
|
| 360 |
+
|
| 361 |
+
if text is None:
|
| 362 |
+
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
| 363 |
+
|
| 364 |
+
was_training = self.training
|
| 365 |
+
num_dims = len(text.shape)
|
| 366 |
+
|
| 367 |
+
if num_dims == 1:
|
| 368 |
+
text = text[None, :]
|
| 369 |
+
|
| 370 |
+
self.eval()
|
| 371 |
+
out = text
|
| 372 |
+
|
| 373 |
+
while True:
|
| 374 |
+
x = out[:, -max_seq_len:]
|
| 375 |
+
cur_len = x.shape[1]
|
| 376 |
+
logits = self(
|
| 377 |
+
image,
|
| 378 |
+
x,
|
| 379 |
+
image_latent=image_latent,
|
| 380 |
+
image_embs=image_embs,
|
| 381 |
+
output_labels=False,
|
| 382 |
+
)["logits"][:, -1]
|
| 383 |
+
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
| 384 |
+
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
| 385 |
+
|
| 386 |
+
if mask.all():
|
| 387 |
+
if not fixed_output_length:
|
| 388 |
+
break
|
| 389 |
+
else:
|
| 390 |
+
logits = logits[~mask, :]
|
| 391 |
+
filtered_logits = logit_processor(x[~mask, :], logits)
|
| 392 |
+
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
| 393 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| 394 |
+
|
| 395 |
+
if (cur_len + 1 == seq_len):
|
| 396 |
+
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
| 397 |
+
else:
|
| 398 |
+
sample[~mask, :] = torch.multinomial(probs, 1)
|
| 399 |
+
|
| 400 |
+
out = torch.cat((out, sample), dim=-1)
|
| 401 |
+
|
| 402 |
+
cur_len += 1
|
| 403 |
+
|
| 404 |
+
if all(stopping_criteria(out, None)):
|
| 405 |
+
break
|
| 406 |
+
|
| 407 |
+
if num_dims == 1:
|
| 408 |
+
out = out.squeeze(0)
|
| 409 |
+
|
| 410 |
+
self.train(was_training)
|
| 411 |
+
return out
|
| 412 |
+
|
| 413 |
+
def _generate_beamsearch(
|
| 414 |
+
self,
|
| 415 |
+
image_inputs,
|
| 416 |
+
pad_token_id=None,
|
| 417 |
+
eos_token_id=None,
|
| 418 |
+
sot_token_id=None,
|
| 419 |
+
num_beams=6,
|
| 420 |
+
num_beam_groups=3,
|
| 421 |
+
min_seq_len=5,
|
| 422 |
+
stopping_criteria=None,
|
| 423 |
+
logit_processor=None,
|
| 424 |
+
logit_warper=None,
|
| 425 |
+
):
|
| 426 |
+
device = image_inputs.device
|
| 427 |
+
batch_size = image_inputs.shape[0]
|
| 428 |
+
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
| 429 |
+
image_latent, image_embs = self._encode_image(image_inputs)
|
| 430 |
+
|
| 431 |
+
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
| 432 |
+
input_ids = input_ids * sot_token_id
|
| 433 |
+
beam_scorer = BeamSearchScorer(
|
| 434 |
+
batch_size=batch_size,
|
| 435 |
+
num_beams=num_beams,
|
| 436 |
+
device=device,
|
| 437 |
+
num_beam_groups=num_beam_groups,
|
| 438 |
+
)
|
| 439 |
+
# instantiate logits processors
|
| 440 |
+
logits_processor = (
|
| 441 |
+
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
| 442 |
+
if logit_processor is None
|
| 443 |
+
else logit_processor
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
num_beams = beam_scorer.num_beams
|
| 447 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
| 448 |
+
num_sub_beams = num_beams // num_beam_groups
|
| 449 |
+
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
|
| 450 |
+
batch_beam_size, cur_len = input_ids.shape
|
| 451 |
+
beam_indices = None
|
| 452 |
+
|
| 453 |
+
if num_beams * batch_size != batch_beam_size:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
| 459 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
| 460 |
+
# the same group don't produce same tokens everytime.
|
| 461 |
+
beam_scores[:, ::num_sub_beams] = 0
|
| 462 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
| 463 |
+
|
| 464 |
+
while True:
|
| 465 |
+
|
| 466 |
+
# predicted tokens in cur_len step
|
| 467 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
| 468 |
+
|
| 469 |
+
# indices which will form the beams in the next time step
|
| 470 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
| 471 |
+
|
| 472 |
+
# do one decoder step on all beams of all sentences in batch
|
| 473 |
+
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
| 474 |
+
outputs = self(
|
| 475 |
+
model_inputs['images'],
|
| 476 |
+
model_inputs['text'],
|
| 477 |
+
image_latent=image_latent,
|
| 478 |
+
image_embs=image_embs,
|
| 479 |
+
output_labels=False,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
for beam_group_idx in range(num_beam_groups):
|
| 483 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
| 484 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
| 485 |
+
group_size = group_end_idx - group_start_idx
|
| 486 |
+
|
| 487 |
+
# indices of beams of current group among all sentences in batch
|
| 488 |
+
batch_group_indices = []
|
| 489 |
+
|
| 490 |
+
for batch_idx in range(batch_size):
|
| 491 |
+
batch_group_indices.extend(
|
| 492 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
| 493 |
+
)
|
| 494 |
+
group_input_ids = input_ids[batch_group_indices]
|
| 495 |
+
|
| 496 |
+
# select outputs of beams of currentg group only
|
| 497 |
+
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
| 498 |
+
vocab_size = next_token_logits.shape[-1]
|
| 499 |
+
|
| 500 |
+
next_token_scores_processed = logits_processor(
|
| 501 |
+
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
| 502 |
+
)
|
| 503 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
| 504 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 505 |
+
|
| 506 |
+
# reshape for beam search
|
| 507 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
| 508 |
+
|
| 509 |
+
next_token_scores, next_tokens = torch.topk(
|
| 510 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
| 514 |
+
next_tokens = next_tokens % vocab_size
|
| 515 |
+
|
| 516 |
+
# stateless
|
| 517 |
+
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 518 |
+
beam_outputs = beam_scorer.process(
|
| 519 |
+
group_input_ids,
|
| 520 |
+
next_token_scores,
|
| 521 |
+
next_tokens,
|
| 522 |
+
next_indices,
|
| 523 |
+
pad_token_id=pad_token_id,
|
| 524 |
+
eos_token_id=eos_token_id,
|
| 525 |
+
beam_indices=process_beam_indices,
|
| 526 |
+
group_index=beam_group_idx,
|
| 527 |
+
)
|
| 528 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
| 529 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
| 530 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
| 531 |
+
|
| 532 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
| 533 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
| 534 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
| 535 |
+
|
| 536 |
+
# (beam_idx // group_size) -> batch_idx
|
| 537 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
| 538 |
+
reordering_indices[batch_group_indices] = (
|
| 539 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
| 543 |
+
|
| 544 |
+
# increase cur_len
|
| 545 |
+
cur_len = cur_len + 1
|
| 546 |
+
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
|
| 547 |
+
break
|
| 548 |
+
|
| 549 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 550 |
+
sequence_outputs = beam_scorer.finalize(
|
| 551 |
+
input_ids,
|
| 552 |
+
beam_scores,
|
| 553 |
+
next_tokens,
|
| 554 |
+
next_indices,
|
| 555 |
+
pad_token_id=pad_token_id,
|
| 556 |
+
eos_token_id=eos_token_id,
|
| 557 |
+
max_length=stopping_criteria.max_length,
|
| 558 |
+
beam_indices=final_beam_indices,
|
| 559 |
+
)
|
| 560 |
+
return sequence_outputs['sequences']
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
| 564 |
+
if past:
|
| 565 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 566 |
+
|
| 567 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 568 |
+
position_ids = kwargs.get("position_ids", None)
|
| 569 |
+
|
| 570 |
+
if attention_mask is not None and position_ids is None:
|
| 571 |
+
# create position_ids on the fly for batch generation
|
| 572 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 573 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 574 |
+
else:
|
| 575 |
+
position_ids = None
|
| 576 |
+
return {
|
| 577 |
+
"text": input_ids,
|
| 578 |
+
"images": image_inputs,
|
| 579 |
+
"past_key_values": past,
|
| 580 |
+
"position_ids": position_ids,
|
| 581 |
+
"attention_mask": attention_mask,
|
| 582 |
+
}
|
models/FarSLIP/open_clip/constants.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 3 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 4 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 5 |
+
INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
| 6 |
+
INCEPTION_STD = (0.5, 0.5, 0.5)
|
| 7 |
+
|
| 8 |
+
# Default name for a weights file hosted on the Huggingface Hub.
|
| 9 |
+
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
|
| 10 |
+
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
|
| 11 |
+
HF_CONFIG_NAME = 'open_clip_config.json'
|
models/FarSLIP/open_clip/convert.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
|
| 2 |
+
"""
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from .model import CLIP, CustomTextCLIP
|
| 9 |
+
from .transformer import TextTransformer, Transformer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch.no_grad()
|
| 13 |
+
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
|
| 14 |
+
""" Load weights from .npz checkpoints for official Google big_vision image-text models
|
| 15 |
+
|
| 16 |
+
Currently, the SigLIP source models are supported and a CustomTextCLIP destination model
|
| 17 |
+
w/ timm image encoder.
|
| 18 |
+
"""
|
| 19 |
+
from timm.layers import resample_patch_embed, resample_abs_pos_embed
|
| 20 |
+
|
| 21 |
+
def _n2p(w, t=True, idx=None):
|
| 22 |
+
if idx is not None:
|
| 23 |
+
w = w[idx]
|
| 24 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
| 25 |
+
w = w.flatten()
|
| 26 |
+
if t:
|
| 27 |
+
if w.ndim == 4:
|
| 28 |
+
w = w.transpose([3, 2, 0, 1])
|
| 29 |
+
elif w.ndim == 3:
|
| 30 |
+
w = w.transpose([2, 0, 1])
|
| 31 |
+
elif w.ndim == 2:
|
| 32 |
+
w = w.transpose([1, 0])
|
| 33 |
+
return torch.from_numpy(w)
|
| 34 |
+
|
| 35 |
+
w = np.load(checkpoint_path)
|
| 36 |
+
interpolation = 'bilinear'
|
| 37 |
+
antialias = False
|
| 38 |
+
|
| 39 |
+
def _convert_timm_img(module, prefix):
|
| 40 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
| 41 |
+
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
|
| 42 |
+
embed_conv_w = resample_patch_embed(
|
| 43 |
+
embed_conv_w,
|
| 44 |
+
module.patch_embed.proj.weight.shape[-2:],
|
| 45 |
+
interpolation=interpolation,
|
| 46 |
+
antialias=antialias,
|
| 47 |
+
verbose=True,
|
| 48 |
+
)
|
| 49 |
+
module.patch_embed.proj.weight.copy_(embed_conv_w)
|
| 50 |
+
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
| 51 |
+
|
| 52 |
+
if module.cls_token is not None:
|
| 53 |
+
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
| 54 |
+
|
| 55 |
+
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
|
| 56 |
+
if pos_embed_w.shape != module.pos_embed.shape:
|
| 57 |
+
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
|
| 58 |
+
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
|
| 59 |
+
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
|
| 60 |
+
pos_embed_w,
|
| 61 |
+
new_size=module.patch_embed.grid_size,
|
| 62 |
+
num_prefix_tokens=num_prefix_tokens,
|
| 63 |
+
interpolation=interpolation,
|
| 64 |
+
antialias=antialias,
|
| 65 |
+
verbose=True,
|
| 66 |
+
)
|
| 67 |
+
module.pos_embed.copy_(pos_embed_w)
|
| 68 |
+
|
| 69 |
+
mha_sub, b_sub, ln1_sub = (0, 0, 1)
|
| 70 |
+
for i, block in enumerate(module.blocks.children()):
|
| 71 |
+
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
|
| 72 |
+
block_prefix = f'{prefix}Transformer/encoderblock/'
|
| 73 |
+
idx = i
|
| 74 |
+
else:
|
| 75 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
| 76 |
+
idx = None
|
| 77 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
|
| 78 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
|
| 79 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
|
| 80 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
| 81 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
|
| 82 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
| 83 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
|
| 84 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
|
| 85 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
|
| 86 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
|
| 87 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
|
| 88 |
+
for r in range(2):
|
| 89 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
|
| 90 |
+
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
|
| 91 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
|
| 92 |
+
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
|
| 93 |
+
|
| 94 |
+
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
| 95 |
+
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
| 96 |
+
|
| 97 |
+
if module.attn_pool is not None:
|
| 98 |
+
block_prefix = f'{prefix}MAPHead_0/'
|
| 99 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
|
| 100 |
+
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
|
| 101 |
+
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
|
| 102 |
+
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
|
| 103 |
+
module.attn_pool.kv.weight.copy_(torch.cat([
|
| 104 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
|
| 105 |
+
module.attn_pool.kv.bias.copy_(torch.cat([
|
| 106 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
|
| 107 |
+
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
| 108 |
+
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
| 109 |
+
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
| 110 |
+
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
| 111 |
+
for r in range(2):
|
| 112 |
+
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
|
| 113 |
+
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
|
| 114 |
+
|
| 115 |
+
def _convert_openclip_transformer(module: Transformer, prefix):
|
| 116 |
+
for i, block in enumerate(module.resblocks.children()):
|
| 117 |
+
if f'{prefix}encoderblock/LayerNorm_0/scale' in w:
|
| 118 |
+
block_prefix = f'{prefix}encoderblock/'
|
| 119 |
+
idx = i
|
| 120 |
+
else:
|
| 121 |
+
block_prefix = f'{prefix}encoderblock_{i}/'
|
| 122 |
+
idx = None
|
| 123 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
|
| 124 |
+
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
|
| 125 |
+
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
|
| 126 |
+
block.attn.in_proj_weight.copy_(torch.cat([
|
| 127 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
|
| 128 |
+
block.attn.in_proj_bias.copy_(torch.cat([
|
| 129 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
|
| 130 |
+
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
|
| 131 |
+
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
|
| 132 |
+
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'], idx=idx))
|
| 133 |
+
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'], idx=idx))
|
| 134 |
+
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'], idx=idx))
|
| 135 |
+
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'], idx=idx))
|
| 136 |
+
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'], idx=idx))
|
| 137 |
+
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'], idx=idx))
|
| 138 |
+
|
| 139 |
+
def _convert_openclip_txt(module: TextTransformer, prefix):
|
| 140 |
+
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
|
| 141 |
+
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
|
| 142 |
+
module.positional_embedding.copy_(pos_embed_w)
|
| 143 |
+
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
|
| 144 |
+
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
|
| 145 |
+
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
|
| 146 |
+
if module.text_projection is not None:
|
| 147 |
+
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
| 148 |
+
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
| 149 |
+
|
| 150 |
+
root_prefix = 'params/' if 'params/b' in w else ''
|
| 151 |
+
_convert_timm_img(model.visual.trunk, f'{root_prefix}img/')
|
| 152 |
+
_convert_openclip_txt(model.text, f'{root_prefix}txt/')
|
| 153 |
+
model.logit_bias.copy_(_n2p(w[f'{root_prefix}b'])[0])
|
| 154 |
+
model.logit_scale.copy_(_n2p(w[f'{root_prefix}t'])[0])
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
|
| 159 |
+
|
| 160 |
+
def _convert_timm_img(state_dict):
|
| 161 |
+
if fastvit:
|
| 162 |
+
from timm.models.fastvit import checkpoint_filter_fn
|
| 163 |
+
else:
|
| 164 |
+
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
|
| 165 |
+
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
|
| 166 |
+
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
|
| 167 |
+
return timm_state_dict
|
| 168 |
+
|
| 169 |
+
def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
|
| 170 |
+
text_dict = {}
|
| 171 |
+
for k, v in state_dict.items():
|
| 172 |
+
if not k.startswith(prefix):
|
| 173 |
+
continue
|
| 174 |
+
k = k.replace(prefix, '')
|
| 175 |
+
k = k.replace('projection_layer', 'text_projection')
|
| 176 |
+
k = k.replace('embedding_layer', 'token_embedding')
|
| 177 |
+
if k.startswith('positional_embedding.pos_embed.pos_embed'):
|
| 178 |
+
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
|
| 179 |
+
v = v.squeeze()
|
| 180 |
+
k = k.replace('final_layer_norm', 'ln_final')
|
| 181 |
+
k = k.replace('pre_norm_mha.0', 'ln_1')
|
| 182 |
+
k = k.replace('pre_norm_mha.1', 'attn')
|
| 183 |
+
k = k.replace('pre_norm_ffn.0', 'ln_2')
|
| 184 |
+
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
|
| 185 |
+
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
|
| 186 |
+
k = k.replace('qkv_proj.weight', 'in_proj_weight')
|
| 187 |
+
k = k.replace('qkv_proj.bias', 'in_proj_bias')
|
| 188 |
+
k = k.replace('transformer.', 'transformer.resblocks.')
|
| 189 |
+
text_dict['text.' + k] = v
|
| 190 |
+
return text_dict
|
| 191 |
+
|
| 192 |
+
image_dict = _convert_timm_img(state_dict)
|
| 193 |
+
text_dict = _convert_openclip_txt(state_dict)
|
| 194 |
+
out_dict = {**image_dict, **text_dict}
|
| 195 |
+
out_dict['logit_scale'] = state_dict['logit_scale']
|
| 196 |
+
return out_dict
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
|
| 200 |
+
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
|
| 201 |
+
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
|
| 202 |
+
state_dict = convert_mobile_clip_state_dict(model, state_dict)
|
| 203 |
+
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
|
| 204 |
+
# convert b model
|
| 205 |
+
state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
|
| 206 |
+
return state_dict
|
models/FarSLIP/open_clip/factory.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import warnings
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from dataclasses import asdict
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from .convert import convert_state_dict
|
| 14 |
+
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
| 15 |
+
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
|
| 16 |
+
from .coca_model import CoCa
|
| 17 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss, MultiPosConLossMM
|
| 18 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
|
| 19 |
+
list_pretrained_tags_by_model, download_pretrained_from_hf
|
| 20 |
+
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
|
| 21 |
+
from .tokenizer import HFTokenizer, SimpleTokenizer, SigLipTokenizer, DEFAULT_CONTEXT_LENGTH
|
| 22 |
+
|
| 23 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
| 24 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
| 25 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _natural_key(string_):
|
| 29 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _rescan_model_configs():
|
| 33 |
+
global _MODEL_CONFIGS
|
| 34 |
+
|
| 35 |
+
config_ext = ('.json',)
|
| 36 |
+
config_files = []
|
| 37 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
| 38 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
| 39 |
+
config_files.append(config_path)
|
| 40 |
+
elif config_path.is_dir():
|
| 41 |
+
for ext in config_ext:
|
| 42 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
| 43 |
+
|
| 44 |
+
for cf in config_files:
|
| 45 |
+
with open(cf, 'r') as f:
|
| 46 |
+
model_cfg = json.load(f)
|
| 47 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
| 48 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
| 49 |
+
|
| 50 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_rescan_model_configs() # initial populate of model config registry
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def list_models():
|
| 57 |
+
""" enumerate available model architectures based on config files """
|
| 58 |
+
return list(_MODEL_CONFIGS.keys())
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def add_model_config(path):
|
| 62 |
+
""" add model config path or file and update registry """
|
| 63 |
+
if not isinstance(path, Path):
|
| 64 |
+
path = Path(path)
|
| 65 |
+
_MODEL_CONFIG_PATHS.append(path)
|
| 66 |
+
_rescan_model_configs()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_model_config(model_name):
|
| 70 |
+
""" Fetch model config from builtin (local library) configs.
|
| 71 |
+
"""
|
| 72 |
+
if model_name in _MODEL_CONFIGS:
|
| 73 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
| 74 |
+
else:
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _get_hf_config(
|
| 79 |
+
model_id: str,
|
| 80 |
+
cache_dir: Optional[str] = None,
|
| 81 |
+
):
|
| 82 |
+
""" Fetch model config from HuggingFace Hub.
|
| 83 |
+
"""
|
| 84 |
+
config_path = download_pretrained_from_hf(
|
| 85 |
+
model_id,
|
| 86 |
+
filename='open_clip_config.json',
|
| 87 |
+
cache_dir=cache_dir,
|
| 88 |
+
)
|
| 89 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 90 |
+
config = json.load(f)
|
| 91 |
+
return config
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_tokenizer(
|
| 95 |
+
model_name: str = '',
|
| 96 |
+
context_length: Optional[int] = None,
|
| 97 |
+
cache_dir: Optional[str] = None,
|
| 98 |
+
**kwargs,
|
| 99 |
+
):
|
| 100 |
+
if model_name.startswith(HF_HUB_PREFIX):
|
| 101 |
+
model_name = model_name[len(HF_HUB_PREFIX):]
|
| 102 |
+
try:
|
| 103 |
+
config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg']
|
| 104 |
+
except Exception:
|
| 105 |
+
tokenizer = HFTokenizer(
|
| 106 |
+
model_name,
|
| 107 |
+
context_length=context_length or DEFAULT_CONTEXT_LENGTH,
|
| 108 |
+
cache_dir=cache_dir,
|
| 109 |
+
**kwargs,
|
| 110 |
+
)
|
| 111 |
+
return tokenizer
|
| 112 |
+
else:
|
| 113 |
+
config = get_model_config(model_name)
|
| 114 |
+
assert config is not None, f"No valid model config found for {model_name}."
|
| 115 |
+
|
| 116 |
+
text_config = config.get('text_cfg', {})
|
| 117 |
+
if 'tokenizer_kwargs' in text_config:
|
| 118 |
+
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
|
| 119 |
+
else:
|
| 120 |
+
tokenizer_kwargs = kwargs
|
| 121 |
+
|
| 122 |
+
if context_length is None:
|
| 123 |
+
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
|
| 124 |
+
|
| 125 |
+
model_name = model_name.lower()
|
| 126 |
+
if text_config.get('hf_tokenizer_name', ''):
|
| 127 |
+
tokenizer = HFTokenizer(
|
| 128 |
+
text_config['hf_tokenizer_name'],
|
| 129 |
+
context_length=context_length,
|
| 130 |
+
cache_dir=cache_dir,
|
| 131 |
+
**tokenizer_kwargs,
|
| 132 |
+
)
|
| 133 |
+
elif 'siglip' in model_name:
|
| 134 |
+
tn = 'gemma' if 'siglip2' in model_name else 'mc4' if 'i18n' in model_name else 'c4-en'
|
| 135 |
+
tokenizer = SigLipTokenizer(
|
| 136 |
+
tn,
|
| 137 |
+
context_length=context_length,
|
| 138 |
+
# **tokenizer_kwargs,
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
tokenizer = SimpleTokenizer(
|
| 142 |
+
context_length=context_length,
|
| 143 |
+
**tokenizer_kwargs,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return tokenizer
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_state_dict(
|
| 150 |
+
checkpoint_path: str,
|
| 151 |
+
device='cpu',
|
| 152 |
+
weights_only=True,
|
| 153 |
+
):
|
| 154 |
+
# Check if safetensors or not and load weights accordingly
|
| 155 |
+
if str(checkpoint_path).endswith(".safetensors"):
|
| 156 |
+
from safetensors.torch import load_file
|
| 157 |
+
checkpoint = load_file(checkpoint_path, device=device)
|
| 158 |
+
else:
|
| 159 |
+
try:
|
| 160 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
|
| 161 |
+
except Exception:
|
| 162 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 163 |
+
|
| 164 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
| 165 |
+
state_dict = checkpoint['state_dict']
|
| 166 |
+
elif isinstance(checkpoint, torch.jit.ScriptModule):
|
| 167 |
+
state_dict = checkpoint.state_dict()
|
| 168 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 169 |
+
state_dict.pop(key, None)
|
| 170 |
+
else:
|
| 171 |
+
state_dict = checkpoint
|
| 172 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
| 173 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 174 |
+
return state_dict
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def load_checkpoint(
|
| 178 |
+
model: Union[CLIP, CustomTextCLIP],
|
| 179 |
+
checkpoint_path: str,
|
| 180 |
+
strict: bool = True,
|
| 181 |
+
weights_only: bool = True,
|
| 182 |
+
device='cpu',
|
| 183 |
+
):
|
| 184 |
+
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
|
| 185 |
+
# Separate path loading numpy big_vision (SigLIP) weights
|
| 186 |
+
from open_clip.convert import load_big_vision_weights
|
| 187 |
+
load_big_vision_weights(model, checkpoint_path)
|
| 188 |
+
return {}
|
| 189 |
+
|
| 190 |
+
state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
|
| 191 |
+
|
| 192 |
+
# Detect & convert 3rd party state_dicts -> open_clip
|
| 193 |
+
state_dict = convert_state_dict(model, state_dict)
|
| 194 |
+
|
| 195 |
+
# Detect old format and make compatible with new format
|
| 196 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
| 197 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
| 198 |
+
|
| 199 |
+
# correct if logit_scale differs in being scaler vs 1d param
|
| 200 |
+
if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
|
| 201 |
+
state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)
|
| 202 |
+
|
| 203 |
+
# correct if logit_bias differs in being scaler vs 1d param
|
| 204 |
+
if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
|
| 205 |
+
state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)
|
| 206 |
+
|
| 207 |
+
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
|
| 208 |
+
if 'logit_bias' not in state_dict and model.logit_bias is not None:
|
| 209 |
+
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
|
| 210 |
+
|
| 211 |
+
# Certain text transformers no longer expect position_ids after transformers==4.31
|
| 212 |
+
position_id_key = 'text.transformer.embeddings.position_ids'
|
| 213 |
+
if position_id_key in state_dict and not hasattr(model, position_id_key):
|
| 214 |
+
del state_dict[position_id_key]
|
| 215 |
+
|
| 216 |
+
resize_pos_embed(state_dict, model)
|
| 217 |
+
resize_text_pos_embed(state_dict, model)
|
| 218 |
+
|
| 219 |
+
# Finally, load the massaged state_dict into model
|
| 220 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
| 221 |
+
if incompatible_keys.missing_keys:
|
| 222 |
+
print("Missing keys:", incompatible_keys.missing_keys)
|
| 223 |
+
if incompatible_keys.unexpected_keys:
|
| 224 |
+
print("Unexpected keys:", incompatible_keys.unexpected_keys)
|
| 225 |
+
|
| 226 |
+
logging.info(f"Missing keys: {incompatible_keys.missing_keys}")
|
| 227 |
+
return incompatible_keys
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def create_model(
|
| 231 |
+
model_name: str,
|
| 232 |
+
pretrained: Optional[str] = None,
|
| 233 |
+
precision: str = 'fp32',
|
| 234 |
+
device: Union[str, torch.device] = 'cpu',
|
| 235 |
+
jit: bool = False,
|
| 236 |
+
force_quick_gelu: bool = False,
|
| 237 |
+
force_custom_text: bool = False,
|
| 238 |
+
force_patch_dropout: Optional[float] = None,
|
| 239 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 240 |
+
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
|
| 241 |
+
pretrained_image: bool = False,
|
| 242 |
+
pretrained_hf: bool = True,
|
| 243 |
+
cache_dir: Optional[str] = None,
|
| 244 |
+
output_dict: Optional[bool] = None,
|
| 245 |
+
require_pretrained: bool = False,
|
| 246 |
+
load_weights_only: bool = True,
|
| 247 |
+
long_clip: Optional[str] = 'disable',
|
| 248 |
+
**model_kwargs,
|
| 249 |
+
):
|
| 250 |
+
"""Creates and configures a contrastive vision-language model.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
model_name: Name of the model architecture to create. Can be a local model name
|
| 254 |
+
or a Hugging Face model ID prefixed with 'hf-hub:'.
|
| 255 |
+
pretrained: Tag/path for pretrained model weights. Can be:
|
| 256 |
+
- A pretrained tag name (e.g., 'openai')
|
| 257 |
+
- A path to local weights
|
| 258 |
+
- None to initialize with random weights
|
| 259 |
+
precision: Model precision/AMP configuration. Options:
|
| 260 |
+
- 'fp32': 32-bit floating point
|
| 261 |
+
- 'fp16'/'bf16': Mixed precision with FP32 for certain layers
|
| 262 |
+
- 'pure_fp16'/'pure_bf16': Pure 16-bit precision
|
| 263 |
+
device: Device to load the model on ('cpu', 'cuda', or torch.device object)
|
| 264 |
+
jit: If True, JIT compile the model
|
| 265 |
+
force_quick_gelu: Force use of QuickGELU activation
|
| 266 |
+
force_custom_text: Force use of custom text encoder
|
| 267 |
+
force_patch_dropout: Override default patch dropout value
|
| 268 |
+
force_image_size: Override default image size for vision encoder
|
| 269 |
+
force_preprocess_cfg: Override default preprocessing configuration
|
| 270 |
+
pretrained_image: Load pretrained weights for timm vision models
|
| 271 |
+
pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights
|
| 272 |
+
cache_dir: Override default cache directory for downloaded model files
|
| 273 |
+
output_dict: If True and model supports it, return dictionary of features
|
| 274 |
+
require_pretrained: Raise error if pretrained weights cannot be loaded
|
| 275 |
+
load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)
|
| 276 |
+
**model_kwargs: Additional keyword arguments passed to model constructor
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Created and configured model instance
|
| 280 |
+
|
| 281 |
+
Raises:
|
| 282 |
+
RuntimeError: If model config is not found or required pretrained weights
|
| 283 |
+
cannot be loaded
|
| 284 |
+
|
| 285 |
+
Examples:
|
| 286 |
+
# Create basic CLIP model
|
| 287 |
+
model = create_model('ViT-B/32')
|
| 288 |
+
|
| 289 |
+
# Create CLIP model with mixed precision on GPU
|
| 290 |
+
model = create_model('ViT-B/32', precision='fp16', device='cuda')
|
| 291 |
+
|
| 292 |
+
# Load pretrained OpenAI weights
|
| 293 |
+
model = create_model('ViT-B/32', pretrained='openai')
|
| 294 |
+
|
| 295 |
+
# Load Hugging Face model
|
| 296 |
+
model = create_model('hf-hub:organization/model-name')
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
force_preprocess_cfg = force_preprocess_cfg or {}
|
| 300 |
+
preprocess_cfg = asdict(PreprocessCfg())
|
| 301 |
+
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
| 302 |
+
if has_hf_hub_prefix:
|
| 303 |
+
model_id = model_name[len(HF_HUB_PREFIX):]
|
| 304 |
+
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
| 305 |
+
config = _get_hf_config(model_id, cache_dir=cache_dir)
|
| 306 |
+
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
|
| 307 |
+
model_cfg = config['model_cfg']
|
| 308 |
+
pretrained_hf = False # override, no need to load original HF text weights
|
| 309 |
+
else:
|
| 310 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
| 311 |
+
checkpoint_path = None
|
| 312 |
+
model_cfg = None
|
| 313 |
+
|
| 314 |
+
if isinstance(device, str):
|
| 315 |
+
device = torch.device(device)
|
| 316 |
+
|
| 317 |
+
model_cfg = model_cfg or get_model_config(model_name)
|
| 318 |
+
if model_cfg is not None:
|
| 319 |
+
logging.info(f'Loaded {model_name} model config.')
|
| 320 |
+
else:
|
| 321 |
+
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
| 322 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
| 323 |
+
|
| 324 |
+
if force_quick_gelu:
|
| 325 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
| 326 |
+
model_cfg["quick_gelu"] = True
|
| 327 |
+
|
| 328 |
+
if force_patch_dropout is not None:
|
| 329 |
+
# override the default patch dropout value
|
| 330 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
| 331 |
+
|
| 332 |
+
if force_image_size is not None:
|
| 333 |
+
# override model config's image size
|
| 334 |
+
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
| 335 |
+
|
| 336 |
+
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
|
| 337 |
+
if pretrained_image:
|
| 338 |
+
if is_timm_model:
|
| 339 |
+
# pretrained weight loading for timm models set via vision_cfg
|
| 340 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
| 341 |
+
else:
|
| 342 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
| 343 |
+
|
| 344 |
+
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
|
| 345 |
+
cast_dtype = get_cast_dtype(precision)
|
| 346 |
+
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
| 347 |
+
if is_hf_model:
|
| 348 |
+
# load pretrained weights for HF text model IFF no CLIP weights being loaded
|
| 349 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
|
| 350 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
| 351 |
+
|
| 352 |
+
model_cfg.update({"long_clip": long_clip})
|
| 353 |
+
model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
|
| 354 |
+
if custom_text:
|
| 355 |
+
if "multimodal_cfg" in model_cfg:
|
| 356 |
+
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
| 357 |
+
else:
|
| 358 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
| 359 |
+
else:
|
| 360 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
| 361 |
+
|
| 362 |
+
if precision in ("fp16", "bf16"):
|
| 363 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
| 364 |
+
# manual mixed precision that matches original OpenAI behaviour
|
| 365 |
+
if is_timm_model:
|
| 366 |
+
# FIXME this is a bit janky, create timm based model in low-precision and
|
| 367 |
+
# then cast only LayerNormFp32 instances back to float32 so they don't break.
|
| 368 |
+
# Why? The convert_weights_to_lp fn only works with native models.
|
| 369 |
+
model.to(device=device, dtype=dtype)
|
| 370 |
+
from .transformer import LayerNormFp32
|
| 371 |
+
|
| 372 |
+
def _convert_ln(m):
|
| 373 |
+
if isinstance(m, LayerNormFp32):
|
| 374 |
+
m.weight.data = m.weight.data.to(torch.float32)
|
| 375 |
+
m.bias.data = m.bias.data.to(torch.float32)
|
| 376 |
+
model.apply(_convert_ln)
|
| 377 |
+
else:
|
| 378 |
+
model.to(device=device)
|
| 379 |
+
convert_weights_to_lp(model, dtype=dtype)
|
| 380 |
+
elif precision in ("pure_fp16", "pure_bf16"):
|
| 381 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
| 382 |
+
model.to(device=device, dtype=dtype)
|
| 383 |
+
else:
|
| 384 |
+
model.to(device=device)
|
| 385 |
+
|
| 386 |
+
pretrained_loaded = False
|
| 387 |
+
if pretrained:
|
| 388 |
+
checkpoint_path = ''
|
| 389 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
| 390 |
+
if pretrained_cfg:
|
| 391 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
| 392 |
+
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
|
| 393 |
+
pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)
|
| 394 |
+
model_quick_gelu = model_cfg.get('quick_gelu', False)
|
| 395 |
+
if pretrained_quick_gelu and not model_quick_gelu:
|
| 396 |
+
warnings.warn(
|
| 397 |
+
f'These pretrained weights were trained with QuickGELU activation but the model config does '
|
| 398 |
+
f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.')
|
| 399 |
+
elif not pretrained_quick_gelu and model_quick_gelu:
|
| 400 |
+
warnings.warn(
|
| 401 |
+
f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '
|
| 402 |
+
f'model config, consider using a model config without QuickGELU or disable override flags.')
|
| 403 |
+
elif os.path.exists(pretrained):
|
| 404 |
+
checkpoint_path = pretrained
|
| 405 |
+
|
| 406 |
+
if checkpoint_path:
|
| 407 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
| 408 |
+
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only, strict=False)
|
| 409 |
+
else:
|
| 410 |
+
error_str = (
|
| 411 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
| 412 |
+
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
| 413 |
+
logging.warning(error_str)
|
| 414 |
+
raise RuntimeError(error_str)
|
| 415 |
+
pretrained_loaded = True
|
| 416 |
+
elif has_hf_hub_prefix:
|
| 417 |
+
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
|
| 418 |
+
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
|
| 419 |
+
pretrained_loaded = True
|
| 420 |
+
|
| 421 |
+
if require_pretrained and not pretrained_loaded:
|
| 422 |
+
# callers of create_model_from_pretrained always expect pretrained weights
|
| 423 |
+
raise RuntimeError(
|
| 424 |
+
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
| 425 |
+
|
| 426 |
+
if output_dict and hasattr(model, "output_dict"):
|
| 427 |
+
model.output_dict = True
|
| 428 |
+
|
| 429 |
+
if jit:
|
| 430 |
+
model = torch.jit.script(model)
|
| 431 |
+
|
| 432 |
+
# set image preprocessing configuration in model attributes for convenience
|
| 433 |
+
if getattr(model.visual, 'image_size', None) is not None:
|
| 434 |
+
# use image_size set on model creation (via config or force_image_size arg)
|
| 435 |
+
force_preprocess_cfg['size'] = model.visual.image_size
|
| 436 |
+
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
|
| 437 |
+
|
| 438 |
+
return model
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def create_loss(args):
|
| 442 |
+
if args.distill:
|
| 443 |
+
return DistillClipLoss(
|
| 444 |
+
local_loss=args.local_loss,
|
| 445 |
+
gather_with_grad=args.gather_with_grad,
|
| 446 |
+
cache_labels=True,
|
| 447 |
+
rank=args.rank,
|
| 448 |
+
world_size=args.world_size,
|
| 449 |
+
use_horovod=args.horovod,
|
| 450 |
+
)
|
| 451 |
+
elif "coca" in args.model.lower():
|
| 452 |
+
return CoCaLoss(
|
| 453 |
+
caption_loss_weight=args.coca_caption_loss_weight,
|
| 454 |
+
clip_loss_weight=args.coca_contrastive_loss_weight,
|
| 455 |
+
local_loss=args.local_loss,
|
| 456 |
+
gather_with_grad=args.gather_with_grad,
|
| 457 |
+
cache_labels=True,
|
| 458 |
+
rank=args.rank,
|
| 459 |
+
world_size=args.world_size,
|
| 460 |
+
use_horovod=args.horovod,
|
| 461 |
+
)
|
| 462 |
+
elif args.siglip:
|
| 463 |
+
assert not args.horovod, "Horovod not currently supported for SigLip"
|
| 464 |
+
return SigLipLoss(
|
| 465 |
+
rank=args.rank,
|
| 466 |
+
world_size=args.world_size,
|
| 467 |
+
dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
|
| 468 |
+
)
|
| 469 |
+
# elif args.mpcl_loss:
|
| 470 |
+
# return MultiPosConLossMM(
|
| 471 |
+
# rank=args.rank,
|
| 472 |
+
# world_size=args.world_size,
|
| 473 |
+
# temperature=0.07, w1=1.0, w2=1.0
|
| 474 |
+
# )
|
| 475 |
+
|
| 476 |
+
return ClipLoss(
|
| 477 |
+
local_loss=args.local_loss,
|
| 478 |
+
gather_with_grad=args.gather_with_grad,
|
| 479 |
+
cache_labels=True,
|
| 480 |
+
rank=args.rank,
|
| 481 |
+
world_size=args.world_size,
|
| 482 |
+
use_horovod=args.horovod,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def create_model_and_transforms(
|
| 487 |
+
model_name: str,
|
| 488 |
+
pretrained: Optional[str] = None,
|
| 489 |
+
precision: str = 'fp32',
|
| 490 |
+
device: Union[str, torch.device] = 'cpu',
|
| 491 |
+
jit: bool = False,
|
| 492 |
+
force_quick_gelu: bool = False,
|
| 493 |
+
force_custom_text: bool = False,
|
| 494 |
+
force_patch_dropout: Optional[float] = None,
|
| 495 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 496 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
| 497 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
| 498 |
+
image_interpolation: Optional[str] = None,
|
| 499 |
+
image_resize_mode: Optional[str] = None, # only effective for inference
|
| 500 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
| 501 |
+
pretrained_image: bool = False,
|
| 502 |
+
pretrained_hf: bool = True,
|
| 503 |
+
cache_dir: Optional[str] = None,
|
| 504 |
+
output_dict: Optional[bool] = None,
|
| 505 |
+
load_weights_only: bool = True,
|
| 506 |
+
long_clip: Optional[str] = 'disable',
|
| 507 |
+
|
| 508 |
+
use_imagecrop_aug: Optional[bool] = False,
|
| 509 |
+
max_boxes: Optional[int] = 10,
|
| 510 |
+
local_method: str = 'grids',
|
| 511 |
+
**model_kwargs,
|
| 512 |
+
):
|
| 513 |
+
force_preprocess_cfg = merge_preprocess_kwargs(
|
| 514 |
+
{},
|
| 515 |
+
mean=image_mean,
|
| 516 |
+
std=image_std,
|
| 517 |
+
interpolation=image_interpolation,
|
| 518 |
+
resize_mode=image_resize_mode,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
model = create_model(
|
| 522 |
+
model_name,
|
| 523 |
+
pretrained,
|
| 524 |
+
precision=precision,
|
| 525 |
+
device=device,
|
| 526 |
+
jit=jit,
|
| 527 |
+
force_quick_gelu=force_quick_gelu,
|
| 528 |
+
force_custom_text=force_custom_text,
|
| 529 |
+
force_patch_dropout=force_patch_dropout,
|
| 530 |
+
force_image_size=force_image_size,
|
| 531 |
+
force_preprocess_cfg=force_preprocess_cfg,
|
| 532 |
+
pretrained_image=pretrained_image,
|
| 533 |
+
pretrained_hf=pretrained_hf,
|
| 534 |
+
cache_dir=cache_dir,
|
| 535 |
+
output_dict=output_dict,
|
| 536 |
+
load_weights_only=load_weights_only,
|
| 537 |
+
long_clip=long_clip,
|
| 538 |
+
**model_kwargs,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
|
| 542 |
+
|
| 543 |
+
preprocess_train = image_transform_v2(
|
| 544 |
+
pp_cfg,
|
| 545 |
+
is_train=True,
|
| 546 |
+
|
| 547 |
+
use_imagecrop_aug = use_imagecrop_aug,
|
| 548 |
+
max_boxes = max_boxes,
|
| 549 |
+
local_method = local_method,
|
| 550 |
+
aug_cfg=aug_cfg,
|
| 551 |
+
)
|
| 552 |
+
preprocess_val = image_transform_v2(
|
| 553 |
+
pp_cfg,
|
| 554 |
+
is_train=False,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
return model, preprocess_train, preprocess_val
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def create_model_from_pretrained(
|
| 561 |
+
model_name: str,
|
| 562 |
+
pretrained: Optional[str] = None,
|
| 563 |
+
precision: str = 'fp32',
|
| 564 |
+
device: Union[str, torch.device] = 'cpu',
|
| 565 |
+
jit: bool = False,
|
| 566 |
+
force_quick_gelu: bool = False,
|
| 567 |
+
force_custom_text: bool = False,
|
| 568 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 569 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
| 570 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
| 571 |
+
image_interpolation: Optional[str] = None,
|
| 572 |
+
image_resize_mode: Optional[str] = None, # only effective for inference
|
| 573 |
+
return_transform: bool = True,
|
| 574 |
+
cache_dir: Optional[str] = None,
|
| 575 |
+
load_weights_only: bool = True,
|
| 576 |
+
**model_kwargs,
|
| 577 |
+
):
|
| 578 |
+
force_preprocess_cfg = merge_preprocess_kwargs(
|
| 579 |
+
{},
|
| 580 |
+
mean=image_mean,
|
| 581 |
+
std=image_std,
|
| 582 |
+
interpolation=image_interpolation,
|
| 583 |
+
resize_mode=image_resize_mode,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
model = create_model(
|
| 587 |
+
model_name,
|
| 588 |
+
pretrained,
|
| 589 |
+
precision=precision,
|
| 590 |
+
device=device,
|
| 591 |
+
jit=jit,
|
| 592 |
+
force_quick_gelu=force_quick_gelu,
|
| 593 |
+
force_custom_text=force_custom_text,
|
| 594 |
+
force_image_size=force_image_size,
|
| 595 |
+
force_preprocess_cfg=force_preprocess_cfg,
|
| 596 |
+
cache_dir=cache_dir,
|
| 597 |
+
require_pretrained=True,
|
| 598 |
+
load_weights_only=load_weights_only,
|
| 599 |
+
**model_kwargs,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if not return_transform:
|
| 603 |
+
return model
|
| 604 |
+
|
| 605 |
+
preprocess = image_transform_v2(
|
| 606 |
+
PreprocessCfg(**model.visual.preprocess_cfg),
|
| 607 |
+
is_train=False,
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
return model, preprocess
|