|
|
|
|
|
|
|
|
""" |
|
|
Compute Embeddings for Major-TOM Sentinel-2 Images |
|
|
|
|
|
This script generates embeddings for Sentinel-2 imagery using various models: |
|
|
- DINOv2: Vision Transformer trained with self-supervised learning |
|
|
- SigLIP: Vision-Language model with sigmoid loss |
|
|
- FarSLIP: Remote sensing fine-tuned CLIP |
|
|
- SatCLIP: Satellite imagery CLIP with location awareness |
|
|
|
|
|
Usage: |
|
|
python compute_embeddings.py --model dinov2 --device cuda:1 |
|
|
python compute_embeddings.py --model siglip --device cuda:5 |
|
|
python compute_embeddings.py --model satclip --device cuda:3 |
|
|
python compute_embeddings.py --model farslip --device cuda:4 |
|
|
|
|
|
Author: Generated by Copilot |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from PIL import Image |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.absolute() |
|
|
if str(PROJECT_ROOT) not in sys.path: |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
|
|
from models.load_config import load_and_process_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
METADATA_PATH = Path("/data1/zyj/Core-S2L2A-249k/Core_S2L2A_249k_crop_384x384_metadata.parquet") |
|
|
IMAGE_PARQUET_DIR = Path("/data1/zyj/Core-S2L2A-249k/images") |
|
|
OUTPUT_BASE_DIR = Path("/data1/zyj/EarthEmbeddings/Core-S2L2A-249k") |
|
|
|
|
|
|
|
|
COLUMNS_TO_REMOVE = ['cloud_cover', 'nodata', 'geometry_wkt', 'bands', 'image_shape', 'image_dtype'] |
|
|
|
|
|
|
|
|
COLUMNS_RENAME = {'crs': 'utm_crs'} |
|
|
|
|
|
|
|
|
|
|
|
PIXEL_BBOX = [342, 342, 726, 726] |
|
|
|
|
|
|
|
|
MODEL_OUTPUT_PATHS = { |
|
|
'dinov2': OUTPUT_BASE_DIR / 'dinov2' / 'DINOv2_crop_384x384.parquet', |
|
|
'siglip': OUTPUT_BASE_DIR / 'siglip' / 'SigLIP_crop_384x384.parquet', |
|
|
'farslip': OUTPUT_BASE_DIR / 'farslip' / 'FarSLIP_crop_384x384.parquet', |
|
|
'satclip': OUTPUT_BASE_DIR / 'satclip' / 'SatCLIP_crop_384x384.parquet', |
|
|
} |
|
|
|
|
|
|
|
|
BATCH_SIZES = { |
|
|
'dinov2': 64, |
|
|
'siglip': 64, |
|
|
'farslip': 64, |
|
|
'satclip': 128, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_logging(model_name: str): |
|
|
"""Configure logging to both file and console.""" |
|
|
log_dir = PROJECT_ROOT / "logs" |
|
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
|
log_file = log_dir / f"compute_embeddings_{model_name}.log" |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
|
handlers=[ |
|
|
logging.FileHandler(log_file), |
|
|
logging.StreamHandler(sys.stdout) |
|
|
] |
|
|
) |
|
|
return logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_image_bytes(row) -> np.ndarray: |
|
|
""" |
|
|
Decode image bytes from parquet row to numpy array. |
|
|
|
|
|
Args: |
|
|
row: pandas Series with 'image_bytes', 'image_shape', 'image_dtype' |
|
|
|
|
|
Returns: |
|
|
np.ndarray of shape (H, W, 12) with uint16 values |
|
|
""" |
|
|
shape = tuple(map(int, row['image_shape'])) |
|
|
dtype = np.dtype(row['image_dtype']) |
|
|
img_flat = np.frombuffer(row['image_bytes'], dtype=dtype) |
|
|
return img_flat.reshape(shape) |
|
|
|
|
|
|
|
|
def extract_rgb_image(img_array: np.ndarray, clip_max: float = 4000.0) -> Image.Image: |
|
|
""" |
|
|
Extract RGB channels from 12-band Sentinel-2 array. |
|
|
|
|
|
Sentinel-2 Bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12] |
|
|
RGB Mapping: R=B04(idx 3), G=B03(idx 2), B=B02(idx 1) |
|
|
|
|
|
Args: |
|
|
img_array: numpy array of shape (H, W, 12) |
|
|
clip_max: Value to clip reflectance data for visualization |
|
|
|
|
|
Returns: |
|
|
PIL.Image: RGB image |
|
|
""" |
|
|
|
|
|
rgb_bands = img_array[:, :, [3, 2, 1]].astype(np.float32) |
|
|
|
|
|
|
|
|
rgb_normalized = np.clip(rgb_bands / clip_max, 0, 1) |
|
|
|
|
|
|
|
|
rgb_uint8 = (rgb_normalized * 255).astype(np.uint8) |
|
|
|
|
|
return Image.fromarray(rgb_uint8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(model_name: str, device: str, config: dict): |
|
|
""" |
|
|
Load the specified model. |
|
|
|
|
|
Args: |
|
|
model_name: One of 'dinov2', 'siglip', 'farslip', 'satclip' |
|
|
device: Device string like 'cuda:0' or 'cpu' |
|
|
config: Configuration dictionary from local.yaml |
|
|
|
|
|
Returns: |
|
|
Model instance |
|
|
""" |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if model_name == 'dinov2': |
|
|
from models.dinov2_model import DINOv2Model |
|
|
model_config = config.get('dinov2', {}) |
|
|
model = DINOv2Model( |
|
|
ckpt_path=model_config.get('ckpt_path', '/data1/zyj/checkpoints/dinov2-large'), |
|
|
model_name='facebook/dinov2-large', |
|
|
embedding_path=None, |
|
|
device=device |
|
|
) |
|
|
logger.info(f"DINOv2 model loaded on {device}") |
|
|
return model |
|
|
|
|
|
elif model_name == 'siglip': |
|
|
from models.siglip_model import SigLIPModel |
|
|
model_config = config.get('siglip', {}) |
|
|
model = SigLIPModel( |
|
|
ckpt_path=model_config.get('ckpt_path', './checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin'), |
|
|
model_name='ViT-SO400M-14-SigLIP-384', |
|
|
tokenizer_path=model_config.get('tokenizer_path', './checkpoints/ViT-SO400M-14-SigLIP-384'), |
|
|
embedding_path=None, |
|
|
device=device |
|
|
) |
|
|
|
|
|
model.df_embed = None |
|
|
model.image_embeddings = None |
|
|
logger.info(f"SigLIP model loaded on {device}") |
|
|
return model |
|
|
|
|
|
elif model_name == 'farslip': |
|
|
from models.farslip_model import FarSLIPModel |
|
|
model_config = config.get('farslip', {}) |
|
|
model = FarSLIPModel( |
|
|
ckpt_path=model_config.get('ckpt_path', './checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt'), |
|
|
model_name='ViT-B-16', |
|
|
embedding_path=None, |
|
|
device=device |
|
|
) |
|
|
logger.info(f"FarSLIP model loaded on {device}") |
|
|
return model |
|
|
|
|
|
elif model_name == 'satclip': |
|
|
from models.satclip_ms_model import SatCLIPMSModel |
|
|
model_config = config.get('satclip', {}) |
|
|
model = SatCLIPMSModel( |
|
|
ckpt_path=model_config.get('ckpt_path', './checkpoints/SatCLIP/satclip-vit16-l40.ckpt'), |
|
|
embedding_path=None, |
|
|
device=device |
|
|
) |
|
|
logger.info(f"SatCLIP-MS model loaded on {device}") |
|
|
return model |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_embedding_single(model, model_name: str, img_array: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Compute embedding for a single image. |
|
|
|
|
|
Args: |
|
|
model: Model instance |
|
|
model_name: Model identifier |
|
|
img_array: numpy array of shape (H, W, 12) |
|
|
|
|
|
Returns: |
|
|
np.ndarray: 1D embedding vector |
|
|
""" |
|
|
if model_name in ['dinov2', 'siglip', 'farslip']: |
|
|
|
|
|
rgb_img = extract_rgb_image(img_array) |
|
|
feature = model.encode_image(rgb_img) |
|
|
if feature is not None: |
|
|
return feature.cpu().numpy().flatten() |
|
|
return None |
|
|
|
|
|
elif model_name == 'satclip': |
|
|
|
|
|
feature = model.encode_image(img_array, is_multispectral=True) |
|
|
if feature is not None: |
|
|
return feature.cpu().numpy().flatten() |
|
|
return None |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def compute_embedding_batch(model, model_name: str, img_arrays: list) -> list: |
|
|
""" |
|
|
Compute embeddings for a batch of images. |
|
|
Falls back to single-image processing if batch method unavailable. |
|
|
|
|
|
Args: |
|
|
model: Model instance |
|
|
model_name: Model identifier |
|
|
img_arrays: List of numpy arrays of shape (H, W, 12) |
|
|
|
|
|
Returns: |
|
|
List of 1D embedding vectors (numpy arrays), None for failed items |
|
|
""" |
|
|
n_images = len(img_arrays) |
|
|
|
|
|
if model_name in ['dinov2', 'siglip', 'farslip']: |
|
|
|
|
|
rgb_imgs = [extract_rgb_image(arr) for arr in img_arrays] |
|
|
|
|
|
|
|
|
if hasattr(model, 'encode_images'): |
|
|
try: |
|
|
features = model.encode_images(rgb_imgs) |
|
|
if features is not None: |
|
|
return [features[i].cpu().numpy().flatten() for i in range(len(features))] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
results = [] |
|
|
for img in rgb_imgs: |
|
|
try: |
|
|
feature = model.encode_image(img) |
|
|
if feature is not None: |
|
|
results.append(feature.cpu().numpy().flatten()) |
|
|
else: |
|
|
results.append(None) |
|
|
except Exception: |
|
|
results.append(None) |
|
|
return results |
|
|
|
|
|
elif model_name == 'satclip': |
|
|
|
|
|
if hasattr(model, 'encode_images'): |
|
|
try: |
|
|
features = model.encode_images(img_arrays, is_multispectral=True) |
|
|
if features is not None: |
|
|
return [features[i].cpu().numpy().flatten() for i in range(len(features))] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
results = [] |
|
|
for arr in img_arrays: |
|
|
try: |
|
|
feature = model.encode_image(arr, is_multispectral=True) |
|
|
if feature is not None: |
|
|
results.append(feature.cpu().numpy().flatten()) |
|
|
else: |
|
|
results.append(None) |
|
|
except Exception: |
|
|
results.append(None) |
|
|
return results |
|
|
|
|
|
return [None] * n_images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_parquet_file( |
|
|
file_path: Path, |
|
|
model, |
|
|
model_name: str, |
|
|
batch_size: int = 64 |
|
|
) -> pd.DataFrame: |
|
|
""" |
|
|
Process a single parquet file and generate embeddings using batch processing. |
|
|
|
|
|
Args: |
|
|
file_path: Path to input parquet file |
|
|
model: Model instance |
|
|
model_name: Model identifier |
|
|
batch_size: Batch size for processing |
|
|
|
|
|
Returns: |
|
|
DataFrame with embeddings |
|
|
""" |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
df = pd.read_parquet(file_path) |
|
|
n_rows = len(df) |
|
|
|
|
|
embeddings_list = [None] * n_rows |
|
|
valid_mask = [False] * n_rows |
|
|
|
|
|
|
|
|
for batch_start in range(0, n_rows, batch_size): |
|
|
batch_end = min(batch_start + batch_size, n_rows) |
|
|
batch_indices = list(range(batch_start, batch_end)) |
|
|
|
|
|
|
|
|
batch_arrays = [] |
|
|
batch_valid_indices = [] |
|
|
|
|
|
for idx in batch_indices: |
|
|
try: |
|
|
row = df.iloc[idx] |
|
|
img_array = decode_image_bytes(row) |
|
|
batch_arrays.append(img_array) |
|
|
batch_valid_indices.append(idx) |
|
|
except Exception as e: |
|
|
logger.warning(f"Error decoding row {idx}: {e}") |
|
|
continue |
|
|
|
|
|
if not batch_arrays: |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
batch_embeddings = compute_embedding_batch(model, model_name, batch_arrays) |
|
|
|
|
|
|
|
|
for i, idx in enumerate(batch_valid_indices): |
|
|
if batch_embeddings[i] is not None: |
|
|
embeddings_list[idx] = batch_embeddings[i] |
|
|
valid_mask[idx] = True |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Error computing batch embeddings: {e}") |
|
|
|
|
|
for i, idx in enumerate(batch_valid_indices): |
|
|
try: |
|
|
embedding = compute_embedding_single(model, model_name, batch_arrays[i]) |
|
|
if embedding is not None: |
|
|
embeddings_list[idx] = embedding |
|
|
valid_mask[idx] = True |
|
|
except Exception as inner_e: |
|
|
logger.warning(f"Error processing row {idx}: {inner_e}") |
|
|
continue |
|
|
|
|
|
|
|
|
valid_indices = [i for i, v in enumerate(valid_mask) if v] |
|
|
|
|
|
if not valid_indices: |
|
|
logger.warning(f"No valid embeddings for {file_path.name}") |
|
|
return None |
|
|
|
|
|
|
|
|
result_df = df.iloc[valid_indices].copy() |
|
|
valid_embeddings = [embeddings_list[i] for i in valid_indices] |
|
|
|
|
|
|
|
|
cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns] |
|
|
if cols_to_drop: |
|
|
result_df = result_df.drop(columns=cols_to_drop) |
|
|
|
|
|
|
|
|
if 'image_bytes' in result_df.columns: |
|
|
result_df = result_df.drop(columns=['image_bytes']) |
|
|
|
|
|
|
|
|
if 'geometry' in result_df.columns: |
|
|
result_df = result_df.drop(columns=['geometry']) |
|
|
|
|
|
|
|
|
result_df = result_df.rename(columns=COLUMNS_RENAME) |
|
|
|
|
|
|
|
|
result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df) |
|
|
|
|
|
|
|
|
result_df['embedding'] = valid_embeddings |
|
|
|
|
|
return result_df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Compute embeddings for Major-TOM images') |
|
|
parser.add_argument('--model', type=str, required=True, |
|
|
choices=['dinov2', 'siglip', 'farslip', 'satclip'], |
|
|
help='Model to use for embedding computation') |
|
|
parser.add_argument('--device', type=str, default='cuda:0', |
|
|
help='Device to run on (e.g., cuda:0, cuda:1, cpu)') |
|
|
parser.add_argument('--batch-size', type=int, default=None, |
|
|
help='Batch size for processing (default: model-specific)') |
|
|
parser.add_argument('--max-files', type=int, default=None, |
|
|
help='Maximum number of files to process (for testing)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
logger = setup_logging(args.model) |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Computing {args.model.upper()} embeddings") |
|
|
logger.info(f"Timestamp: {datetime.now().isoformat()}") |
|
|
logger.info(f"Device: {args.device}") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
|
|
|
config = load_and_process_config() |
|
|
if config is None: |
|
|
logger.warning("No config file found, using default paths") |
|
|
config = {} |
|
|
|
|
|
|
|
|
batch_size = args.batch_size or BATCH_SIZES.get(args.model, 64) |
|
|
logger.info(f"Batch size: {batch_size}") |
|
|
|
|
|
|
|
|
output_path = MODEL_OUTPUT_PATHS[args.model] |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
logger.info(f"Output path: {output_path}") |
|
|
|
|
|
|
|
|
logger.info(f"Loading {args.model} model...") |
|
|
model = load_model(args.model, args.device, config) |
|
|
|
|
|
|
|
|
parquet_files = sorted(IMAGE_PARQUET_DIR.glob("batch_*.parquet")) |
|
|
if args.max_files: |
|
|
parquet_files = parquet_files[:args.max_files] |
|
|
|
|
|
logger.info(f"Found {len(parquet_files)} input files") |
|
|
|
|
|
|
|
|
all_results = [] |
|
|
total_rows = 0 |
|
|
|
|
|
for file_path in tqdm(parquet_files, desc=f"Processing {args.model}"): |
|
|
try: |
|
|
result_df = process_parquet_file(file_path, model, args.model, batch_size) |
|
|
|
|
|
if result_df is not None: |
|
|
all_results.append(result_df) |
|
|
total_rows += len(result_df) |
|
|
logger.info(f"[{file_path.name}] Processed {len(result_df)} rows") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing {file_path.name}: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
|
|
|
if all_results: |
|
|
logger.info("Merging all results...") |
|
|
final_df = pd.concat(all_results, ignore_index=True) |
|
|
|
|
|
|
|
|
logger.info(f"Final columns: {list(final_df.columns)}") |
|
|
|
|
|
|
|
|
removed = [c for c in COLUMNS_TO_REMOVE if c in final_df.columns] |
|
|
if removed: |
|
|
logger.warning(f"Columns still present that should be removed: {removed}") |
|
|
else: |
|
|
logger.info("✓ All unwanted columns removed") |
|
|
|
|
|
|
|
|
if 'utm_crs' in final_df.columns and 'crs' not in final_df.columns: |
|
|
logger.info("✓ Column 'crs' renamed to 'utm_crs'") |
|
|
|
|
|
|
|
|
if 'pixel_bbox' in final_df.columns: |
|
|
logger.info("✓ Column 'pixel_bbox' added") |
|
|
|
|
|
|
|
|
logger.info(f"Saving to {output_path}...") |
|
|
final_df.to_parquet(output_path, index=False) |
|
|
|
|
|
logger.info(f"=" * 80) |
|
|
logger.info(f"Processing complete!") |
|
|
logger.info(f"Total rows: {len(final_df):,}") |
|
|
logger.info(f"Embedding dimension: {len(final_df['embedding'].iloc[0])}") |
|
|
logger.info(f"Output file: {output_path}") |
|
|
logger.info(f"=" * 80) |
|
|
|
|
|
else: |
|
|
logger.error("No data processed!") |
|
|
return 1 |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|