#!/usr/bin/env python3 """ Compute Embeddings for Major-TOM Sentinel-2 Images This script generates embeddings for Sentinel-2 imagery using various models: - DINOv2: Vision Transformer trained with self-supervised learning - SigLIP: Vision-Language model with sigmoid loss - FarSLIP: Remote sensing fine-tuned CLIP - SatCLIP: Satellite imagery CLIP with location awareness Usage: python compute_embeddings.py --model dinov2 --device cuda:1 python compute_embeddings.py --model siglip --device cuda:5 python compute_embeddings.py --model satclip --device cuda:3 python compute_embeddings.py --model farslip --device cuda:4 Author: Generated by Copilot """ import os import sys import argparse import logging from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import torch from PIL import Image from tqdm.auto import tqdm # Add project root to path PROJECT_ROOT = Path(__file__).parent.absolute() if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from models.load_config import load_and_process_config # ============================================================================= # Configuration # ============================================================================= METADATA_PATH = Path("/data1/zyj/Core-S2L2A-249k/Core_S2L2A_249k_crop_384x384_metadata.parquet") IMAGE_PARQUET_DIR = Path("/data1/zyj/Core-S2L2A-249k/images") OUTPUT_BASE_DIR = Path("/data1/zyj/EarthEmbeddings/Core-S2L2A-249k") # Columns to remove from output COLUMNS_TO_REMOVE = ['cloud_cover', 'nodata', 'geometry_wkt', 'bands', 'image_shape', 'image_dtype'] # Columns to rename COLUMNS_RENAME = {'crs': 'utm_crs'} # Pixel bbox for center 384x384 crop from 1068x1068 original # (1068 - 384) / 2 = 342 PIXEL_BBOX = [342, 342, 726, 726] # [x_min, y_min, x_max, y_max] # Model output paths MODEL_OUTPUT_PATHS = { 'dinov2': OUTPUT_BASE_DIR / 'dinov2' / 'DINOv2_crop_384x384.parquet', 'siglip': OUTPUT_BASE_DIR / 'siglip' / 'SigLIP_crop_384x384.parquet', 'farslip': OUTPUT_BASE_DIR / 'farslip' / 'FarSLIP_crop_384x384.parquet', 'satclip': OUTPUT_BASE_DIR / 'satclip' / 'SatCLIP_crop_384x384.parquet', } # Batch sizes for different models BATCH_SIZES = { 'dinov2': 64, 'siglip': 64, 'farslip': 64, 'satclip': 128, } # ============================================================================= # Setup Logging # ============================================================================= def setup_logging(model_name: str): """Configure logging to both file and console.""" log_dir = PROJECT_ROOT / "logs" log_dir.mkdir(parents=True, exist_ok=True) log_file = log_dir / f"compute_embeddings_{model_name}.log" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.FileHandler(log_file), logging.StreamHandler(sys.stdout) ] ) return logging.getLogger(__name__) # ============================================================================= # Image Preprocessing Functions # ============================================================================= def decode_image_bytes(row) -> np.ndarray: """ Decode image bytes from parquet row to numpy array. Args: row: pandas Series with 'image_bytes', 'image_shape', 'image_dtype' Returns: np.ndarray of shape (H, W, 12) with uint16 values """ shape = tuple(map(int, row['image_shape'])) dtype = np.dtype(row['image_dtype']) img_flat = np.frombuffer(row['image_bytes'], dtype=dtype) return img_flat.reshape(shape) def extract_rgb_image(img_array: np.ndarray, clip_max: float = 4000.0) -> Image.Image: """ Extract RGB channels from 12-band Sentinel-2 array. Sentinel-2 Bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12] RGB Mapping: R=B04(idx 3), G=B03(idx 2), B=B02(idx 1) Args: img_array: numpy array of shape (H, W, 12) clip_max: Value to clip reflectance data for visualization Returns: PIL.Image: RGB image """ # Select RGB Channels: R=B04(3), G=B03(2), B=B02(1) rgb_bands = img_array[:, :, [3, 2, 1]].astype(np.float32) # Normalize and Clip rgb_normalized = np.clip(rgb_bands / clip_max, 0, 1) # Convert to 8-bit rgb_uint8 = (rgb_normalized * 255).astype(np.uint8) return Image.fromarray(rgb_uint8) # ============================================================================= # Model Loading Functions # ============================================================================= def load_model(model_name: str, device: str, config: dict): """ Load the specified model. Args: model_name: One of 'dinov2', 'siglip', 'farslip', 'satclip' device: Device string like 'cuda:0' or 'cpu' config: Configuration dictionary from local.yaml Returns: Model instance """ logger = logging.getLogger(__name__) if model_name == 'dinov2': from models.dinov2_model import DINOv2Model model_config = config.get('dinov2', {}) model = DINOv2Model( ckpt_path=model_config.get('ckpt_path', '/data1/zyj/checkpoints/dinov2-large'), model_name='facebook/dinov2-large', embedding_path=None, # We're generating, not loading device=device ) logger.info(f"DINOv2 model loaded on {device}") return model elif model_name == 'siglip': from models.siglip_model import SigLIPModel model_config = config.get('siglip', {}) model = SigLIPModel( ckpt_path=model_config.get('ckpt_path', './checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin'), model_name='ViT-SO400M-14-SigLIP-384', tokenizer_path=model_config.get('tokenizer_path', './checkpoints/ViT-SO400M-14-SigLIP-384'), embedding_path=None, device=device ) # Disable embedding loading since we set path to None model.df_embed = None model.image_embeddings = None logger.info(f"SigLIP model loaded on {device}") return model elif model_name == 'farslip': from models.farslip_model import FarSLIPModel model_config = config.get('farslip', {}) model = FarSLIPModel( ckpt_path=model_config.get('ckpt_path', './checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt'), model_name='ViT-B-16', embedding_path=None, device=device ) logger.info(f"FarSLIP model loaded on {device}") return model elif model_name == 'satclip': from models.satclip_ms_model import SatCLIPMSModel model_config = config.get('satclip', {}) model = SatCLIPMSModel( ckpt_path=model_config.get('ckpt_path', './checkpoints/SatCLIP/satclip-vit16-l40.ckpt'), embedding_path=None, device=device ) logger.info(f"SatCLIP-MS model loaded on {device}") return model else: raise ValueError(f"Unknown model: {model_name}") # ============================================================================= # Embedding Computation Functions # ============================================================================= def compute_embedding_single(model, model_name: str, img_array: np.ndarray) -> np.ndarray: """ Compute embedding for a single image. Args: model: Model instance model_name: Model identifier img_array: numpy array of shape (H, W, 12) Returns: np.ndarray: 1D embedding vector """ if model_name in ['dinov2', 'siglip', 'farslip']: # These models use RGB input rgb_img = extract_rgb_image(img_array) feature = model.encode_image(rgb_img) if feature is not None: return feature.cpu().numpy().flatten() return None elif model_name == 'satclip': # SatCLIP can use multi-spectral input directly feature = model.encode_image(img_array, is_multispectral=True) if feature is not None: return feature.cpu().numpy().flatten() return None return None def compute_embedding_batch(model, model_name: str, img_arrays: list) -> list: """ Compute embeddings for a batch of images. Falls back to single-image processing if batch method unavailable. Args: model: Model instance model_name: Model identifier img_arrays: List of numpy arrays of shape (H, W, 12) Returns: List of 1D embedding vectors (numpy arrays), None for failed items """ n_images = len(img_arrays) if model_name in ['dinov2', 'siglip', 'farslip']: # These models use RGB input rgb_imgs = [extract_rgb_image(arr) for arr in img_arrays] # Try batch encoding first if hasattr(model, 'encode_images'): try: features = model.encode_images(rgb_imgs) if features is not None: return [features[i].cpu().numpy().flatten() for i in range(len(features))] except Exception: pass # Fall back to single processing # Fall back to single image encoding results = [] for img in rgb_imgs: try: feature = model.encode_image(img) if feature is not None: results.append(feature.cpu().numpy().flatten()) else: results.append(None) except Exception: results.append(None) return results elif model_name == 'satclip': # SatCLIP uses multi-spectral input if hasattr(model, 'encode_images'): try: features = model.encode_images(img_arrays, is_multispectral=True) if features is not None: return [features[i].cpu().numpy().flatten() for i in range(len(features))] except Exception: pass # Fall back to single processing # Fall back to single image encoding results = [] for arr in img_arrays: try: feature = model.encode_image(arr, is_multispectral=True) if feature is not None: results.append(feature.cpu().numpy().flatten()) else: results.append(None) except Exception: results.append(None) return results return [None] * n_images # def process_parquet_file( # file_path: Path, # model, # model_name: str, # batch_size: int = 64 # ) -> pd.DataFrame: # """ # Process a single parquet file and generate embeddings. # Args: # file_path: Path to input parquet file # model: Model instance # model_name: Model identifier # batch_size: Batch size for processing # Returns: # DataFrame with embeddings # """ # logger = logging.getLogger(__name__) # # Load data # df = pd.read_parquet(file_path) # embeddings_list = [] # valid_indices = [] # # Process in batches (for future batch optimization) # for idx, row in df.iterrows(): # try: # # Decode image # img_array = decode_image_bytes(row) # # Compute embedding # embedding = compute_embedding_single(model, model_name, img_array) # if embedding is not None: # embeddings_list.append(embedding) # valid_indices.append(idx) # except Exception as e: # logger.warning(f"Error processing row {idx}: {e}") # continue # if not embeddings_list: # logger.warning(f"No valid embeddings for {file_path.name}") # return None # # Build result DataFrame # result_df = df.loc[valid_indices].copy() # # Remove unwanted columns # cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns] # if cols_to_drop: # result_df = result_df.drop(columns=cols_to_drop) # # Remove image_bytes (large binary data) # if 'image_bytes' in result_df.columns: # result_df = result_df.drop(columns=['image_bytes']) # # Remove geometry column (binary) # if 'geometry' in result_df.columns: # result_df = result_df.drop(columns=['geometry']) # # Rename columns # result_df = result_df.rename(columns=COLUMNS_RENAME) # # Add pixel_bbox # result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df) # # Add embedding # result_df['embedding'] = embeddings_list # return result_df def process_parquet_file( file_path: Path, model, model_name: str, batch_size: int = 64 ) -> pd.DataFrame: """ Process a single parquet file and generate embeddings using batch processing. Args: file_path: Path to input parquet file model: Model instance model_name: Model identifier batch_size: Batch size for processing Returns: DataFrame with embeddings """ logger = logging.getLogger(__name__) # Load data df = pd.read_parquet(file_path) n_rows = len(df) embeddings_list = [None] * n_rows valid_mask = [False] * n_rows # Process in batches for batch_start in range(0, n_rows, batch_size): batch_end = min(batch_start + batch_size, n_rows) batch_indices = list(range(batch_start, batch_end)) # Decode images for this batch batch_arrays = [] batch_valid_indices = [] for idx in batch_indices: try: row = df.iloc[idx] img_array = decode_image_bytes(row) batch_arrays.append(img_array) batch_valid_indices.append(idx) except Exception as e: logger.warning(f"Error decoding row {idx}: {e}") continue if not batch_arrays: continue # Compute embeddings for this batch try: batch_embeddings = compute_embedding_batch(model, model_name, batch_arrays) # Store results for i, idx in enumerate(batch_valid_indices): if batch_embeddings[i] is not None: embeddings_list[idx] = batch_embeddings[i] valid_mask[idx] = True except Exception as e: logger.warning(f"Error computing batch embeddings: {e}") # Fall back to single image processing for this batch for i, idx in enumerate(batch_valid_indices): try: embedding = compute_embedding_single(model, model_name, batch_arrays[i]) if embedding is not None: embeddings_list[idx] = embedding valid_mask[idx] = True except Exception as inner_e: logger.warning(f"Error processing row {idx}: {inner_e}") continue # Filter to valid rows only valid_indices = [i for i, v in enumerate(valid_mask) if v] if not valid_indices: logger.warning(f"No valid embeddings for {file_path.name}") return None # Build result DataFrame result_df = df.iloc[valid_indices].copy() valid_embeddings = [embeddings_list[i] for i in valid_indices] # Remove unwanted columns cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns] if cols_to_drop: result_df = result_df.drop(columns=cols_to_drop) # Remove image_bytes (large binary data) if 'image_bytes' in result_df.columns: result_df = result_df.drop(columns=['image_bytes']) # Remove geometry column (binary) if 'geometry' in result_df.columns: result_df = result_df.drop(columns=['geometry']) # Rename columns result_df = result_df.rename(columns=COLUMNS_RENAME) # Add pixel_bbox result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df) # Add embedding result_df['embedding'] = valid_embeddings return result_df # ============================================================================= # Main Processing Pipeline # ============================================================================= def main(): parser = argparse.ArgumentParser(description='Compute embeddings for Major-TOM images') parser.add_argument('--model', type=str, required=True, choices=['dinov2', 'siglip', 'farslip', 'satclip'], help='Model to use for embedding computation') parser.add_argument('--device', type=str, default='cuda:0', help='Device to run on (e.g., cuda:0, cuda:1, cpu)') parser.add_argument('--batch-size', type=int, default=None, help='Batch size for processing (default: model-specific)') parser.add_argument('--max-files', type=int, default=None, help='Maximum number of files to process (for testing)') args = parser.parse_args() # Setup logging logger = setup_logging(args.model) logger.info("=" * 80) logger.info(f"Computing {args.model.upper()} embeddings") logger.info(f"Timestamp: {datetime.now().isoformat()}") logger.info(f"Device: {args.device}") logger.info("=" * 80) # Load configuration config = load_and_process_config() if config is None: logger.warning("No config file found, using default paths") config = {} # Determine batch size batch_size = args.batch_size or BATCH_SIZES.get(args.model, 64) logger.info(f"Batch size: {batch_size}") # Get output path output_path = MODEL_OUTPUT_PATHS[args.model] output_path.parent.mkdir(parents=True, exist_ok=True) logger.info(f"Output path: {output_path}") # Load model logger.info(f"Loading {args.model} model...") model = load_model(args.model, args.device, config) # Get input files parquet_files = sorted(IMAGE_PARQUET_DIR.glob("batch_*.parquet")) if args.max_files: parquet_files = parquet_files[:args.max_files] logger.info(f"Found {len(parquet_files)} input files") # Process files all_results = [] total_rows = 0 for file_path in tqdm(parquet_files, desc=f"Processing {args.model}"): try: result_df = process_parquet_file(file_path, model, args.model, batch_size) if result_df is not None: all_results.append(result_df) total_rows += len(result_df) logger.info(f"[{file_path.name}] Processed {len(result_df)} rows") except Exception as e: logger.error(f"Error processing {file_path.name}: {e}") import traceback traceback.print_exc() continue # Merge and save if all_results: logger.info("Merging all results...") final_df = pd.concat(all_results, ignore_index=True) # Validate columns logger.info(f"Final columns: {list(final_df.columns)}") # Check for removed columns removed = [c for c in COLUMNS_TO_REMOVE if c in final_df.columns] if removed: logger.warning(f"Columns still present that should be removed: {removed}") else: logger.info("✓ All unwanted columns removed") # Check for renamed columns if 'utm_crs' in final_df.columns and 'crs' not in final_df.columns: logger.info("✓ Column 'crs' renamed to 'utm_crs'") # Check for pixel_bbox if 'pixel_bbox' in final_df.columns: logger.info("✓ Column 'pixel_bbox' added") # Save logger.info(f"Saving to {output_path}...") final_df.to_parquet(output_path, index=False) logger.info(f"=" * 80) logger.info(f"Processing complete!") logger.info(f"Total rows: {len(final_df):,}") logger.info(f"Embedding dimension: {len(final_df['embedding'].iloc[0])}") logger.info(f"Output file: {output_path}") logger.info(f"=" * 80) else: logger.error("No data processed!") return 1 return 0 if __name__ == "__main__": sys.exit(main())