EarthEmbeddingExplorer / compute_embeddings.py
ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
#!/usr/bin/env python3
"""
Compute Embeddings for Major-TOM Sentinel-2 Images
This script generates embeddings for Sentinel-2 imagery using various models:
- DINOv2: Vision Transformer trained with self-supervised learning
- SigLIP: Vision-Language model with sigmoid loss
- FarSLIP: Remote sensing fine-tuned CLIP
- SatCLIP: Satellite imagery CLIP with location awareness
Usage:
python compute_embeddings.py --model dinov2 --device cuda:1
python compute_embeddings.py --model siglip --device cuda:5
python compute_embeddings.py --model satclip --device cuda:3
python compute_embeddings.py --model farslip --device cuda:4
Author: Generated by Copilot
"""
import os
import sys
import argparse
import logging
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import torch
from PIL import Image
from tqdm.auto import tqdm
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.absolute()
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from models.load_config import load_and_process_config
# =============================================================================
# Configuration
# =============================================================================
METADATA_PATH = Path("/data1/zyj/Core-S2L2A-249k/Core_S2L2A_249k_crop_384x384_metadata.parquet")
IMAGE_PARQUET_DIR = Path("/data1/zyj/Core-S2L2A-249k/images")
OUTPUT_BASE_DIR = Path("/data1/zyj/EarthEmbeddings/Core-S2L2A-249k")
# Columns to remove from output
COLUMNS_TO_REMOVE = ['cloud_cover', 'nodata', 'geometry_wkt', 'bands', 'image_shape', 'image_dtype']
# Columns to rename
COLUMNS_RENAME = {'crs': 'utm_crs'}
# Pixel bbox for center 384x384 crop from 1068x1068 original
# (1068 - 384) / 2 = 342
PIXEL_BBOX = [342, 342, 726, 726] # [x_min, y_min, x_max, y_max]
# Model output paths
MODEL_OUTPUT_PATHS = {
'dinov2': OUTPUT_BASE_DIR / 'dinov2' / 'DINOv2_crop_384x384.parquet',
'siglip': OUTPUT_BASE_DIR / 'siglip' / 'SigLIP_crop_384x384.parquet',
'farslip': OUTPUT_BASE_DIR / 'farslip' / 'FarSLIP_crop_384x384.parquet',
'satclip': OUTPUT_BASE_DIR / 'satclip' / 'SatCLIP_crop_384x384.parquet',
}
# Batch sizes for different models
BATCH_SIZES = {
'dinov2': 64,
'siglip': 64,
'farslip': 64,
'satclip': 128,
}
# =============================================================================
# Setup Logging
# =============================================================================
def setup_logging(model_name: str):
"""Configure logging to both file and console."""
log_dir = PROJECT_ROOT / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"compute_embeddings_{model_name}.log"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout)
]
)
return logging.getLogger(__name__)
# =============================================================================
# Image Preprocessing Functions
# =============================================================================
def decode_image_bytes(row) -> np.ndarray:
"""
Decode image bytes from parquet row to numpy array.
Args:
row: pandas Series with 'image_bytes', 'image_shape', 'image_dtype'
Returns:
np.ndarray of shape (H, W, 12) with uint16 values
"""
shape = tuple(map(int, row['image_shape']))
dtype = np.dtype(row['image_dtype'])
img_flat = np.frombuffer(row['image_bytes'], dtype=dtype)
return img_flat.reshape(shape)
def extract_rgb_image(img_array: np.ndarray, clip_max: float = 4000.0) -> Image.Image:
"""
Extract RGB channels from 12-band Sentinel-2 array.
Sentinel-2 Bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12]
RGB Mapping: R=B04(idx 3), G=B03(idx 2), B=B02(idx 1)
Args:
img_array: numpy array of shape (H, W, 12)
clip_max: Value to clip reflectance data for visualization
Returns:
PIL.Image: RGB image
"""
# Select RGB Channels: R=B04(3), G=B03(2), B=B02(1)
rgb_bands = img_array[:, :, [3, 2, 1]].astype(np.float32)
# Normalize and Clip
rgb_normalized = np.clip(rgb_bands / clip_max, 0, 1)
# Convert to 8-bit
rgb_uint8 = (rgb_normalized * 255).astype(np.uint8)
return Image.fromarray(rgb_uint8)
# =============================================================================
# Model Loading Functions
# =============================================================================
def load_model(model_name: str, device: str, config: dict):
"""
Load the specified model.
Args:
model_name: One of 'dinov2', 'siglip', 'farslip', 'satclip'
device: Device string like 'cuda:0' or 'cpu'
config: Configuration dictionary from local.yaml
Returns:
Model instance
"""
logger = logging.getLogger(__name__)
if model_name == 'dinov2':
from models.dinov2_model import DINOv2Model
model_config = config.get('dinov2', {})
model = DINOv2Model(
ckpt_path=model_config.get('ckpt_path', '/data1/zyj/checkpoints/dinov2-large'),
model_name='facebook/dinov2-large',
embedding_path=None, # We're generating, not loading
device=device
)
logger.info(f"DINOv2 model loaded on {device}")
return model
elif model_name == 'siglip':
from models.siglip_model import SigLIPModel
model_config = config.get('siglip', {})
model = SigLIPModel(
ckpt_path=model_config.get('ckpt_path', './checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin'),
model_name='ViT-SO400M-14-SigLIP-384',
tokenizer_path=model_config.get('tokenizer_path', './checkpoints/ViT-SO400M-14-SigLIP-384'),
embedding_path=None,
device=device
)
# Disable embedding loading since we set path to None
model.df_embed = None
model.image_embeddings = None
logger.info(f"SigLIP model loaded on {device}")
return model
elif model_name == 'farslip':
from models.farslip_model import FarSLIPModel
model_config = config.get('farslip', {})
model = FarSLIPModel(
ckpt_path=model_config.get('ckpt_path', './checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt'),
model_name='ViT-B-16',
embedding_path=None,
device=device
)
logger.info(f"FarSLIP model loaded on {device}")
return model
elif model_name == 'satclip':
from models.satclip_ms_model import SatCLIPMSModel
model_config = config.get('satclip', {})
model = SatCLIPMSModel(
ckpt_path=model_config.get('ckpt_path', './checkpoints/SatCLIP/satclip-vit16-l40.ckpt'),
embedding_path=None,
device=device
)
logger.info(f"SatCLIP-MS model loaded on {device}")
return model
else:
raise ValueError(f"Unknown model: {model_name}")
# =============================================================================
# Embedding Computation Functions
# =============================================================================
def compute_embedding_single(model, model_name: str, img_array: np.ndarray) -> np.ndarray:
"""
Compute embedding for a single image.
Args:
model: Model instance
model_name: Model identifier
img_array: numpy array of shape (H, W, 12)
Returns:
np.ndarray: 1D embedding vector
"""
if model_name in ['dinov2', 'siglip', 'farslip']:
# These models use RGB input
rgb_img = extract_rgb_image(img_array)
feature = model.encode_image(rgb_img)
if feature is not None:
return feature.cpu().numpy().flatten()
return None
elif model_name == 'satclip':
# SatCLIP can use multi-spectral input directly
feature = model.encode_image(img_array, is_multispectral=True)
if feature is not None:
return feature.cpu().numpy().flatten()
return None
return None
def compute_embedding_batch(model, model_name: str, img_arrays: list) -> list:
"""
Compute embeddings for a batch of images.
Falls back to single-image processing if batch method unavailable.
Args:
model: Model instance
model_name: Model identifier
img_arrays: List of numpy arrays of shape (H, W, 12)
Returns:
List of 1D embedding vectors (numpy arrays), None for failed items
"""
n_images = len(img_arrays)
if model_name in ['dinov2', 'siglip', 'farslip']:
# These models use RGB input
rgb_imgs = [extract_rgb_image(arr) for arr in img_arrays]
# Try batch encoding first
if hasattr(model, 'encode_images'):
try:
features = model.encode_images(rgb_imgs)
if features is not None:
return [features[i].cpu().numpy().flatten() for i in range(len(features))]
except Exception:
pass # Fall back to single processing
# Fall back to single image encoding
results = []
for img in rgb_imgs:
try:
feature = model.encode_image(img)
if feature is not None:
results.append(feature.cpu().numpy().flatten())
else:
results.append(None)
except Exception:
results.append(None)
return results
elif model_name == 'satclip':
# SatCLIP uses multi-spectral input
if hasattr(model, 'encode_images'):
try:
features = model.encode_images(img_arrays, is_multispectral=True)
if features is not None:
return [features[i].cpu().numpy().flatten() for i in range(len(features))]
except Exception:
pass # Fall back to single processing
# Fall back to single image encoding
results = []
for arr in img_arrays:
try:
feature = model.encode_image(arr, is_multispectral=True)
if feature is not None:
results.append(feature.cpu().numpy().flatten())
else:
results.append(None)
except Exception:
results.append(None)
return results
return [None] * n_images
# def process_parquet_file(
# file_path: Path,
# model,
# model_name: str,
# batch_size: int = 64
# ) -> pd.DataFrame:
# """
# Process a single parquet file and generate embeddings.
# Args:
# file_path: Path to input parquet file
# model: Model instance
# model_name: Model identifier
# batch_size: Batch size for processing
# Returns:
# DataFrame with embeddings
# """
# logger = logging.getLogger(__name__)
# # Load data
# df = pd.read_parquet(file_path)
# embeddings_list = []
# valid_indices = []
# # Process in batches (for future batch optimization)
# for idx, row in df.iterrows():
# try:
# # Decode image
# img_array = decode_image_bytes(row)
# # Compute embedding
# embedding = compute_embedding_single(model, model_name, img_array)
# if embedding is not None:
# embeddings_list.append(embedding)
# valid_indices.append(idx)
# except Exception as e:
# logger.warning(f"Error processing row {idx}: {e}")
# continue
# if not embeddings_list:
# logger.warning(f"No valid embeddings for {file_path.name}")
# return None
# # Build result DataFrame
# result_df = df.loc[valid_indices].copy()
# # Remove unwanted columns
# cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns]
# if cols_to_drop:
# result_df = result_df.drop(columns=cols_to_drop)
# # Remove image_bytes (large binary data)
# if 'image_bytes' in result_df.columns:
# result_df = result_df.drop(columns=['image_bytes'])
# # Remove geometry column (binary)
# if 'geometry' in result_df.columns:
# result_df = result_df.drop(columns=['geometry'])
# # Rename columns
# result_df = result_df.rename(columns=COLUMNS_RENAME)
# # Add pixel_bbox
# result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df)
# # Add embedding
# result_df['embedding'] = embeddings_list
# return result_df
def process_parquet_file(
file_path: Path,
model,
model_name: str,
batch_size: int = 64
) -> pd.DataFrame:
"""
Process a single parquet file and generate embeddings using batch processing.
Args:
file_path: Path to input parquet file
model: Model instance
model_name: Model identifier
batch_size: Batch size for processing
Returns:
DataFrame with embeddings
"""
logger = logging.getLogger(__name__)
# Load data
df = pd.read_parquet(file_path)
n_rows = len(df)
embeddings_list = [None] * n_rows
valid_mask = [False] * n_rows
# Process in batches
for batch_start in range(0, n_rows, batch_size):
batch_end = min(batch_start + batch_size, n_rows)
batch_indices = list(range(batch_start, batch_end))
# Decode images for this batch
batch_arrays = []
batch_valid_indices = []
for idx in batch_indices:
try:
row = df.iloc[idx]
img_array = decode_image_bytes(row)
batch_arrays.append(img_array)
batch_valid_indices.append(idx)
except Exception as e:
logger.warning(f"Error decoding row {idx}: {e}")
continue
if not batch_arrays:
continue
# Compute embeddings for this batch
try:
batch_embeddings = compute_embedding_batch(model, model_name, batch_arrays)
# Store results
for i, idx in enumerate(batch_valid_indices):
if batch_embeddings[i] is not None:
embeddings_list[idx] = batch_embeddings[i]
valid_mask[idx] = True
except Exception as e:
logger.warning(f"Error computing batch embeddings: {e}")
# Fall back to single image processing for this batch
for i, idx in enumerate(batch_valid_indices):
try:
embedding = compute_embedding_single(model, model_name, batch_arrays[i])
if embedding is not None:
embeddings_list[idx] = embedding
valid_mask[idx] = True
except Exception as inner_e:
logger.warning(f"Error processing row {idx}: {inner_e}")
continue
# Filter to valid rows only
valid_indices = [i for i, v in enumerate(valid_mask) if v]
if not valid_indices:
logger.warning(f"No valid embeddings for {file_path.name}")
return None
# Build result DataFrame
result_df = df.iloc[valid_indices].copy()
valid_embeddings = [embeddings_list[i] for i in valid_indices]
# Remove unwanted columns
cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns]
if cols_to_drop:
result_df = result_df.drop(columns=cols_to_drop)
# Remove image_bytes (large binary data)
if 'image_bytes' in result_df.columns:
result_df = result_df.drop(columns=['image_bytes'])
# Remove geometry column (binary)
if 'geometry' in result_df.columns:
result_df = result_df.drop(columns=['geometry'])
# Rename columns
result_df = result_df.rename(columns=COLUMNS_RENAME)
# Add pixel_bbox
result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df)
# Add embedding
result_df['embedding'] = valid_embeddings
return result_df
# =============================================================================
# Main Processing Pipeline
# =============================================================================
def main():
parser = argparse.ArgumentParser(description='Compute embeddings for Major-TOM images')
parser.add_argument('--model', type=str, required=True,
choices=['dinov2', 'siglip', 'farslip', 'satclip'],
help='Model to use for embedding computation')
parser.add_argument('--device', type=str, default='cuda:0',
help='Device to run on (e.g., cuda:0, cuda:1, cpu)')
parser.add_argument('--batch-size', type=int, default=None,
help='Batch size for processing (default: model-specific)')
parser.add_argument('--max-files', type=int, default=None,
help='Maximum number of files to process (for testing)')
args = parser.parse_args()
# Setup logging
logger = setup_logging(args.model)
logger.info("=" * 80)
logger.info(f"Computing {args.model.upper()} embeddings")
logger.info(f"Timestamp: {datetime.now().isoformat()}")
logger.info(f"Device: {args.device}")
logger.info("=" * 80)
# Load configuration
config = load_and_process_config()
if config is None:
logger.warning("No config file found, using default paths")
config = {}
# Determine batch size
batch_size = args.batch_size or BATCH_SIZES.get(args.model, 64)
logger.info(f"Batch size: {batch_size}")
# Get output path
output_path = MODEL_OUTPUT_PATHS[args.model]
output_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Output path: {output_path}")
# Load model
logger.info(f"Loading {args.model} model...")
model = load_model(args.model, args.device, config)
# Get input files
parquet_files = sorted(IMAGE_PARQUET_DIR.glob("batch_*.parquet"))
if args.max_files:
parquet_files = parquet_files[:args.max_files]
logger.info(f"Found {len(parquet_files)} input files")
# Process files
all_results = []
total_rows = 0
for file_path in tqdm(parquet_files, desc=f"Processing {args.model}"):
try:
result_df = process_parquet_file(file_path, model, args.model, batch_size)
if result_df is not None:
all_results.append(result_df)
total_rows += len(result_df)
logger.info(f"[{file_path.name}] Processed {len(result_df)} rows")
except Exception as e:
logger.error(f"Error processing {file_path.name}: {e}")
import traceback
traceback.print_exc()
continue
# Merge and save
if all_results:
logger.info("Merging all results...")
final_df = pd.concat(all_results, ignore_index=True)
# Validate columns
logger.info(f"Final columns: {list(final_df.columns)}")
# Check for removed columns
removed = [c for c in COLUMNS_TO_REMOVE if c in final_df.columns]
if removed:
logger.warning(f"Columns still present that should be removed: {removed}")
else:
logger.info("✓ All unwanted columns removed")
# Check for renamed columns
if 'utm_crs' in final_df.columns and 'crs' not in final_df.columns:
logger.info("✓ Column 'crs' renamed to 'utm_crs'")
# Check for pixel_bbox
if 'pixel_bbox' in final_df.columns:
logger.info("✓ Column 'pixel_bbox' added")
# Save
logger.info(f"Saving to {output_path}...")
final_df.to_parquet(output_path, index=False)
logger.info(f"=" * 80)
logger.info(f"Processing complete!")
logger.info(f"Total rows: {len(final_df):,}")
logger.info(f"Embedding dimension: {len(final_df['embedding'].iloc[0])}")
logger.info(f"Output file: {output_path}")
logger.info(f"=" * 80)
else:
logger.error("No data processed!")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())